RLAIF Example - Finetuning with Sagemaker#
This notebook demonstrates basic user flow for RLAIF Finetuning from a model available in Sagemaker Jumpstart. Information on available models on jumpstart: https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-latest.html
Setup and Configuration#
Initialize the environment by importing necessary libraries and configuring AWS credentials
# Configure AWS credentials and region
#! ada credentials update --provider=isengard --account=<> --role=Admin --profile=default --once
#! aws configure set region us-west-2
#!/usr/bin/env python3
from sagemaker.train.rlaif_trainer import RLAIFTrainer
from sagemaker.train.configs import InputData
from rich import print as rprint
from rich.pretty import pprint
from sagemaker.core.resources import ModelPackage
import os
#os.environ['SAGEMAKER_REGION'] = 'us-east-1'
import boto3
from sagemaker.core.helper.session_helper import Session
# For MLFlow native metrics in Trainer wait, run below line with approriate region
os.environ["SAGEMAKER_MLFLOW_CUSTOM_ENDPOINT"] = "https://mlflow.sagemaker.us-west-2.app.aws"
Prepare and Register Dataset#
Prepare and Register Dataset for Finetuning
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique
# Register dataset in SageMaker AI Registry
# This creates a versioned dataset that can be referenced by ARN
# Provide a source (it can be local file path or S3 URL)
dataset = DataSet.create(
name="demo-2",
source="s3://mc-flows-sdk-testing/input_data/rlvr-rlaif-test-data/train_285.jsonl"
)
print(f"Dataset ARN: {dataset.arn}")
Create a Model Package group (if not already exists)#
from sagemaker.core.resources import ModelPackage, ModelPackageGroup
model_package_group=ModelPackageGroup.create(model_package_group_name="test-model-package-group")
Create RLAIFTrainer#
Required Parameters
model: base_model id on Sagemaker Hubcontent that is available to finetune (or) ModelPackage artifacts
Optional Parameters
reward_model_id: Bedrock model id to be used as judge.reward_prompt: Reward prompt ARN or builtin prompts refer: https://docs.aws.amazon.com/bedrock/latest/userguide/model-evaluation-metrics.htmlmodel_package_group: ModelPackage group name or ModelPackageGroup object. This parameter is mandatory when a base model ID is provided, but optional when a model package is provided.mlflow_resource_arn: MLFlow app ARN to track the training jobmlflow_experiment_name: MLFlow app experiment name(str)mlflow_run_name: MLFlow app run name(str)training_dataset: Training Dataset - should be a Dataset ARN or Dataset object (Please note training dataset is required for a training job to run, can be either provided via Trainer or .train())validation_dataset: Validation Dataset - should be a Dataset ARN or Dataset objects3_output_path: S3 path for the trained model artifacts
Reference#
Refer this doc for other models that support Model Customization: https://docs.aws.amazon.com/bedrock/latest/userguide/custom-model-supported.html
Refer this for supported reward models: https://github.com/aws/sagemaker-python-sdk/blob/master/sagemaker-train/src/sagemaker/train/constants.py#L46
# For fine-tuning
rlaif_trainer = RLAIFTrainer(
model="meta-textgeneration-llama-3-2-1b-instruct",
model_package_group=model_package_group, # or use an existing model package group arn
reward_model_id='openai.gpt-oss-120b-1:0',
reward_prompt='Builtin.Summarize',
mlflow_experiment_name="test-rlaif-finetuned-models-exp",
mlflow_run_name="test-rlaif-finetuned-models-run",
training_dataset=dataset.arn,
s3_output_path="s3://mc-flows-sdk-testing/output/",
accept_eula=True
)
Discover and update Finetuning options#
Each of the technique and model has overridable hyperparameters that can be finetuned by the user.
print("Default Finetuning Options:")
pprint(rlaif_trainer.hyperparameters.to_dict()) # rename as hyperparameters
#set options
rlaif_trainer.hyperparameters.get_info()
Start RLAIF training#
training_job = rlaif_trainer.train(wait=True)
import re
from sagemaker.core.utils.utils import Unassigned
import json
def pretty_print(obj):
def parse_unassigned(item):
if isinstance(item, Unassigned):
return None
if isinstance(item, dict):
return {k: parse_unassigned(v) for k, v in item.items() if parse_unassigned(v) is not None}
if isinstance(item, list):
return [parse_unassigned(x) for x in item if parse_unassigned(x) is not None]
if isinstance(item, str) and "Unassigned object" in item:
pairs = re.findall(r"(\w+)=([^<][^=]*?)(?=\s+\w+=|$)", item)
result = {k: v.strip("'\"") for k, v in pairs}
return result if result else None
return item
cleaned = parse_unassigned(obj.__dict__ if hasattr(obj, '__dict__') else obj)
print(json.dumps(cleaned, indent=2, default=str))
pretty_print(training_job)
View any Training job details#
We can get any training job details and its status with TrainingJob.get(…)
from sagemaker.core.resources import TrainingJob
response = TrainingJob.get(training_job_name="meta-textgeneration-llama-3-2-1b-instruct-rlaif-20251124140754")
pretty_print(response)
Test RLAIF with Custom Reward Prompt#
Here we are providing a user-defined reward prompt/evaluator ARN
Create a custom reward prompt#
from rich.pretty import pprint
from sagemaker.ai_registry.air_constants import REWARD_FUNCTION, REWARD_PROMPT
from sagemaker.ai_registry.evaluator import Evaluator
evaluator = Evaluator.create(
name = "jamj-rp2",
source="/Users/jamjee/workplace/hubpuller/prompt/custom_prompt.jinja",
type = REWARD_PROMPT
)
Use it with RLAIF Trainer#
# For fine-tuning
rlaif_trainer = RLAIFTrainer(
model="meta-textgeneration-llama-3-2-1b-instruct",
model_package_group="sdk-test-finetuned-models",
reward_model_id='openai.gpt-oss-120b-1:0',
reward_prompt=evaluator.arn,
mlflow_experiment_name="test-rlaif-finetuned-models-exp",
mlflow_run_name="test-rlaif-finetuned-models-run",
training_dataset=dataset.arn,
s3_output_path="s3://mc-flows-sdk-testing/output/",
accept_eula=True
)
training_job = rlaif_trainer.train(wait=True)
pretty_print(training_job)