SageMaker V3 Train-to-Inference E2E with MLflow Integration#

This notebook demonstrates the complete end-to-end workflow from training a custom PyTorch model to deploying it for inference on SageMaker cloud infrastructure, with MLflow 3.x tracking and model registry integration.

Prerequisites#

  • SageMaker MLflow App created (tracking server ARN required)

  • IAM permissions for MLflow tracking and model registry

  • AWS credentials configured

Step 0: Install Dependencies#

Note: There are known issues with MLflow model path resolution. Install the latest published SDK from GitHub for the latest fixes.

# Install fix for MLflow path resolution issues
%pip install mlflow==3.4.0

NOTE: You must restart your kernel#

Step 1: Configuration#

Set up MLflow tracking server and training configuration.

import uuid
from sagemaker.core import image_uris
from sagemaker.core.helper.session_helper import Session

# =============================================================================
# MLflow Configuration - UPDATE THIS WITH YOUR TRACKING SERVER ARN
# =============================================================================
# Eg. "arn:aws:sagemaker:us-east-1:12345678:mlflow-app/app-ABCDEFGH123"
MLFLOW_TRACKING_ARN = "XXXXX"

# AWS Configuration
AWS_REGION = Session.boto_region_name

# Get PyTorch training image dynamically
PYTORCH_TRAINING_IMAGE = image_uris.retrieve(
    framework="pytorch",
    region=AWS_REGION,
    version="2.5",
    py_version="py311",
    instance_type="ml.m5.xlarge",
    image_scope="training"
)
print(f"Using PyTorch training image: {PYTORCH_TRAINING_IMAGE}")

# Naming prefixes
MODEL_NAME_PREFIX = "mlflow-e2e-model"
ENDPOINT_NAME_PREFIX = "mlflow-e2e-endpoint"
TRAINING_JOB_PREFIX = "mlflow-e2e-pytorch"
MLFLOW_EXPERIMENT_NAME = "sagemaker-v3-e2e-training"
MLFLOW_REGISTERED_MODEL_NAME = "pytorch-simple-classifier"

# Generate unique identifiers
unique_id = str(uuid.uuid4())[:8]
training_job_name = f"{TRAINING_JOB_PREFIX}-{unique_id}"
model_name = f"{MODEL_NAME_PREFIX}-{unique_id}"
endpoint_name = f"{ENDPOINT_NAME_PREFIX}-{unique_id}"

print(f"Training job name: {training_job_name}")
print(f"Model name: {model_name}")
print(f"Endpoint name: {endpoint_name}")

Step 2: Connect to MLflow Tracking Server#

import mlflow

# Connect to SageMaker MLflow tracking server
mlflow.set_tracking_uri(MLFLOW_TRACKING_ARN)

# Create or get experiment
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)

print(f"Connected to MLflow tracking server")
print(f"Experiment: {MLFLOW_EXPERIMENT_NAME}")

Step 3: Create Training Code with MLflow Logging#

Create a PyTorch training script that logs metrics and registers the model to MLflow.

import tempfile
import os

