SageMaker V3 Custom Distributed Training Example#

This notebook demonstrates how to create and use custom distributed training drivers with SageMaker V3 ModelTrainer.

import os
import tempfile
import shutil

from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import SourceCode
from sagemaker.train.distributed import DistributedConfig
from sagemaker.core.helper.session_helper import Session, get_execution_role
from sagemaker.core import image_uris

Step 1: Setup Session and Create Test Files#

Initialize the SageMaker session and create the custom distributed driver files.

sagemaker_session = Session()
role = get_execution_role()
region = sagemaker_session.boto_region_name

DEFAULT_CPU_IMAGE = image_uris.retrieve(
    framework="pytorch",
    region=region,
    version="2.0.0",
    py_version="py310",
    instance_type="ml.m5.xlarge",
    image_scope="training"
)

# Create temporary directories
temp_dir = tempfile.mkdtemp()
custom_drivers_dir = os.path.join(temp_dir, "custom_drivers")
scripts_dir = os.path.join(temp_dir, "scripts")

os.makedirs(custom_drivers_dir, exist_ok=True)
os.makedirs(scripts_dir, exist_ok=True)

print(f"Created temporary directories in: {temp_dir}")

Step 2: Create Custom Driver and Entry Script#

Create the custom driver script and entry script for training.

# Create custom driver script
driver_script = '''
import json
import os
import subprocess
import sys

def main():
    driver_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
    process_count_per_node = driver_config["process_count_per_node"]
    assert process_count_per_node != None

    hps = json.loads(os.environ["SM_HPS"])
    assert hps != None
    assert isinstance(hps, dict)

    source_dir = os.environ["SM_SOURCE_DIR"]
    assert source_dir == "/opt/ml/input/data/code"
    sm_drivers_dir = os.environ["SM_DISTRIBUTED_DRIVER_DIR"]
    assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/distributed_drivers"

    entry_script = os.environ["SM_ENTRY_SCRIPT"]
    assert entry_script != None

    python = sys.executable

    command = [python, entry_script]
    print(f"Running command: {command}")
    subprocess.run(command, check=True)

if __name__ == "__main__":
    print("Running custom driver script")
    main()
    print("Finished running custom driver script")
'''

with open(os.path.join(custom_drivers_dir, "driver.py"), 'w') as f:
    f.write(driver_script)

print("Created custom driver script")
# Create entry script
entry_script = '''
import json
import os
import time

def main():
    hps = json.loads(os.environ["SM_HPS"])
    assert hps != None
    print(f"Hyperparameters: {hps}")

    print("Running pseudo training script")
    for epochs in range(hps["epochs"]):
        print(f"Epoch: {epochs}")
        time.sleep(1)
    print("Finished running pseudo training script")
    
    # Save results
    model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
    os.makedirs(model_dir, exist_ok=True)
    
    results = {"status": "success", "epochs_completed": hps["epochs"]}
    with open(os.path.join(model_dir, "results.json"), "w") as f:
        json.dump(results, f, indent=2)

if __name__ == "__main__":
    main()
'''

with open(os.path.join(scripts_dir, "entry_script.py"), 'w') as f:
    f.write(entry_script)

print("Created entry script")

Step 3: Define Custom Distributed Driver#

Create the custom distributed driver class.

class CustomDriver(DistributedConfig):
    process_count_per_node: int = None

    @property
    def driver_dir(self) -> str:
        return custom_drivers_dir

    @property
    def driver_script(self) -> str:
        return "driver.py"

print("Custom distributed driver class defined!")
print(f"Driver directory: {custom_drivers_dir}")
print(f"Driver script: driver.py")

Step 4: Configure Source Code and Hyperparameters#

Set up the source code and hyperparameters for training.

source_code = SourceCode(
    source_dir=scripts_dir,
    entry_script="entry_script.py",
)

hyperparameters = {"epochs": 10}

custom_driver = CustomDriver(process_count_per_node=2)

print(f"Source directory: {scripts_dir}")
print(f"Entry script: entry_script.py")
print(f"Hyperparameters: {hyperparameters}")
print(f"Custom driver: {custom_driver}")

Step 5: Create ModelTrainer with Custom Driver#

Initialize ModelTrainer with the custom distributed configuration.

model_trainer = ModelTrainer(
    sagemaker_session=sagemaker_session,
    training_image=DEFAULT_CPU_IMAGE,
    hyperparameters=hyperparameters,
    source_code=source_code,
    distributed=custom_driver,
    base_job_name="custom-distributed-driver",
)

print("ModelTrainer created with custom distributed driver!")
print(f"Job name: custom-distributed-driver")
print(f"Distributed configuration: {model_trainer.distributed}")

Step 6: Run Custom Distributed Training#

Start the distributed training job using the custom driver.

print("Starting custom distributed training...")

try:
    model_trainer.train()
    print(f"Custom distributed training completed successfully!")
    print(f"Job name: {model_trainer._latest_training_job.training_job_name}")
    training_successful = True
except Exception as e:
    print(f"Training failed with error: {e}")
    training_successful = False

Step 7: Analyze Training Results#

Examine the results from the custom distributed training.

if training_successful:
    job_name = model_trainer._latest_training_job.training_job_name
    model_artifacts = model_trainer._latest_training_job.model_artifacts
    
    print("Custom Distributed Training Results:")
    print("=" * 40)
    print(f"Job Name: {job_name}")
    print(f"Model Artifacts: {model_artifacts}")
    print(f"Training Image: {DEFAULT_CPU_IMAGE}")
    
    print("\nCustom Driver Configuration:")
    print(f"Driver Class: {custom_driver.__class__.__name__}")
    print(f"Process Count Per Node: {custom_driver.process_count_per_node}")
    print(f"Driver Directory: {custom_driver.driver_dir}")
    print(f"Driver Script: {custom_driver.driver_script}")
    
    print("\nHyperparameters Used:")
    for key, value in hyperparameters.items():
        print(f"  {key}: {value}")
    
    print("\n✓ Custom distributed training completed successfully!")
    
else:
    print("Training was not successful.")

Step 8: Clean Up#

Clean up temporary files.

try:
    shutil.rmtree(temp_dir)
    print(f"Cleaned up temporary directory: {temp_dir}")
except Exception as e:
    print(f"Could not clean up temp directory: {e}")

print("Cleanup completed!")

Summary#

This notebook demonstrated:

  1. Custom distributed driver creation: Extending DistributedConfig for specialized needs

  2. Driver coordination: How custom drivers manage training processes

  3. ModelTrainer integration: Seamless integration with SageMaker V3 training

  4. Custom training logic: Implementing specialized training patterns

Custom distributed drivers provide flexibility for implementing specialized coordination logic, framework integration, and advanced debugging capabilities for distributed training scenarios.