SageMaker V3 Train-to-Inference E2E with MLflow Integration#
This notebook demonstrates the complete end-to-end workflow from training a custom PyTorch model to deploying it for inference on SageMaker cloud infrastructure, with MLflow 3.x tracking and model registry integration.
Prerequisites#
SageMaker MLflow App created (tracking server ARN required)
IAM permissions for MLflow tracking and model registry
AWS credentials configured
Step 0: Install Dependencies#
Note: There are known issues with MLflow model path resolution. Install the latest published SDK from GitHub for the latest fixes.
# Install fix for MLflow path resolution issues
%pip install mlflow==3.4.0
NOTE: You must restart your kernel#
Step 1: Configuration#
Set up MLflow tracking server and training configuration.
import uuid
from sagemaker.core import image_uris
from sagemaker.core.helper.session_helper import Session
# =============================================================================
# MLflow Configuration - UPDATE THIS WITH YOUR TRACKING SERVER ARN
# =============================================================================
# Eg. "arn:aws:sagemaker:us-east-1:12345678:mlflow-app/app-ABCDEFGH123"
MLFLOW_TRACKING_ARN = "XXXXX"
# AWS Configuration
AWS_REGION = Session.boto_region_name
# Get PyTorch training image dynamically
PYTORCH_TRAINING_IMAGE = image_uris.retrieve(
framework="pytorch",
region=AWS_REGION,
version="2.5",
py_version="py311",
instance_type="ml.m5.xlarge",
image_scope="training"
)
print(f"Using PyTorch training image: {PYTORCH_TRAINING_IMAGE}")
# Naming prefixes
MODEL_NAME_PREFIX = "mlflow-e2e-model"
ENDPOINT_NAME_PREFIX = "mlflow-e2e-endpoint"
TRAINING_JOB_PREFIX = "mlflow-e2e-pytorch"
MLFLOW_EXPERIMENT_NAME = "sagemaker-v3-e2e-training"
MLFLOW_REGISTERED_MODEL_NAME = "pytorch-simple-classifier"
# Generate unique identifiers
unique_id = str(uuid.uuid4())[:8]
training_job_name = f"{TRAINING_JOB_PREFIX}-{unique_id}"
model_name = f"{MODEL_NAME_PREFIX}-{unique_id}"
endpoint_name = f"{ENDPOINT_NAME_PREFIX}-{unique_id}"
print(f"Training job name: {training_job_name}")
print(f"Model name: {model_name}")
print(f"Endpoint name: {endpoint_name}")
Step 2: Connect to MLflow Tracking Server#
import mlflow
# Connect to SageMaker MLflow tracking server
mlflow.set_tracking_uri(MLFLOW_TRACKING_ARN)
# Create or get experiment
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
print(f"Connected to MLflow tracking server")
print(f"Experiment: {MLFLOW_EXPERIMENT_NAME}")
Step 3: Create Training Code with MLflow Logging#
Create a PyTorch training script that logs metrics and registers the model to MLflow.
import tempfile
import os
def create_pytorch_training_code_with_mlflow(mlflow_tracking_arn, experiment_name, registered_model_name):
"""Create PyTorch training script with MLflow integration."""
temp_dir = tempfile.mkdtemp()
train_script = f'''import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os
import mlflow
import mlflow.pytorch
from mlflow.models import infer_signature
class SimpleModel(nn.Module):
def __init__(self, input_dim=4, output_dim=2):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return torch.softmax(self.linear(x), dim=1)
def train():
# MLflow setup
mlflow.set_tracking_uri("{mlflow_tracking_arn}")
mlflow.set_experiment("{experiment_name}")
# Hyperparameters
learning_rate = 0.01
epochs = 10
batch_size = 32
input_dim = 4
output_dim = 2
with mlflow.start_run() as run:
# Log hyperparameters
mlflow.log_params({{
"learning_rate": learning_rate,
"epochs": epochs,
"batch_size": batch_size,
"input_dim": input_dim,
"output_dim": output_dim,
"optimizer": "Adam",
"loss_function": "CrossEntropyLoss"
}})
model = SimpleModel(input_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
# Synthetic data
X = torch.randn(100, input_dim)
y = torch.randint(0, output_dim, (100,))
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=batch_size)
# Training loop with metric logging
model.train()
for epoch in range(epochs):
epoch_loss = 0.0
correct = 0
total = 0
for batch_x, batch_y in dataloader:
optimizer.zero_grad()
outputs = model(batch_x)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += batch_y.size(0)
correct += (predicted == batch_y).sum().item()
avg_loss = epoch_loss / len(dataloader)
accuracy = correct / total
# Log metrics per epoch
mlflow.log_metrics({{
"train_loss": avg_loss,
"train_accuracy": accuracy
}}, step=epoch)
print(f"Epoch {{epoch+1}}/{{epochs}} - Loss: {{avg_loss:.4f}}, Accuracy: {{accuracy:.4f}}")
# Log final metrics
mlflow.log_metrics({{
"final_loss": avg_loss,
"final_accuracy": accuracy
}})
# Infer signature and register model to MLflow
model.eval()
signature = infer_signature(
X.numpy(),
model(X).detach().numpy()
)
# Log and register model in one step
mlflow.pytorch.log_model(
model,
name="{registered_model_name}",
signature=signature,
registered_model_name="{registered_model_name}"
)
print(f"Model registered to MLflow: {registered_model_name}")
print(f"Run ID: {{run.info.run_id}}")
print("Training completed!")
if __name__ == "__main__":
train()
'''
with open(os.path.join(temp_dir, 'train.py'), 'w') as f:
f.write(train_script)
with open(os.path.join(temp_dir, 'requirements.txt'), 'w') as f:
f.write('mlflow==3.4.0\nsagemaker-mlflow==0.2.0\ncloudpickle==3.1.2\n')
return temp_dir
# Create training code
training_code_dir = create_pytorch_training_code_with_mlflow(
MLFLOW_TRACKING_ARN,
MLFLOW_EXPERIMENT_NAME,
MLFLOW_REGISTERED_MODEL_NAME
)
print(f"Training code created in: {training_code_dir}")
Step 4: Create ModelTrainer and Start Training#
Use ModelTrainer to run the training script on SageMaker managed infrastructure. The training job will log metrics to MLflow and register the model to the MLflow registry.
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import SourceCode
# Training on SageMaker managed infrastructure
model_trainer = ModelTrainer(
training_image=PYTORCH_TRAINING_IMAGE,
source_code=SourceCode(
source_dir=training_code_dir,
entry_script="train.py",
requirements="requirements.txt",
),
base_job_name=training_job_name,
)
# Start training job
print(f"Starting training job: {training_job_name}")
print("Metrics will be logged to MLflow during training...")
model_trainer.train()
print("Training completed! Check MLflow UI for metrics and registered model.")
Step 5: Get Registered Model from MLflow#
Retrieve the registered model from MLflow to get the model URI (models:/<name>/<version>) needed for deployment with ModelBuilder.
# Get the latest version of the registered model
from mlflow import MlflowClient
client = MlflowClient()
registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)
latest_version = registered_model.latest_versions[0]
model_version = latest_version.version
model_source = latest_version.source
# Get S3 URL of model files (for info only)
artifact_uri = client.get_model_version_download_uri(MLFLOW_REGISTERED_MODEL_NAME, model_version)
# MLflow model registry path to use with ModelBuilder
mlflow_model_path = f"models:/{MLFLOW_REGISTERED_MODEL_NAME}/{model_version}"
print(f"Registered Model: {MLFLOW_REGISTERED_MODEL_NAME}")
print(f"Latest Version: {model_version}")
print(f"Source: {model_source}")
print(f"Model artifacts location: {artifact_uri}")
Step 6: Deploy from MLflow Model Registry#
Use ModelBuilder to deploy the model directly from MLflow registry to a SageMaker endpoint.
import json
import torch
from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator
from sagemaker.serve.builder.schema_builder import SchemaBuilder
# =============================================================================
# Custom translators for PyTorch tensor conversion
#
# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.
# These translators handle the conversion between JSON payloads and PyTorch tensors.
# =============================================================================
class PyTorchInputTranslator(CustomPayloadTranslator):
"""Handles input serialization/deserialization for PyTorch models."""
def __init__(self):
super().__init__(content_type='application/json', accept_type='application/json')
def serialize_payload_to_bytes(self, payload: object) -> bytes:
if isinstance(payload, torch.Tensor):
return json.dumps(payload.tolist()).encode('utf-8')
return json.dumps(payload).encode('utf-8')
def deserialize_payload_from_stream(self, stream) -> object:
data = json.load(stream)
return torch.tensor(data, dtype=torch.float32)
class PyTorchOutputTranslator(CustomPayloadTranslator):
"""Handles output serialization/deserialization for PyTorch models."""
def __init__(self):
super().__init__(content_type='application/json', accept_type='application/json')
def serialize_payload_to_bytes(self, payload: object) -> bytes:
if isinstance(payload, torch.Tensor):
return json.dumps(payload.tolist()).encode('utf-8')
return json.dumps(payload).encode('utf-8')
def deserialize_payload_from_stream(self, stream) -> object:
return json.load(stream)
# Sample input/output for schema inference
sample_input = [[0.1, 0.2, 0.3, 0.4]]
sample_output = [[0.8, 0.2]]
schema_builder = SchemaBuilder(
sample_input=sample_input,
sample_output=sample_output,
input_translator=PyTorchInputTranslator(),
output_translator=PyTorchOutputTranslator()
)
from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.serve.mode.function_pointers import Mode
# Cloud deployment to SageMaker endpoint
model_builder = ModelBuilder(
mode=Mode.SAGEMAKER_ENDPOINT,
schema_builder=schema_builder,
model_metadata={
"MLFLOW_MODEL_PATH": mlflow_model_path,
"MLFLOW_TRACKING_ARN": MLFLOW_TRACKING_ARN
},
dependencies={"auto": False, "custom": ["mlflow==3.4.0", "sagemaker==3.3.1", "numpy==2.4.1", "cloudpickle==3.1.2"]},
)
print(f"ModelBuilder configured with MLflow model: {mlflow_model_path}")
# Build the model
core_model = model_builder.build(model_name=model_name, region=AWS_REGION)
print(f"Model built: {core_model.model_name}")
# Deploy to SageMaker endpoint
core_endpoint = model_builder.deploy(
endpoint_name=endpoint_name,
initial_instance_count=1
)
print(f"Endpoint deployed: {core_endpoint.endpoint_name}")
Step 7: Test the Deployed Model#
Invoke the endpoint with a sample input. The model returns class probabilities (2 classes) as a softmax output.
import boto3
# Test with JSON input
test_data = [[0.1, 0.2, 0.3, 0.4]]
runtime_client = boto3.client('sagemaker-runtime')
response = runtime_client.invoke_endpoint(
EndpointName=core_endpoint.endpoint_name,
Body=json.dumps(test_data),
ContentType='application/json'
)
prediction = json.loads(response['Body'].read().decode('utf-8'))
print(f"Input: {test_data}")
print(f"Prediction: {prediction}")
Step 8: Clean Up Resources#
import shutil
from sagemaker.core.resources import EndpointConfig
# Clean up AWS resources
core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)
core_model.delete()
core_endpoint.delete()
core_endpoint_config.delete()
print("AWS resources cleaned up!")
# Clean up training code directory
try:
shutil.rmtree(training_code_dir)
print("Cleaned up training code directory")
except Exception as e:
print(f"Could not clean up training code: {e}")
print("Note: MLflow experiment runs and registered models are preserved.")
Summary#
This notebook demonstrates cloud deployment of a PyTorch model with MLflow integration:
Training: Runs on SageMaker managed infrastructure with ModelTrainer
MLflow Integration: Logs metrics, parameters, and registers model to MLflow registry
Deployment: Uses ModelBuilder to deploy directly from MLflow registry to a SageMaker endpoint
Inference: Invokes the endpoint with JSON payloads
Key MLflow integration points:
mlflow.log_params()- hyperparametersmlflow.log_metrics()- training metrics per epochmlflow.pytorch.log_model()- model artifact with registryModelBuilderwithMLFLOW_MODEL_PATH- deploy from registry
Key patterns:
Custom
PayloadTranslatorclasses for PyTorch tensor serialization