def create_pytorch_training_code_with_mlflow(mlflow_tracking_arn, experiment_name, registered_model_name):
    """Create PyTorch training script with MLflow integration."""
    temp_dir = tempfile.mkdtemp()
    
    train_script = f'''import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os
import mlflow
import mlflow.pytorch
from mlflow.models import infer_signature

class SimpleModel(nn.Module):
    def __init__(self, input_dim=4, output_dim=2):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return torch.softmax(self.linear(x), dim=1)

def train():
    # MLflow setup
    mlflow.set_tracking_uri("{mlflow_tracking_arn}")
    mlflow.set_experiment("{experiment_name}")
    
    # Hyperparameters
    learning_rate = 0.01
    epochs = 10
    batch_size = 32
    input_dim = 4
    output_dim = 2
    
    with mlflow.start_run() as run:
        # Log hyperparameters
        mlflow.log_params({{
            "learning_rate": learning_rate,
            "epochs": epochs,
            "batch_size": batch_size,
            "input_dim": input_dim,
            "output_dim": output_dim,
            "optimizer": "Adam",
            "loss_function": "CrossEntropyLoss"
        }})
        
        model = SimpleModel(input_dim, output_dim)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        criterion = nn.CrossEntropyLoss()
        
        # Synthetic data
        X = torch.randn(100, input_dim)
        y = torch.randint(0, output_dim, (100,))
        dataset = TensorDataset(X, y)
        dataloader = DataLoader(dataset, batch_size=batch_size)
        
        # Training loop with metric logging
        model.train()
        for epoch in range(epochs):
            epoch_loss = 0.0
            correct = 0
            total = 0
            
            for batch_x, batch_y in dataloader:
                optimizer.zero_grad()
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y).sum().item()
            
            avg_loss = epoch_loss / len(dataloader)
            accuracy = correct / total
            
            # Log metrics per epoch
            mlflow.log_metrics({{
                "train_loss": avg_loss,
                "train_accuracy": accuracy
            }}, step=epoch)
            
            print(f"Epoch {{epoch+1}}/{{epochs}} - Loss: {{avg_loss:.4f}}, Accuracy: {{accuracy:.4f}}")
        
        # Log final metrics
        mlflow.log_metrics({{
            "final_loss": avg_loss,
            "final_accuracy": accuracy
        }})
        
        # Infer signature and register model to MLflow
        model.eval()
        signature = infer_signature(
            X.numpy(),
            model(X).detach().numpy()
        )
        
        # Log and register model in one step
        mlflow.pytorch.log_model(
            model,
            name="{registered_model_name}",
            signature=signature,
            registered_model_name="{registered_model_name}"
        )
        
        print(f"Model registered to MLflow: {registered_model_name}")
        print(f"Run ID: {{run.info.run_id}}")
        
        print("Training completed!")

if __name__ == "__main__":
    train()
'''
    
    with open(os.path.join(temp_dir, 'train.py'), 'w') as f:
        f.write(train_script)
    
    with open(os.path.join(temp_dir, 'requirements.txt'), 'w') as f:
        f.write('mlflow==3.4.0\nsagemaker-mlflow==0.2.0\ncloudpickle==3.1.2\n')
    
    return temp_dir

# Create training code
training_code_dir = create_pytorch_training_code_with_mlflow(
    MLFLOW_TRACKING_ARN, 
    MLFLOW_EXPERIMENT_NAME,
    MLFLOW_REGISTERED_MODEL_NAME
)
print(f"Training code created in: {training_code_dir}")

Step 4: Create ModelTrainer and Start Training#

Use ModelTrainer to run the training script on SageMaker managed infrastructure. The training job will log metrics to MLflow and register the model to the MLflow registry.

from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import SourceCode

# Training on SageMaker managed infrastructure
model_trainer = ModelTrainer(
    training_image=PYTORCH_TRAINING_IMAGE,
    source_code=SourceCode(
        source_dir=training_code_dir,
        entry_script="train.py",
        requirements="requirements.txt",
    ),
    base_job_name=training_job_name,
)

# Start training job
print(f"Starting training job: {training_job_name}")
print("Metrics will be logged to MLflow during training...")

model_trainer.train() 
print("Training completed! Check MLflow UI for metrics and registered model.")

Step 5: Get Registered Model from MLflow#

Retrieve the registered model from MLflow to get the model URI (models:/<name>/<version>) needed for deployment with ModelBuilder.

# Get the latest version of the registered model
from mlflow import MlflowClient

client = MlflowClient()
registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)

latest_version = registered_model.latest_versions[0]
model_version = latest_version.version
model_source = latest_version.source

# Get S3 URL of model files (for info only)
artifact_uri = client.get_model_version_download_uri(MLFLOW_REGISTERED_MODEL_NAME, model_version)

# MLflow model registry path to use with ModelBuilder
mlflow_model_path = f"models:/{MLFLOW_REGISTERED_MODEL_NAME}/{model_version}"

print(f"Registered Model: {MLFLOW_REGISTERED_MODEL_NAME}")
print(f"Latest Version: {model_version}")
print(f"Source: {model_source}")
print(f"Model artifacts location: {artifact_uri}")

Step 6: Deploy from MLflow Model Registry#

Use ModelBuilder to deploy the model directly from MLflow registry to a SageMaker endpoint.

import json
import torch
from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator
from sagemaker.serve.builder.schema_builder import SchemaBuilder

# =============================================================================
# Custom translators for PyTorch tensor conversion
# 
# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.
# These translators handle the conversion between JSON payloads and PyTorch tensors.
# =============================================================================

