Direct Preference Optimization (DPO) Training with SageMaker#

This notebook demonstrates how to use the DPOTrainer to fine-tune large language models using Direct Preference Optimization (DPO). DPO is a technique that trains models to align with human preferences by learning from preference data without requiring a separate reward model.

What is DPO?#

Direct Preference Optimization (DPO) is a method for training language models to follow human preferences. Unlike traditional RLHF (Reinforcement Learning from Human Feedback), DPO directly optimizes the model using preference pairs without needing a reward model.

Key Benefits:

  • Simpler than RLHF - no reward model required

  • More stable training process

  • Direct optimization on preference data

  • Works with LoRA for efficient fine-tuning

Workflow Overview#

  1. Prepare Preference Dataset: Upload preference data in JSONL format

  2. Register Dataset: Create a SageMaker AI Registry dataset

  3. Configure DPO Trainer: Set up model, training parameters, and resources

  4. Execute Training: Run the DPO fine-tuning job

  5. Track Results: Monitor training with MLflow integration

Step 1: Prepare and Register Preference Dataset#

DPO requires preference data in a specific format where each example contains:

  • prompt: The input text

  • chosen: The preferred response

  • rejected: The less preferred response

The dataset should be in JSONL format with each line containing one preference example.

from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique


# Register dataset in SageMaker AI Registry
# This creates a versioned dataset that can be referenced by ARN
# Provide a source (it can be local file path or S3 URL)
dataset = DataSet.create(
    name="demo-6",
    source="s3://nova-mlflow-us-west-2/dataset/preference_dataset_train_256.jsonl"
)

print(f"Dataset ARN: {dataset.arn}")

Create a Model Package group (if not already exists)#

from sagemaker.core.resources import ModelPackage, ModelPackageGroup

model_package_group=ModelPackageGroup.create(model_package_group_name="test-model-package-group")

Step 2: Configure and Execute DPO Training#

The DPOTrainer provides a high-level interface for DPO fine-tuning with the following key features:

Key Parameters:#

Required Parameters

  • model: base_model id on Sagemaker Hubcontent that is available to finetune (or) ModelPackage artifacts

Optional Parameters

  • training_type: Choose from TrainingType Enum(sagemaker.modules.train.common) either LORA OR FULL.

  • model_package_group: ModelPackage group name or ModelPackageGroup object. This parameter is mandatory when a base model ID is provided, but optional when a model package is provided.

  • mlflow_resource_arn: MLFlow app ARN to track the training job

  • mlflow_experiment_name: MLFlow app experiment name(str)

  • mlflow_run_name: MLFlow app run name(str)

  • training_dataset: Training Dataset - should be a Dataset ARN or Dataset object (Please note training dataset is required for a training job to run, can be either provided via Trainer or .train())

  • validation_dataset: Validation Dataset - should be a Dataset ARN or Dataset object

  • s3_output_path: S3 path for the trained model artifacts

Training Features:#

  • Serverless Training: Automatically managed compute resources

  • LoRA Integration: Parameter-efficient fine-tuning

  • MLflow Tracking: Automatic experiment and metrics logging

  • Model Versioning: Automatic model package creation

Reference#

Refer this doc for other models that support Model Customization: https://docs.aws.amazon.com/bedrock/latest/userguide/custom-model-supported.html

import random
#! ada credentials update --provider=isengard --account=<> --role=Admin --profile=default --once
#! aws configure set region  us-west-2

from sagemaker.train.dpo_trainer import DPOTrainer
from sagemaker.train.common import TrainingType
# Create DPOTrainer instance with comprehensive configuration
trainer = DPOTrainer(
    # Base model from SageMaker Hub
    model="meta-textgeneration-llama-3-2-1b-instruct",
    
    # Use LoRA for efficient fine-tuning
    training_type=TrainingType.LORA,
    
    # Model versioning and storage
    model_package_group=model_package_group, # or use an existing model package group arn
        
    # Training data (from Step 1)
    training_dataset=dataset.arn,
    
    # Output configuration
    s3_output_path="s3://mc-flows-sdk-testing/output/",

    
    # Unique job name
    base_job_name=f"dpo-job-{random.randint(1, 1000)}",
    accept_eula=True
)

# Customize training hyperparameters
# DPO-specific parameters are automatically loaded from the model's recipe
trainer.hyperparameters.max_epochs = 1  # Quick training for demo

print("Starting DPO training job...")
print(f"Job name: {trainer.base_job_name}")
print(f"Base model: {trainer._model_name}")

# Execute training with monitoring
training_job = trainer.train(wait=True)

print(f"Training completed! Job ARN: {training_job.training_job_arn}")
from pprint import pprint
from sagemaker.core.utils.utils import Unassigned

def pretty_print(obj):
    def parse_unassigned(item):
        if isinstance(item, Unassigned):
            return None
        if isinstance(item, dict):
            return {k: parse_unassigned(v) for k, v in item.items() if parse_unassigned(v) is not None}
        if isinstance(item, list):
            return [parse_unassigned(x) for x in item if parse_unassigned(x) is not None]
        if isinstance(item, str) and "Unassigned object" in item:
            pairs = re.findall(r"(\w+)=([^<][^=]*?)(?=\s+\w+=|$)", item)
            result = {k: v.strip("'\"") for k, v in pairs}
            return result if result else None
        return item

    cleaned = parse_unassigned(obj.__dict__ if hasattr(obj, '__dict__') else obj)
    print(json.dumps(cleaned, indent=2, default=str))
pretty_print(training_job)
# Print the training job object

import json
from sagemaker.core.utils.utils import Unassigned
from sagemaker.core.resources import TrainingJob
import pprint
response = TrainingJob.get(training_job_name="generate-sql-queries-bas-base-judge-y6cfcrah49j7-090dlKtAnQ")

import json
import re
from sagemaker.core.utils.utils import Unassigned

# Usage
pretty_print(response)

Next Steps#

After training completes, you can:

  1. Deploy the Model: Use ModelBuilder to deploy the fine-tuned model

  2. Evaluate Performance: Compare responses from base vs fine-tuned model

  3. Monitor Metrics: Review training metrics in MLflow

  4. Iterate: Adjust hyperparameters and retrain if needed

Example Deployment:#

from sagemaker.serve import ModelBuilder

# Deploy the fine-tuned model
model_builder = ModelBuilder(model=training_job)
model_builder.build(role_arn="arn:aws:iam::account:role/SageMakerRole")
endpoint = model_builder.deploy(endpoint_name="dpo-finetuned-llama")

The fine-tuned model will now generate responses that better align with the preferences in your training data.