SageMaker V3 Custom Distributed Training Example#
This notebook demonstrates how to create and use custom distributed training drivers with SageMaker V3 ModelTrainer.
import os
import tempfile
import shutil
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import SourceCode
from sagemaker.train.distributed import DistributedConfig
from sagemaker.core.helper.session_helper import Session, get_execution_role
from sagemaker.core import image_uris
Step 1: Setup Session and Create Test Files#
Initialize the SageMaker session and create the custom distributed driver files.
sagemaker_session = Session()
role = get_execution_role()
region = sagemaker_session.boto_region_name
DEFAULT_CPU_IMAGE = image_uris.retrieve(
framework="pytorch",
region=region,
version="2.0.0",
py_version="py310",
instance_type="ml.m5.xlarge",
image_scope="training"
)
# Create temporary directories
temp_dir = tempfile.mkdtemp()
custom_drivers_dir = os.path.join(temp_dir, "custom_drivers")
scripts_dir = os.path.join(temp_dir, "scripts")
os.makedirs(custom_drivers_dir, exist_ok=True)
os.makedirs(scripts_dir, exist_ok=True)
print(f"Created temporary directories in: {temp_dir}")
Step 2: Create Custom Driver and Entry Script#
Create the custom driver script and entry script for training.
# Create custom driver script
driver_script = '''
import json
import os
import subprocess
import sys
def main():
driver_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
process_count_per_node = driver_config["process_count_per_node"]
assert process_count_per_node != None
hps = json.loads(os.environ["SM_HPS"])
assert hps != None
assert isinstance(hps, dict)
source_dir = os.environ["SM_SOURCE_DIR"]
assert source_dir == "/opt/ml/input/data/code"
sm_drivers_dir = os.environ["SM_DISTRIBUTED_DRIVER_DIR"]
assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/distributed_drivers"
entry_script = os.environ["SM_ENTRY_SCRIPT"]
assert entry_script != None
python = sys.executable
command = [python, entry_script]
print(f"Running command: {command}")
subprocess.run(command, check=True)
if __name__ == "__main__":
print("Running custom driver script")
main()
print("Finished running custom driver script")
'''
with open(os.path.join(custom_drivers_dir, "driver.py"), 'w') as f:
f.write(driver_script)
print("Created custom driver script")
# Create entry script
entry_script = '''
import json
import os
import time
def main():
hps = json.loads(os.environ["SM_HPS"])
assert hps != None
print(f"Hyperparameters: {hps}")
print("Running pseudo training script")
for epochs in range(hps["epochs"]):
print(f"Epoch: {epochs}")
time.sleep(1)
print("Finished running pseudo training script")
# Save results
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
os.makedirs(model_dir, exist_ok=True)
results = {"status": "success", "epochs_completed": hps["epochs"]}
with open(os.path.join(model_dir, "results.json"), "w") as f:
json.dump(results, f, indent=2)
if __name__ == "__main__":
main()
'''
with open(os.path.join(scripts_dir, "entry_script.py"), 'w') as f:
f.write(entry_script)
print("Created entry script")
Step 3: Define Custom Distributed Driver#
Create the custom distributed driver class.
class CustomDriver(DistributedConfig):
process_count_per_node: int = None
@property
def driver_dir(self) -> str:
return custom_drivers_dir
@property
def driver_script(self) -> str:
return "driver.py"
print("Custom distributed driver class defined!")
print(f"Driver directory: {custom_drivers_dir}")
print(f"Driver script: driver.py")
Step 4: Configure Source Code and Hyperparameters#
Set up the source code and hyperparameters for training.
source_code = SourceCode(
source_dir=scripts_dir,
entry_script="entry_script.py",
)
hyperparameters = {"epochs": 10}
custom_driver = CustomDriver(process_count_per_node=2)
print(f"Source directory: {scripts_dir}")
print(f"Entry script: entry_script.py")
print(f"Hyperparameters: {hyperparameters}")
print(f"Custom driver: {custom_driver}")
Step 5: Create ModelTrainer with Custom Driver#
Initialize ModelTrainer with the custom distributed configuration.
model_trainer = ModelTrainer(
sagemaker_session=sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=hyperparameters,
source_code=source_code,
distributed=custom_driver,
base_job_name="custom-distributed-driver",
)
print("ModelTrainer created with custom distributed driver!")
print(f"Job name: custom-distributed-driver")
print(f"Distributed configuration: {model_trainer.distributed}")
Step 6: Run Custom Distributed Training#
Start the distributed training job using the custom driver.
print("Starting custom distributed training...")
try:
model_trainer.train()
print(f"Custom distributed training completed successfully!")
print(f"Job name: {model_trainer._latest_training_job.training_job_name}")
training_successful = True
except Exception as e:
print(f"Training failed with error: {e}")
training_successful = False
Step 7: Analyze Training Results#
Examine the results from the custom distributed training.
if training_successful:
job_name = model_trainer._latest_training_job.training_job_name
model_artifacts = model_trainer._latest_training_job.model_artifacts
print("Custom Distributed Training Results:")
print("=" * 40)
print(f"Job Name: {job_name}")
print(f"Model Artifacts: {model_artifacts}")
print(f"Training Image: {DEFAULT_CPU_IMAGE}")
print("\nCustom Driver Configuration:")
print(f"Driver Class: {custom_driver.__class__.__name__}")
print(f"Process Count Per Node: {custom_driver.process_count_per_node}")
print(f"Driver Directory: {custom_driver.driver_dir}")
print(f"Driver Script: {custom_driver.driver_script}")
print("\nHyperparameters Used:")
for key, value in hyperparameters.items():
print(f" {key}: {value}")
print("\n✓ Custom distributed training completed successfully!")
else:
print("Training was not successful.")
Step 8: Clean Up#
Clean up temporary files.
try:
shutil.rmtree(temp_dir)
print(f"Cleaned up temporary directory: {temp_dir}")
except Exception as e:
print(f"Could not clean up temp directory: {e}")
print("Cleanup completed!")
Summary#
This notebook demonstrated:
Custom distributed driver creation: Extending DistributedConfig for specialized needs
Driver coordination: How custom drivers manage training processes
ModelTrainer integration: Seamless integration with SageMaker V3 training
Custom training logic: Implementing specialized training patterns
Custom distributed drivers provide flexibility for implementing specialized coordination logic, framework integration, and advanced debugging capabilities for distributed training scenarios.