Source code for sagemaker.train.distributed
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Distributed module."""
from __future__ import absolute_import
import os
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List
from sagemaker.train.utils import safe_serialize
from sagemaker.train.constants import SM_DRIVERS_LOCAL_PATH
from sagemaker.train.configs import BaseConfig
[docs]
class SMP(BaseConfig):
"""SMP.
This class is used for configuring the SageMaker Model Parallelism v2 parameters.
For more information on the model parallelism parameters, see:
https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-model-parallel-v2-reference.html#distributed-model-parallel-v2-reference-init-config
Parameters:
hybrid_shard_degree (Optional[int]):
Specifies a sharded parallelism degree for the model.
sm_activation_offloading (Optional[bool]):
Specifies whether to enable the SMP activation offloading implementation.
activation_loading_horizon (Optional[int]):
An integer specifying the activation offloading horizon type for FSDP. This is the
maximum number of checkpointed or offloaded layers whose inputs can be in the GPU
memory simultaneously.
fsdp_cache_flush_warnings (Optional[bool]):
Detects and warns if cache flushes happen in the PyTorch memory manager, because they
can degrade computational performance.
allow_empty_shards (Optional[bool]):
Whether to allow empty shards when sharding tensors if tensor is not divisible. This is
an experimental fix for crash during checkpointing in certain scenarios. Disabling this
falls back to the original PyTorch behavior.
tensor_parallel_degree (Optional[int]):
Specifies a tensor parallelism degree. The value must be between 1 and world_size.
context_parallel_degree (Optional[int]):
Specifies the context parallelism degree. The value must be between 1 and world_size ,
and must be <= hybrid_shard_degree.
expert_parallel_degree (Optional[int]):
Specifies a expert parallelism degree. The value must be between 1 and world_size.
random_seed (Optional[int]):
A seed number for the random operations in distributed modules by SMP tensor
parallelism or expert parallelism.
"""
hybrid_shard_degree: Optional[int] = None
sm_activation_offloading: Optional[bool] = None
activation_loading_horizon: Optional[int] = None
fsdp_cache_flush_warnings: Optional[bool] = None
allow_empty_shards: Optional[bool] = None
tensor_parallel_degree: Optional[int] = None
context_parallel_degree: Optional[int] = None
expert_parallel_degree: Optional[int] = None
random_seed: Optional[int] = None
def _to_mp_hyperparameters(self) -> Dict[str, Any]:
"""Converts to the hyperparameters format for the SageMaker Model Parallelism v2."""
mp_parameters = self.model_dump(exclude_none=True)
hyperparameters = {
"mp_parameters": safe_serialize(mp_parameters),
}
return hyperparameters
[docs]
class DistributedConfig(BaseConfig, ABC):
"""Abstract base class for distributed training configurations.
This class defines the interface that all distributed training configurations
must implement. It provides a standardized way to specify driver scripts and
their locations for distributed training jobs.
"""
@property
@abstractmethod
def driver_dir(self) -> str:
"""Directory containing the driver script.
This property should return the path to the directory containing
the driver script, relative to the container's working directory.
Returns:
str: Path to directory containing the driver script
"""
@property
@abstractmethod
def driver_script(self) -> str:
"""Name of the driver script.
This property should return the name of the Python script that implements
the distributed training driver logic.
Returns:
str: Name of the driver script file
"""
[docs]
class Torchrun(DistributedConfig):
"""Torchrun.
The Torchrun class configures a job that uses ``torchrun`` or
``torch.distributed.launch`` in the backend to launch distributed training.
Parameters:
process_count_per_node (int):
The number of processes to run on each node in the training job.
Will default to the number of GPUs available in the container.
smp (Optional[SMP]):
The SageMaker Model Parallelism v2 parameters.
"""
process_count_per_node: Optional[int] = None
smp: Optional["SMP"] = None
@property
def driver_dir(self) -> str:
"""Directory containing the driver script.
Returns:
str: Path to directory containing the driver script
"""
return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers")
@property
def driver_script(self) -> str:
"""Name of the driver script.
Returns:
str: Name of the driver script file
"""
return "torchrun_driver.py"
[docs]
class MPI(DistributedConfig):
"""MPI.
The MPI class configures a job that uses ``mpirun`` in the backend to launch
distributed training.
Parameters:
process_count_per_node (int):
The number of processes to run on each node in the training job.
Will default to the number of GPUs available in the container.
mpi_additional_options (Optional[str]):
The custom MPI options to use for the training job.
"""
process_count_per_node: Optional[int] = None
mpi_additional_options: Optional[List[str]] = None
@property
def driver_dir(self) -> str:
"""Directory containing the driver script.
Returns:
str: Path to directory containing the driver script
"""
return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers")
@property
def driver_script(self) -> str:
"""Name of the driver script.
Returns:
str: Name of the driver script
"""
return "mpi_driver.py"