sagemaker.train.container_drivers.distributed_drivers.torchrun_driver

sagemaker.train.container_drivers.distributed_drivers.torchrun_driver#

This module is the entry point for the Torchrun driver script.

Functions

create_commands()

Create the Torch Distributed command to execute

get_base_pytorch_command()

Get the base Torch Distributed launcher to execute

main()

Main function to execute the PyTorch distributed training script.

pytorch_version()

Get the PyTorch version as a tuple of integers.

setup_env()

Setup the environment variables for PyTorch distributed training

sagemaker.train.container_drivers.distributed_drivers.torchrun_driver.create_commands()[source]#

Create the Torch Distributed command to execute

sagemaker.train.container_drivers.distributed_drivers.torchrun_driver.get_base_pytorch_command() List[str][source]#

Get the base Torch Distributed launcher to execute

sagemaker.train.container_drivers.distributed_drivers.torchrun_driver.main()[source]#

Main function to execute the PyTorch distributed training script.

This function sets some environment variables and executes the PyTorch distributed training script.

Execution Lifecycle: 1. Setup Environment Variables for PyTorch Distributed Training 2. Create Torch Distributed Command 3. Execute Torch Distributed Command with user script provided in entry_script 4. Exit

sagemaker.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version() Tuple[int, int][source]#

Get the PyTorch version as a tuple of integers.

sagemaker.train.container_drivers.distributed_drivers.torchrun_driver.setup_env()[source]#

Setup the environment variables for PyTorch distributed training