class PyTorchInputTranslator(CustomPayloadTranslator):
    """Handles input serialization/deserialization for PyTorch models."""
    def __init__(self):
        super().__init__(content_type='application/json', accept_type='application/json')
    
    def serialize_payload_to_bytes(self, payload: object) -> bytes:
        if isinstance(payload, torch.Tensor):
            return json.dumps(payload.tolist()).encode('utf-8')
        return json.dumps(payload).encode('utf-8')
    
    def deserialize_payload_from_stream(self, stream) -> object:
        data = json.load(stream)
        return torch.tensor(data, dtype=torch.float32)

class PyTorchOutputTranslator(CustomPayloadTranslator):
    """Handles output serialization/deserialization for PyTorch models."""
    def __init__(self):
        super().__init__(content_type='application/json', accept_type='application/json')
    
    def serialize_payload_to_bytes(self, payload: object) -> bytes:
        if isinstance(payload, torch.Tensor):
            return json.dumps(payload.tolist()).encode('utf-8')
        return json.dumps(payload).encode('utf-8')
    
    def deserialize_payload_from_stream(self, stream) -> object:
        return json.load(stream)

# Sample input/output for schema inference
sample_input = [[0.1, 0.2, 0.3, 0.4]]
sample_output = [[0.8, 0.2]]

schema_builder = SchemaBuilder(
    sample_input=sample_input,
    sample_output=sample_output,
    input_translator=PyTorchInputTranslator(),
    output_translator=PyTorchOutputTranslator()
)
from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.serve.mode.function_pointers import Mode

# Cloud deployment to SageMaker endpoint
model_builder = ModelBuilder(
    mode=Mode.SAGEMAKER_ENDPOINT,
    schema_builder=schema_builder,
    model_metadata={
        "MLFLOW_MODEL_PATH": mlflow_model_path,
        "MLFLOW_TRACKING_ARN": MLFLOW_TRACKING_ARN
    },
    dependencies={"auto": False, "custom": ["mlflow==3.4.0", "sagemaker==3.3.1", "numpy==2.4.1", "cloudpickle==3.1.2"]},
)

print(f"ModelBuilder configured with MLflow model: {mlflow_model_path}")
# Build the model
core_model = model_builder.build(model_name=model_name, region=AWS_REGION)
print(f"Model built: {core_model.model_name}")
# Deploy to SageMaker endpoint
core_endpoint = model_builder.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1
)

print(f"Endpoint deployed: {core_endpoint.endpoint_name}")

Step 7: Test the Deployed Model#

Invoke the endpoint with a sample input. The model returns class probabilities (2 classes) as a softmax output.

import boto3

# Test with JSON input
test_data = [[0.1, 0.2, 0.3, 0.4]]

runtime_client = boto3.client('sagemaker-runtime')
response = runtime_client.invoke_endpoint(
    EndpointName=core_endpoint.endpoint_name,
    Body=json.dumps(test_data),
    ContentType='application/json'
)

prediction = json.loads(response['Body'].read().decode('utf-8'))
print(f"Input: {test_data}")
print(f"Prediction: {prediction}")

Step 8: Clean Up Resources#

import shutil
from sagemaker.core.resources import EndpointConfig

# Clean up AWS resources
core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)
core_model.delete()
core_endpoint.delete()
core_endpoint_config.delete()
print("AWS resources cleaned up!")

# Clean up training code directory
try:
    shutil.rmtree(training_code_dir)
    print("Cleaned up training code directory")
except Exception as e:
    print(f"Could not clean up training code: {e}")

print("Note: MLflow experiment runs and registered models are preserved.")

Summary#

This notebook demonstrates cloud deployment of a PyTorch model with MLflow integration:

  1. Training: Runs on SageMaker managed infrastructure with ModelTrainer

  2. MLflow Integration: Logs metrics, parameters, and registers model to MLflow registry

  3. Deployment: Uses ModelBuilder to deploy directly from MLflow registry to a SageMaker endpoint

  4. Inference: Invokes the endpoint with JSON payloads

Key MLflow integration points:

  • mlflow.log_params() - hyperparameters

  • mlflow.log_metrics() - training metrics per epoch

  • mlflow.pytorch.log_model() - model artifact with registry

  • ModelBuilder with MLFLOW_MODEL_PATH - deploy from registry

Key patterns:

  • Custom PayloadTranslator classes for PyTorch tensor serialization