Model Customization using SageMaker Training Job#
This notebook provides an end-to-end walkthrough for creating SageMaker Training job using a SageMaker Nova model and deploy it for inference.
Setup and Dependencies#
import os
import json
import boto3
from rich.pretty import pprint
from sagemaker.core.helper.session_helper import Session
REGION = boto3.Session().region_name
sm_client = boto3.client("sagemaker", region_name=REGION)
# Create SageMaker session
sagemaker_session = Session(sagemaker_client=sm_client)
print(f"Region: {REGION}")
# For MLFlow native metrics in Trainer wait, run below line with appropriate region
os.environ["SAGEMAKER_MLFLOW_CUSTOM_ENDPOINT"] = f"https://mlflow.sagemaker.{REGION}.app.aws"
Create Training Dataset#
Below section provides sample code to create the training dataset arn
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique
# Register dataset in SageMaker AI Registry. This creates a versioned dataset that can be referenced by ARN
dataset = DataSet.create(
name="demo-sft-dataset",
source="s3://your-bucket/dataset/training_dataset.jsonl", # source can be S3 or local path
#customization_technique=CUSTOMIZATION_TECHNIQUE.SFT # or DPO or RLVR
# Optional technique name for minimal dataset format check.
wait=True
)
print(f"TRAINING_DATASET ARN: {dataset.arn}")
# TRAINING_DATASET = dataset.arn
# Required Configs
BASE_MODEL = ""
# MODEL_PACKAGE_GROUP_NAME is same as CUSTOM_MODEL_NAME
MODEL_PACKAGE_GROUP_NAME = ""
TRAINING_DATASET = ""
S3_OUTPUT_PATH = ""
ROLE_ARN = ""
Create Model Package Group#
from sagemaker.core.resources import ModelPackageGroup
model_package_group = ModelPackageGroup.create(
model_package_group_name=MODEL_PACKAGE_GROUP_NAME,
model_package_group_description='' # Required Description
)
Part 1: Fine-tuning#
Step 1: Creating the Trainer#
Choose one of the following trainer techniques:#
Option 1: SFT Trainer (Supervised Fine-Tuning)
Option 2: Create RLVRTrainer (Reinforcement Learning with Verifiable Rewards).
Option 3: DPO Trainer (Direct Preference Optimization)
Instructions: Run only ONE of the trainers, not all of them.
Create SFT Trainer (Supervised Fine-Tuning)#
Key Parameters:#
model: base_model id on Sagemaker Hubcontent that is available to finetune (or) ModelPackage artifactstraining_type: Choose from TrainingType Enum(sagemaker.train.common) either LORA OR FULL. (optional)model_package_group: ModelPackage group name or ModelPackageGroup (optional)mlflow_resource_arn: MLFlow app ARN to track the training job (optional)mlflow_experiment_name: MLFlow app experiment name(str) (optional)mlflow_run_name: MLFlow app run name(str) (optional)training_dataset: Training Dataset - either Dataset ARN or S3 Path of the dataset (Please note these are required for a training job to run, can be either provided via Trainer or .train()) (optional)validation_dataset: Validation Dataset - either Dataset ARN or S3 Path of the dataset (optional)s3_output_path: S3 path for the trained model artifacts (optional)base_job_name: Unique job name (optional)
from sagemaker.train.sft_trainer import SFTTrainer
from sagemaker.train.common import TrainingType
trainer = SFTTrainer(
model=BASE_MODEL,
training_type=TrainingType.LORA,
model_package_group=model_package_group,
training_dataset=TRAINING_DATASET,
s3_output_path=S3_OUTPUT_PATH,
sagemaker_session=sagemaker_session,
role=ROLE_ARN
)
OR#
Create RLVRTrainer (Reinforcement Learning with Verifiable Rewards)#
Key Parameters:#
model: base_model id on Sagemaker Hubcontent that is available to finetune (or) ModelPackage artifactscustom_reward_function: Custom reward function/Evaluator ARN (optional)model_package_group: ModelPackage group name or ModelPackageGroup (optional)mlflow_resource_arn: MLFlow app ARN to track the training job (optional)mlflow_experiment_name: MLFlow app experiment name(str) (optional)mlflow_run_name: MLFlow app run name(str) (optional)training_dataset: Training Dataset - either Dataset ARN or S3 Path of the dataset (Please note these are required for a training job to run, can be either provided via Trainer or .train()) (optional)validation_dataset: Validation Dataset - either Dataset ARN or S3 Path of the dataset (optional)s3_output_path: S3 path for the trained model artifacts (optional)
from sagemaker.train.rlvr_trainer import RLVRTrainer
trainer = RLVRTrainer(
model=BASE_MODEL,
model_package_group=model_package_group,
training_dataset=TRAINING_DATASET,
s3_output_path=S3_OUTPUT_PATH,
sagemaker_session=sagemaker_session,
role=ROLE_ARN
)
OR#
Create DPO Trainer (Direct Preference Optimization)#
Direct Preference Optimization (DPO) is a method for training language models to follow human preferences. Unlike traditional RLHF (Reinforcement Learning from Human Feedback), DPO directly optimizes the model using preference pairs without needing a reward model.
Key Parameters:#
modelBase model to fine-tune (from SageMaker Hub)training_typeFine-tuning method (LoRA recommended for efficiency)training_datasetARN of the registered preference datasetmodel_package_groupWhere to store the fine-tuned modelmlflow_resource_arnMLflow tracking server for experiment logging
from sagemaker.train.dpo_trainer import DPOTrainer
from sagemaker.train.common import TrainingType
trainer = DPOTrainer(
model=BASE_MODEL,
training_type=TrainingType.LORA,
model_package_group=model_package_group,
training_dataset=TRAINING_DATASET,
s3_output_path=S3_OUTPUT_PATH,
sagemaker_session=sagemaker_session,
role=ROLE_ARN
)
Step 2: Get Finetuning Options and Modify#
print("Default Finetuning Options:")
pprint(trainer.hyperparameters.to_dict())
# Modify options like object attributes
trainer.hyperparameters.learning_rate = 0.0002
print("\nModified/User defined Options:")
pprint(trainer.hyperparameters.to_dict())
Step 3: Start Training#
training_job = trainer.train(wait=True)
TRAINING_JOB_NAME = training_job.training_job_name
pprint(training_job)
Step 4: Describe Training job#
from sagemaker.core.resources import TrainingJob
response = TrainingJob.get(training_job_name=TRAINING_JOB_NAME)
pprint(response)
Part 2: Model Evaluation#
This section demonstrates the basic user-facing flow for creating and managing evaluation jobs
Step 1: Create BenchmarkEvaluator#
Create a BenchmarkEvaluator instance with the desired benchmark. The evaluator will use Jinja2 templates to render a complete pipeline definition.
Key Parameters:#
benchmark: Benchmark type from the Benchmark enummodel: Model ARN from SageMaker hub contents3_output_path: S3 location for evaluation outputsmlflow_resource_arn: MLflow tracking server ARN for experiment tracking (optional)model_package_group: Model package group ARN (optional)source_model_package: Source model package ARN (optional)model_artifact: ARN of model artifact for lineage tracking (auto-inferred from source_model_package) (optional)
Note: When you call evaluate(), the system will start evaluation job. The evaluator will:
Build template context with all required parameters
Render the pipeline definition from
DETERMINISTIC_TEMPLATEusing Jinja2Create or update the pipeline with the rendered definition
Start the pipeline execution with empty parameters (all values pre-substituted)
from sagemaker.train.evaluate import BenchMarkEvaluator
from sagemaker.train.evaluate import get_benchmarks, get_benchmark_properties
from rich.pretty import pprint
import logging
logging.basicConfig(
level=logging.INFO,
format='%(levelname)s - %(name)s - %(message)s'
)
# Get available benchmarks
Benchmark = get_benchmarks()
pprint(list(Benchmark))
# Print properties for a specific benchmark
pprint(get_benchmark_properties(benchmark=Benchmark.GEN_QA))
# Create evaluator with GEN_QA benchmark
evaluator = BenchMarkEvaluator(
benchmark=Benchmark.GEN_QA,
model=BASE_MODEL,
s3_output_path=S3_OUTPUT_PATH,
)
pprint(evaluator)
Step 2: Run Evaluation#
# Run evaluation
execution = evaluator.evaluate()
print(f"Evaluation job started!")
print(f"Job ARN: {execution.arn}")
print(f"Job Name: {execution.name}")
print(f"Status: {execution.status.overall_status}")
pprint(execution)
Step 3: Monitor Execution#
execution.refresh()
print(f"Current status: {execution.status}")
# Display individual step statuses
if execution.status.step_details:
print("\nStep Details:")
for step in execution.status.step_details:
print(f" {step.name}: {step.status}")
Step 4: Wait for Completion#
Wait for the pipeline to complete. This provides rich progress updates in Jupyter notebooks:
execution.wait(target_status="Succeeded", poll=5, timeout=3600)
print(f"\nFinal Status: {execution.status.overall_status}")
Step 5: View Results#
Display the evaluation results in a formatted table:
execution.show_results()
Part 3. Deploying the Model to Bedrock for inference#
Trained model artifacts and checkpoints are stored in your designated escrow S3 bucket. You can access the training checkpoint location from the describe_training_job response.
By calling create_custom_model API, you can create your custom model referencing the model artifacts stored in your S3 escrow bucket.
import boto3
import json
from urllib.parse import urlparse
bedrock_custom_model_name = "" # customize as needed
describe_training_response = sm_client.describe_training_job(TrainingJobName=TRAINING_JOB_NAME)
training_output_s3_uri = describe_training_response['OutputDataConfig']['S3OutputPath']
def get_s3_manifest(training_output_s3_uri):
try:
s3_client = boto3.client('s3')
parsed_uri = urlparse(training_output_s3_uri)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip('/')
manifest_key = f"{key.rstrip('/')}/{TRAINING_JOB_NAME}/output/output/manifest.json"
print(f"Fetching manifest from s3://{bucket}/{manifest_key}")
response = s3_client.get_object(Bucket=bucket, Key=manifest_key)
manifest_content = response['Body'].read().decode('utf-8')
manifest = json.loads(manifest_content)
if 'checkpoint_s3_bucket' not in manifest:
raise ValueError("Checkpoint location not found in manifest")
print(f"Successfully retrieved checkpoint S3 URI: {manifest['checkpoint_s3_bucket']}")
return manifest['checkpoint_s3_bucket']
except s3_client.exceptions.NoSuchKey:
raise FileNotFoundError(f"Manifest file not found at s3://{bucket}/{manifest_key}")
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse manifest JSON: {str(e)}")
except Exception as e:
raise Exception(f"Error fetching manifest: {str(e)}")
s3_checkpoint_path = get_s3_manifest(training_output_s3_uri)
bedrock_client = boto3.Session().client(service_name="bedrock", region_name=REGION)
s3_checkpoint_path = describe_training_response["CheckpointConfig"]["S3Uri"]
try:
response = bedrock_client.create_custom_model(
modelName=bedrock_custom_model_name,
modelSourceConfig={"s3DataSource": {"s3Uri": s3_checkpoint_path}},
roleArn=ROLE_ARN,
# Optionally, add modelTags here
)
print("Custom model ARN:", response["modelArn"])
except Exception as e:
print(f"An unexpected error occurred: {e}")
To monitor the job, use the get_custom_model operation to retrieve the job status. Please allow some time for the job to complete as this can take upto 20 minutes.
import time
while True:
custom_model_response = bedrock_client.get_custom_model(modelIdentifier=bedrock_custom_model_name)
model_status = custom_model_response["modelStatus"]
print(f"Custom model status: {model_status}")
if model_status == "Active":
break
elif model_status in ["Failed"]:
raise Exception(f"Custom model creation failed with status: {model_status}")
time.sleep(30)
print("Custom model is ACTIVE.")
custom_model_response
After you create a custom model, you can set up inference using one of the following options:
Purchase Provisioned Throughput – Purchase Provisioned Throughput for your model to set up dedicated compute capacity with guaranteed throughput for consistent performance and lower latency. For more information about Provisioned Throughput, see Increase model invocation capacity with Provisioned Throughput in Amazon Bedrock. For more information about using custom models with Provisioned Throughput, see Purchase Provisioned Throughput for a custom model.
Deploy custom model for on-demand inference (only LoRA fine-tuned Amazon Nova models) – To set up on-demand inference, you deploy the custom model with a custom model deployment. After you deploy the model, you invoke it using the ARN for the custom model deployment. With on-demand inference, you only pay for what you use and you don’t need to set up provisioned compute resources. For more information about deploying custom models for on-demand inference, see Deploy a custom model for on-demand inference.
Deploy custom model for inference by using Provisioned Throughput#
provisioned_model_name = "test-provisioned-model"
custom_model_id = custom_model_response["modelArn"]
try:
response = bedrock_client.create_provisioned_model_throughput(
modelId=custom_model_id, provisionedModelName=provisioned_model_name, modelUnits=1
)
provisioned_model_arn = response["provisionedModelArn"]
print("Provisioned model ARN:", provisioned_model_arn)
except Exception as e:
print(f"An unexpected error occurred: {e}")
Wait for provisioned model to become ACTIVE
while True:
response = bedrock_client.get_provisioned_model_throughput(
provisionedModelId=provisioned_model_arn
)
model_status = response["status"]
print(f"Provisioned model status: {model_status}")
if model_status == "InService":
break
elif model_status in ["Failed"]:
raise Exception(f"Provisioned model failed with status: {model_status}")
time.sleep(30)
print("Provisioned model is in service.")
response
Finally, you can invoke the model like any other Bedrock-hosted model using the invoke-model API
# Invoke model (Inference)
bedrock_runtime = boto3.client("bedrock-runtime", region_name=REGION)
request_body = {
"inferenceConfig": {"max_new_tokens": 1000, "temperature": 0.7, "top_p": 0.9},
"messages": [
{
"role": "user",
"content": [
{"text": "Tell me about Amazon Bedrock in less than 100 words."}
],
}
],
}
response = bedrock_runtime.invoke_model(
modelId=provisioned_model_arn,
body=json.dumps(request_body),
contentType="application/json",
accept="application/json",
)
response_body = json.loads(response["body"].read())
print(response_body["output"]["message"]["content"][0]["text"])
Deploy custom model for On-Demand Inference#
Important Note: On-demand inference is currently supported only for LoRA-based fine-tuned models.
Once the custom model has reached Active Status, deploy it for on-demand inference by creating custom model deployment.
model_deployment_name = "<model-deployment-name>"
custom_model_arn=custom_model_response["modelArn"]
try:
response = bedrock_client.create_custom_model_deployment(
modelDeploymentName=model_deployment_name,
modelArn=custom_model_arn,
description="<model-deployment-description>",
tags=[
{
"key":"<your-key>",
"value":"<your-value>"
}
]
)
custom_model_deployment_arn = response["customModelDeploymentArn"]
print("Custom model deployment ARN:", custom_model_deployment_arn)
except Exception as e:
print(f"An unexpected error occurred: {e}")
response
while True:
response = bedrock_client.get_custom_model_deployment(customModelDeploymentIdentifier=custom_model_deployment_arn)
model_status = response["status"]
print(f"Custom model deployment status: {model_status}")
if model_status == "Active":
break
elif model_status in ["Failed"]:
raise Exception(f"Custom model deployment failed with status: {model_status}")
time.sleep(30)
print("Custom model is ACTIVE.")
response
bedrock_runtime = boto3.client("bedrock-runtime", region_name=REGION)
# invoke a deployed custom model using Converse API
response = bedrock_runtime.converse(
modelId=custom_model_deployment_arn,
messages=[
{
"role": "user",
"content": [
{
"text": "Tell me about Amazon Bedrock in less than 100 words.",
}
]
}
]
)
result = response.get('output')
print(result)
# invoke a deployed custom model using InvokeModel API
request_body = {
"schemaVersion": "messages-v1",
"messages": [{"role": "user",
"content": [{"text": "Tell me about Amazon Bedrock in less than 100 words."}]}],
"system": [{"text": "What is amazon bedrock?"}],
"inferenceConfig": {"maxTokens": 500,
"topP": 0.9,
"temperature": 0.0
}
}
body = json.dumps(request_body)
response = bedrock_runtime.invoke_model(
modelId=custom_model_deployment_arn,
body=body
)
# Extract and print the response text
model_response = json.loads(response["body"].read())
response_text = model_response["output"]["message"]["content"][0]["text"]
print(response_text)
Cleanup#
Delete the resources that were created to stop incurring charges.
# Delete provisioned model throughput
print(f"Deleting provisioned model throughput: {provisioned_model_arn}")
try:
bedrock_client.delete_provisioned_model_throughput(
provisionedModelId=provisioned_model_name
)
print("Provisioned model throughput deleted successfully.")
except Exception as e:
print(f"Error deleting provisioned throughput: {e}")
# Delete custom model deployment if you have used on-demand inference.
print(f"Deleting custom model deployment: {custom_model_deployment_arn}")
try:
bedrock_client.delete_custom_model_deployment(
customModelDeploymentIdentifier=custom_model_deployment_arn
)
print("Custom model deployment deleted successfully.")
except Exception as e:
print(f"Error deleting custom model deployment: {e}")
# Delete custom model
print(f"Deleting custom model: {custom_model_id}")
try:
bedrock_client.delete_custom_model(modelIdentifier=custom_model_id)
print("Custom model deleted successfully.")
except Exception as e:
print(f"Error deleting custom model: {e}")