Direct Preference Optimization (DPO) Training with SageMaker#
This notebook demonstrates how to use the DPOTrainer to fine-tune large language models using Direct Preference Optimization (DPO). DPO is a technique that trains models to align with human preferences by learning from preference data without requiring a separate reward model.
What is DPO?#
Direct Preference Optimization (DPO) is a method for training language models to follow human preferences. Unlike traditional RLHF (Reinforcement Learning from Human Feedback), DPO directly optimizes the model using preference pairs without needing a reward model.
Key Benefits:
Simpler than RLHF - no reward model required
More stable training process
Direct optimization on preference data
Works with LoRA for efficient fine-tuning
Workflow Overview#
Prepare Preference Dataset: Upload preference data in JSONL format
Register Dataset: Create a SageMaker AI Registry dataset
Configure DPO Trainer: Set up model, training parameters, and resources
Execute Training: Run the DPO fine-tuning job
Track Results: Monitor training with MLflow integration
Step 1: Prepare and Register Preference Dataset#
DPO requires preference data in a specific format where each example contains:
prompt: The input text
chosen: The preferred response
rejected: The less preferred response
The dataset should be in JSONL format with each line containing one preference example.
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique
# Register dataset in SageMaker AI Registry
# This creates a versioned dataset that can be referenced by ARN
# Provide a source (it can be local file path or S3 URL)
dataset = DataSet.create(
name="demo-6",
source="s3://nova-mlflow-us-west-2/dataset/preference_dataset_train_256.jsonl"
)
print(f"Dataset ARN: {dataset.arn}")
Create a Model Package group (if not already exists)#
from sagemaker.core.resources import ModelPackage, ModelPackageGroup
model_package_group=ModelPackageGroup.create(model_package_group_name="test-model-package-group")
Step 2: Configure and Execute DPO Training#
The DPOTrainer provides a high-level interface for DPO fine-tuning with the following key features:
Key Parameters:#
Required Parameters
model: base_model id on Sagemaker Hubcontent that is available to finetune (or) ModelPackage artifacts
Optional Parameters
training_type: Choose from TrainingType Enum(sagemaker.modules.train.common) either LORA OR FULL.model_package_group: ModelPackage group name or ModelPackageGroup object. This parameter is mandatory when a base model ID is provided, but optional when a model package is provided.mlflow_resource_arn: MLFlow app ARN to track the training jobmlflow_experiment_name: MLFlow app experiment name(str)mlflow_run_name: MLFlow app run name(str)training_dataset: Training Dataset - should be a Dataset ARN or Dataset object (Please note training dataset is required for a training job to run, can be either provided via Trainer or .train())validation_dataset: Validation Dataset - should be a Dataset ARN or Dataset objects3_output_path: S3 path for the trained model artifacts
Training Features:#
Serverless Training: Automatically managed compute resources
LoRA Integration: Parameter-efficient fine-tuning
MLflow Tracking: Automatic experiment and metrics logging
Model Versioning: Automatic model package creation
Reference#
Refer this doc for other models that support Model Customization: https://docs.aws.amazon.com/bedrock/latest/userguide/custom-model-supported.html
import random
#! ada credentials update --provider=isengard --account=<> --role=Admin --profile=default --once
#! aws configure set region us-west-2
from sagemaker.train.dpo_trainer import DPOTrainer
from sagemaker.train.common import TrainingType
# Create DPOTrainer instance with comprehensive configuration
trainer = DPOTrainer(
# Base model from SageMaker Hub
model="meta-textgeneration-llama-3-2-1b-instruct",
# Use LoRA for efficient fine-tuning
training_type=TrainingType.LORA,
# Model versioning and storage
model_package_group=model_package_group, # or use an existing model package group arn
# Training data (from Step 1)
training_dataset=dataset.arn,
# Output configuration
s3_output_path="s3://mc-flows-sdk-testing/output/",
# Unique job name
base_job_name=f"dpo-job-{random.randint(1, 1000)}",
accept_eula=True
)
# Customize training hyperparameters
# DPO-specific parameters are automatically loaded from the model's recipe
trainer.hyperparameters.max_epochs = 1 # Quick training for demo
print("Starting DPO training job...")
print(f"Job name: {trainer.base_job_name}")
print(f"Base model: {trainer._model_name}")
# Execute training with monitoring
training_job = trainer.train(wait=True)
print(f"Training completed! Job ARN: {training_job.training_job_arn}")
from pprint import pprint
from sagemaker.core.utils.utils import Unassigned
def pretty_print(obj):
def parse_unassigned(item):
if isinstance(item, Unassigned):
return None
if isinstance(item, dict):
return {k: parse_unassigned(v) for k, v in item.items() if parse_unassigned(v) is not None}
if isinstance(item, list):
return [parse_unassigned(x) for x in item if parse_unassigned(x) is not None]
if isinstance(item, str) and "Unassigned object" in item:
pairs = re.findall(r"(\w+)=([^<][^=]*?)(?=\s+\w+=|$)", item)
result = {k: v.strip("'\"") for k, v in pairs}
return result if result else None
return item
cleaned = parse_unassigned(obj.__dict__ if hasattr(obj, '__dict__') else obj)
print(json.dumps(cleaned, indent=2, default=str))
pretty_print(training_job)
# Print the training job object
import json
from sagemaker.core.utils.utils import Unassigned
from sagemaker.core.resources import TrainingJob
import pprint
response = TrainingJob.get(training_job_name="generate-sql-queries-bas-base-judge-y6cfcrah49j7-090dlKtAnQ")
import json
import re
from sagemaker.core.utils.utils import Unassigned
# Usage
pretty_print(response)
Next Steps#
After training completes, you can:
Deploy the Model: Use
ModelBuilderto deploy the fine-tuned modelEvaluate Performance: Compare responses from base vs fine-tuned model
Monitor Metrics: Review training metrics in MLflow
Iterate: Adjust hyperparameters and retrain if needed
Example Deployment:#
from sagemaker.serve import ModelBuilder
# Deploy the fine-tuned model
model_builder = ModelBuilder(model=training_job)
model_builder.build(role_arn="arn:aws:iam::account:role/SageMakerRole")
endpoint = model_builder.deploy(endpoint_name="dpo-finetuned-llama")
The fine-tuned model will now generate responses that better align with the preferences in your training data.