sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils#

This module provides mpi related utility functions for the container drivers.

Functions

bootstrap_master_node(worker_hosts)

Bootstrap the master node.

bootstrap_worker_node(master_host[, status_file])

Bootstrap the worker nodes.

get_mpirun_command(host_count, host_list, ...)

Fetch mpi command

start_sshd_daemon()

Start the SSH daemon on the current node.

validate_smddpmprun()

Whether smddpmprun is installed.

validate_smddprun()

Whether smddprun is installed.

write_env_vars_to_file()

Write environment variables to /etc/environment file.

write_status_file_to_workers(worker_hosts[, ...])

Write the status file to all worker nodes.

Classes

CustomHostKeyPolicy()

Class to handle host key policy for SageMaker distributed training SSH connections.

class sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.CustomHostKeyPolicy[source]#

Bases: MissingHostKeyPolicy

Class to handle host key policy for SageMaker distributed training SSH connections.

Example: >>> client = paramiko.SSHClient() >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) >>> # Will succeed for SageMaker algorithm containers >>> client.connect(‘algo-1234.internal’) >>> # Will raise SSHException for other unknown hosts >>> client.connect(‘unknown-host’) # raises SSHException

missing_host_key(client, hostname, key)[source]#

Accept host keys for algo-* hostnames, reject others.

Parameters:
  • client – The SSHClient instance

  • hostname – The hostname attempting to connect

  • key – The host key

Raises:

paramiko.SSHException – If hostname doesn’t match algo-* pattern

sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.bootstrap_master_node(worker_hosts: List[str])[source]#

Bootstrap the master node.

sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.bootstrap_worker_node(master_host: str, status_file: str = '/tmp/done.algo-1')[source]#

Bootstrap the worker nodes.

sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.get_mpirun_command(host_count: int, host_list: List[str], num_processes: int, additional_options: List[str], entry_script_path: str)[source]#

Fetch mpi command

sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.start_sshd_daemon()[source]#

Start the SSH daemon on the current node.

sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.validate_smddpmprun() bool[source]#

Whether smddpmprun is installed.

Returns:

True if both are installed

Return type:

bool

sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.validate_smddprun() bool[source]#

Whether smddprun is installed.

Returns:

True if installed

Return type:

bool

sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.write_env_vars_to_file()[source]#

Write environment variables to /etc/environment file.

sagemaker.core.modules.train.container_drivers.distributed_drivers.mpi_utils.write_status_file_to_workers(worker_hosts: List[str], status_file: str = '/tmp/done.algo-1')[source]#

Write the status file to all worker nodes.