SageMaker V3 Hyperparameter Tuning Example#
This notebook demonstrates how to use the V3 SageMaker Python SDK to perform hyperparameter tuning with PyTorch on the MNIST dataset.
Key V3 Changes#
Estimator → ModelTrainer: Use
ModelTrainerclass instead of framework-specific estimatorsfit() → tune(): Call
tuner.tune()instead oftuner.fit()Inputs: Use
InputDataobjects or simple S3 URIsSourceCode: Configure training scripts with
SourceCodeobject
What This Example Shows#
Setting up a PyTorch training script for MNIST
Creating a ModelTrainer with framework container
Configuring HyperparameterTuner with parameter ranges
Running a hyperparameter tuning job
Monitoring and analyzing results
Setup and Imports#
# V3 Imports
from sagemaker.train import ModelTrainer
from sagemaker.train.configs import Compute, SourceCode, InputData, StoppingCondition
from sagemaker.train.tuner import HyperparameterTuner
from sagemaker.core.parameter import ContinuousParameter, CategoricalParameter
from sagemaker.core.helper.session_helper import Session, get_execution_role
import os
Configure Session and Variables#
# Initialize SageMaker session
sagemaker_session = Session()
region = sagemaker_session.boto_region_name
default_bucket = sagemaker_session.default_bucket()
# Role Configuration
# Option 1: Auto-detect (works in SageMaker Studio/Notebook instances)
# Option 2: Manually specify your SageMaker execution role ARN
try:
role = get_execution_role()
print(f"✓ Auto-detected role: {role}")
except Exception as e:
print(f"⚠️ Could not auto-detect role: {e}")
# Manually specify your SageMaker execution role ARN here:
role = "<IAM Role ARN>"
print(f"✓ Using manually specified role: {role}")
# Define prefixes for organization
prefix = "v3-hpo-pytorch-mnist"
base_job_prefix = "pytorch-mnist-hpo"
default_bucket_prefix = sagemaker_session.default_bucket_prefix
# Apply bucket prefix if specified
if default_bucket_prefix:
prefix = f"{default_bucket_prefix}/{prefix}"
base_job_prefix = f"{default_bucket_prefix}/{base_job_prefix}"
# Configuration
training_instance_type = "ml.m5.xlarge"
account_id = sagemaker_session.account_id()
local_dir = "data"
print(f"\nRegion: {region}")
print(f"Role: {role}")
print(f"Bucket: {default_bucket}")
print(f"Prefix: {prefix}")
Prepare Training Data#
Download MNIST dataset and upload to S3.
# Download MNIST dataset
from torchvision.datasets import MNIST
from torchvision import transforms
MNIST.mirrors = [
f"https://sagemaker-example-files-prod-{region}.s3.amazonaws.com/datasets/image/MNIST/"
]
print("Downloading MNIST dataset...")
MNIST(
local_dir,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
)
# Upload to S3
print(f"Uploading data to S3...")
s3_data_uri = sagemaker_session.upload_data(
path=local_dir,
bucket=default_bucket,
key_prefix=f"{prefix}/data"
)
print(f"Training data uploaded to: {s3_data_uri}")
Create Training Script#
The mnist.py training script is in the current directory.
# Training script (mnist.py) is in the current directory
import os
if os.path.exists("mnist.py"):
print("✓ Training script found: mnist.py")
else:
print("✗ Warning: mnist.py not found in current directory.")
print(" Please ensure mnist.py exists in the same directory as this notebook.")
Configure ModelTrainer#
Create a ModelTrainer instance with PyTorch training container.
# Configure source code
source_code = SourceCode(
source_dir=".", # Current directory containing mnist.py
entry_script="mnist.py"
)
# Configure compute resources
compute = Compute(
instance_type=training_instance_type,
instance_count=1,
volume_size_in_gb=30
)
# Configure stopping condition
stopping_condition = StoppingCondition(
max_runtime_in_seconds=3600 # 1 hour
)
# Get PyTorch training image
training_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:1.10.0-gpu-py38"
# Create ModelTrainer
model_trainer = ModelTrainer(
training_image=training_image,
source_code=source_code,
compute=compute,
stopping_condition=stopping_condition,
hyperparameters={
"epochs": 1, # Use 1 epoch for faster tuning
"backend": "gloo"
},
sagemaker_session=sagemaker_session,
role=role,
base_job_name=base_job_prefix
)
print("ModelTrainer configured successfully")
print(f"Training Image: {training_image}")
print(f"Instance Type: {training_instance_type}")
Configure HyperparameterTuner#
Define hyperparameter ranges and create a HyperparameterTuner to optimize the model.
# Define hyperparameter ranges to tune
hyperparameter_ranges = {
"lr": ContinuousParameter(0.001, 0.1),
"batch-size": CategoricalParameter([32, 64, 128, 256, 512]),
}
# Define objective metric
objective_metric_name = "average test loss"
objective_type = "Minimize"
# Define metric definitions
metric_definitions = [
{
"Name": "average test loss",
"Regex": "Test set: Average loss: ([0-9\\.]+)"
}
]
# Create HyperparameterTuner
tuner = HyperparameterTuner(
model_trainer=model_trainer,
objective_metric_name=objective_metric_name,
hyperparameter_ranges=hyperparameter_ranges,
metric_definitions=metric_definitions,
max_jobs=3,
max_parallel_jobs=2,
strategy="Random",
objective_type=objective_type,
early_stopping_type="Auto"
)
print("HyperparameterTuner configured successfully")
Run Hyperparameter Tuning Job#
Start the hyperparameter tuning job.
# Prepare input data
training_data = InputData(
channel_name="training",
data_source=s3_data_uri
)
# Start tuning job
print("Starting hyperparameter tuning job...")
tuner.tune(
inputs=[training_data],
wait=False
)
tuning_job_name = tuner._current_job_name
print(f"\nTuning job started: {tuning_job_name}")
Monitor Status#
# Check status
response = tuner.describe()
print(f"Job Name: {response.hyper_parameter_tuning_job_name}")
print(f"Status: {response.hyper_parameter_tuning_job_status}")
Wait for Completion (Optional)#
# Uncomment to wait
# tuner.wait()
Get Best Job (After Completion)#
# Get best training job
try:
best_job_name = tuner.best_training_job()
print(f"Best Training Job: {best_job_name}")
except Exception as e:
print(f"Not yet available: {e}")
Analyze Results (After Completion)#
# Get analytics
try:
analytics = tuner.analytics()
df = analytics.dataframe()
print(f"Results: {df.shape}")
display(df.head())
except Exception as e:
print(f"Analytics not yet available: {e}")