Source code for sagemaker.train.container_drivers.distributed_drivers.torchrun_driver
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module is the entry point for the Torchrun driver script."""
from __future__ import absolute_import
import os
import sys
import json
from pathlib import Path
from typing import List, Tuple
sys.path.insert(0, str(Path(__file__).parent.parent))
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
logger,
hyperparameters_to_cli_args,
get_process_count,
get_python_executable,
execute_commands,
write_failure_file,
SM_EFA_NCCL_INSTANCES,
SM_EFA_RDMA_INSTANCES,
)
[docs]
def pytorch_version() -> Tuple[int, int]:
"""Get the PyTorch version as a tuple of integers."""
import torch
return tuple(map(int, torch.__version__.split(".")[:2]))
[docs]
def get_base_pytorch_command() -> List[str]:
"""Get the base Torch Distributed launcher to execute"""
if pytorch_version() >= (1, 9):
return ["torchrun"]
return [f"{get_python_executable()}", "-m", "torch.distributed.launch"]
[docs]
def setup_env():
"""Setup the environment variables for PyTorch distributed training"""
instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"]
network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0")
if instance_type in SM_EFA_NCCL_INSTANCES:
# Enable EFA use
os.environ["FI_PROVIDER"] = "efa"
if instance_type in SM_EFA_RDMA_INSTANCES:
# Use EFA's RDMA functionality for one-sided and two-sided transfer
os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1"
os.environ["RDMAV_FORK_SAFE"] = "1"
os.environ["NCCL_SOCKET_IFNAME"] = str(network_interface_name)
os.environ["NCCL_PROTO"] = "simple"
[docs]
def create_commands():
"""Create the Torch Distributed command to execute"""
entry_script = os.environ["SM_ENTRY_SCRIPT"]
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
hyperparameters = json.loads(os.environ["SM_HPS"])
process_count = int(distributed_config["process_count_per_node"] or 0)
process_count = get_process_count(process_count)
host_count = int(os.environ["SM_HOST_COUNT"])
torch_cmd = []
if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1":
torch_cmd.append("neuron_parallel_compile")
torch_cmd.extend(get_base_pytorch_command())
torch_cmd.extend(
[
f"--nnodes={host_count}",
f"--nproc_per_node={process_count}",
]
)
# If more than one node is used, add node rank information
if int(host_count) > 1:
torch_cmd.extend(
[
f"--master_addr={os.environ['SM_MASTER_ADDR']}",
f"--master_port={os.environ['SM_MASTER_PORT']}",
f"--node_rank={os.environ['SM_CURRENT_HOST_RANK']}",
]
)
torch_cmd.extend([entry_script])
args = hyperparameters_to_cli_args(hyperparameters)
torch_cmd += args
return torch_cmd
[docs]
def main():
"""Main function to execute the PyTorch distributed training script.
This function sets some environment variables and executes the PyTorch
distributed training script.
Execution Lifecycle:
1. Setup Environment Variables for PyTorch Distributed Training
2. Create Torch Distributed Command
3. Execute Torch Distributed Command with user script provided in `entry_script`
4. Exit
"""
setup_env()
torch_cmd = create_commands()
logger.info(f"Executing command: {' '.join(torch_cmd)}")
exit_code, traceback = execute_commands(torch_cmd)
if exit_code != 0:
write_failure_file(traceback)
sys.exit(exit_code)
if __name__ == "__main__":
main()