Batch Transform Job with XGBoost Model

Batch Transform Job with XGBoost Model#

Deploy an XGBoost model and run batch inference using SageMaker Transform Job to generate predictions on validation data.

import boto3
from sagemaker.core.helper.session_helper import get_execution_role, Session
from sagemaker.core.transformer import Transformer
from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.core.image_uris import retrieve

region = boto3.Session().region_name

role = get_execution_role()
print("RoleArn: {}".format(role))
sagemaker_session = Session()
bucket = sagemaker_session.default_bucket()

print("Demo Bucket: {}".format(bucket))
prefix = "demo-transform"
reports_prefix = "{}/reports".format(prefix)
s3_report_path = "s3://{}/{}".format(bucket, reports_prefix)

transform_output_path = "s3://{}/{}/transform-outputs".format(bucket, prefix)

print("Transform Output path: {}".format(transform_output_path))
print("Report path: {}".format(s3_report_path))

Deploy Model#

model_file_name = "xgb-churn-prediction-model.tar.gz"

!aws s3 cp data/{model_file_name} s3://{bucket}/{prefix}/{model_file_name}
model_url = "https://{}.s3-{}.amazonaws.com/{}/{}".format(bucket, region, prefix, model_file_name)
image_uri = retrieve("xgboost", boto3.Session().region_name, "0.90-1")

model_builder = ModelBuilder(
    image_uri=image_uri,
    s3_model_data_url=model_url,
    role_arn=role,
    sagemaker_session=sagemaker_session,
)

model_builder.build(model_name="my-transform-model")

Validation#

# Dataset used to get predictions

!aws s3 cp data/validation.csv s3://{bucket}/{prefix}/transform_input/validation/validation.csv
transformer = Transformer(
    model_name="my-transform-model",
    instance_count=1,
    instance_type="ml.m5.xlarge",
    accept="text/csv",
    assemble_with="Line",
    output_path=transform_output_path,
    sagemaker_session=sagemaker_session,
)
data_input=f"s3://{bucket}/{prefix}/transform_input/validation"

transform_arg = transformer.transform(
    data_input,
    content_type="text/csv",
    split_type="Line",
    # exclude the ground truth (first column) from the validation set
    # when doing inference.
    input_filter="$[1:]",
)