SFTTrainer Example - Finetuning with Sagemaker#

This notebook demonstrates basic user flow for SFT Finetuning from a model available in Sagemaker Jumpstart. Information on available models on jumpstart: https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-latest.html

Setup and Configuration#

Initialize the environment by importing necessary libraries and configuring AWS credentials

# Configure AWS credentials and region
#! ada credentials update --provider=isengard --account=<> --role=Admin --profile=default --once
#! aws configure set region us-west-2
from sagemaker.train.sft_trainer import SFTTrainer
from sagemaker.train.common import TrainingType
from sagemaker.core.training.configs import InputData
from rich import print as rprint
from rich.pretty import pprint
from sagemaker.core.resources import ModelPackage

import boto3
from sagemaker.core.helper.session_helper import Session
import os


# For MLFlow native metrics in Trainer wait, run below line with approriate region
os.environ["SAGEMAKER_MLFLOW_CUSTOM_ENDPOINT"] = "https://mlflow.sagemaker.us-west-2.app.aws"

Finetuning with Jumpstart base model#

Prepare and Register Dataset#

Prepare and Register Dataset for Finetuning

from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique



# Register dataset in SageMaker AI Registry
# This creates a versioned dataset that can be referenced by ARN
# Provide a source (it can be local file path or S3 URL)
dataset = DataSet.create(
    name="demo-1",
    source="s3://mc-flows-sdk-testing/input_data/sft/sample_data_256_final.jsonl"
)

print(f"Dataset ARN: {dataset.arn}")

Create a Model Package group (if not already exists)#

from sagemaker.core.resources import ModelPackage, ModelPackageGroup

model_package_group=ModelPackageGroup.create(model_package_group_name="test-model-package-group")

Create SFTTrainer#

Required Parameters

  • model: base_model id on Sagemaker Hubcontent that is available to finetune (or) ModelPackage artifacts

Optional Parameters

  • training_type: Choose from TrainingType Enum(sagemaker.modules.train.common) either LORA OR FULL.

  • model_package_group: ModelPackage group name or ModelPackageGroup object. This parameter is mandatory when a base model ID is provided, but optional when a model package is provided.

  • mlflow_resource_arn: MLFlow app ARN to track the training job

  • mlflow_experiment_name: MLFlow app experiment name(str)

  • mlflow_run_name: MLFlow app run name(str)

  • training_dataset: Training Dataset - should be a Dataset ARN or Dataset object (Please note training dataset is required for a training job to run, can be either provided via Trainer or .train())

  • validation_dataset: Validation Dataset - should be a Dataset ARN or Dataset object

  • s3_output_path: S3 path for the trained model artifacts

Reference#

Refer this doc for other models that support Model Customization: https://docs.aws.amazon.com/bedrock/latest/userguide/custom-model-supported.html

# For fine-tuning 
sft_trainer = SFTTrainer(
    model="meta-textgeneration-llama-3-2-1b-instruct", 
    training_type=TrainingType.LORA, 
    model_package_group=model_package_group, # or use an existing model package group arn
    mlflow_experiment_name="test-finetuned-models-exp", 
    mlflow_run_name="test-finetuned-models-run", 
    training_dataset=dataset.arn, 
    s3_output_path="s3://mc-flows-sdk-testing/output/",
    accept_eula=True
)

Discover and update Finetuning options#

Each of the technique and model has overridable hyperparameters that can be finetuned by the user.

print("Default Finetuning Options:")
pprint(sft_trainer.hyperparameters.to_dict()) # rename as hyperparameters
# To update any hyperparameter, simply assign the value, example:
sft_trainer.hyperparameters.global_batch_size=16

Start SFT training#

training_job = sft_trainer.train(
    wait=True,
)
import json
import re
from sagemaker.core.utils.utils import Unassigned
from sagemaker.core.resources import TrainingJob

response = TrainingJob.get(training_job_name="meta-textgeneration-llama-3-2-1b-instruct-sft-20251201114921")

def pretty_print(obj):
    def parse_unassigned(item):
        if isinstance(item, Unassigned):
            return None
        if isinstance(item, dict):
            return {k: parse_unassigned(v) for k, v in item.items() if parse_unassigned(v) is not None}
        if isinstance(item, list):
            return [parse_unassigned(x) for x in item if parse_unassigned(x) is not None]
        if isinstance(item, str) and "Unassigned object" in item:
            pairs = re.findall(r"(\w+)=([^<][^=]*?)(?=\s+\w+=|$)", item)
            result = {k: v.strip("'\"") for k, v in pairs}
            return result if result else None
        return item

    cleaned = parse_unassigned(obj.__dict__ if hasattr(obj, '__dict__') else obj)
    print(json.dumps(cleaned, indent=2, default=str))

pretty_print(response)
#In order to skip waiting and monitor the training Job later

'''
training_job = sft_trainer.train(
    wait=False,
)
'''
pretty_print(training_job)

View any Training job details#

We can get any training job details and its status with TrainingJob.get(…)

from sagemaker.core.resources import TrainingJob

response = TrainingJob.get(training_job_name="meta-textgeneration-llama-3-2-1b-instruct-sft-20251123162832")
pretty_print(response)

Continued Finetuning (or) Finetuning on Model Artifacts#

Discover a ModelPackage and get its details#

from rich import print as rprint
from rich.pretty import pprint
from sagemaker.core.resources import ModelPackage, ModelPackageGroup

#model_package_iter = ModelPackage.get_all(model_package_group_name="test-finetuned-models-gamma")
model_package = ModelPackage.get(model_package_name="arn:aws:sagemaker:us-west-2:<>:model-package/sdk-test-finetuned-models/2")

pretty_print(model_package)

Create Trainer#

Trainer creation is same as above Finetuning Section except for model’s input is ModelPackage(previously trained artifacts)

# For fine-tuning 
sft_trainer = SFTTrainer(
    model=model_package, # Union[str, ModelPackage]
    training_type=TrainingType.LORA, 
    model_package_group="sdk-test-finetuned-models", # Make it Optional
    mlflow_experiment_name="test-finetuned-models-exp", # Optional[str]
    mlflow_run_name="test-finetuned-models-run", # Optional[str]
    training_dataset=dataset.arn, #Optional[]
    s3_output_path="s3://mc-flows-sdk-testing/output/",
)

Start the Training#

training_job = sft_trainer.train(
    wait=True,
)
pretty_print(training_job)

SFT Trainer Nova testing#

os.environ['SAGEMAKER_REGION'] = 'us-east-1'

# For fine-tuning 
sft_trainer_nova = SFTTrainer(
    #model="test-nova-lite-v2", 
    #model="nova-textgeneration-micro",
    model="nova-textgeneration-lite-v2",
    training_type=TrainingType.LORA, 
    model_package_group="sdk-test-finetuned-models", 
    mlflow_experiment_name="test-nova-finetuned-models-exp", 
    mlflow_run_name="test-nova-finetuned-models-run", 
    training_dataset="arn:aws:sagemaker:us-east-1:<>:hub-content/sdktest/DataSet/sft-nova-test-dataset/0.0.1",
    s3_output_path="s3://mc-flows-sdk-testing-us-east-1/output/"
)
sft_trainer_nova.hyperparameters.to_dict()
training_job = sft_trainer_nova.train(
    wait=True,
)