Batch Transform Job with XGBoost Model#
Deploy an XGBoost model and run batch inference using SageMaker Transform Job to generate predictions on validation data.
import boto3
from sagemaker.core.helper.session_helper import get_execution_role, Session
from sagemaker.core.transformer import Transformer
from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.core.image_uris import retrieve
region = boto3.Session().region_name
role = get_execution_role()
print("RoleArn: {}".format(role))
sagemaker_session = Session()
bucket = sagemaker_session.default_bucket()
print("Demo Bucket: {}".format(bucket))
prefix = "demo-transform"
reports_prefix = "{}/reports".format(prefix)
s3_report_path = "s3://{}/{}".format(bucket, reports_prefix)
transform_output_path = "s3://{}/{}/transform-outputs".format(bucket, prefix)
print("Transform Output path: {}".format(transform_output_path))
print("Report path: {}".format(s3_report_path))
Deploy Model#
model_file_name = "xgb-churn-prediction-model.tar.gz"
!aws s3 cp data/{model_file_name} s3://{bucket}/{prefix}/{model_file_name}
model_url = "https://{}.s3-{}.amazonaws.com/{}/{}".format(bucket, region, prefix, model_file_name)
image_uri = retrieve("xgboost", boto3.Session().region_name, "0.90-1")
model_builder = ModelBuilder(
image_uri=image_uri,
s3_model_data_url=model_url,
role_arn=role,
sagemaker_session=sagemaker_session,
)
model_builder.build(model_name="my-transform-model")
Validation#
# Dataset used to get predictions
!aws s3 cp data/validation.csv s3://{bucket}/{prefix}/transform_input/validation/validation.csv
transformer = Transformer(
model_name="my-transform-model",
instance_count=1,
instance_type="ml.m5.xlarge",
accept="text/csv",
assemble_with="Line",
output_path=transform_output_path,
sagemaker_session=sagemaker_session,
)
data_input=f"s3://{bucket}/{prefix}/transform_input/validation"
transform_arg = transformer.transform(
data_input,
content_type="text/csv",
split_type="Line",
# exclude the ground truth (first column) from the validation set
# when doing inference.
input_filter="$[1:]",
)