# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK"""
from __future__ import absolute_import
import argparse
import json
import os
import subprocess
import sys
import time
from typing import List
import paramiko
if __package__ is None or __package__ == "":
from runtime_environment_manager import (
get_logger,
)
else:
from sagemaker.train.remote_function.runtime_environment.runtime_environment_manager import (
get_logger,
)
SUCCESS_EXIT_CODE = 0
DEFAULT_FAILURE_CODE = 1
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
READY_FILE = "/tmp/ready.%s"
DEFAULT_SSH_PORT = 22
FAILURE_REASON_PATH = "/opt/ml/output/failure"
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
logger = get_logger()
[docs]
class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
"""Class to handle host key policy for SageMaker distributed training SSH connections.
Example:
>>> client = paramiko.SSHClient()
>>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
>>> # Will succeed for SageMaker algorithm containers
>>> client.connect('algo-1234.internal')
>>> # Will raise SSHException for other unknown hosts
>>> client.connect('unknown-host') # raises SSHException
"""
[docs]
def missing_host_key(self, client, hostname, key):
"""Accept host keys for algo-* hostnames, reject others.
Args:
client: The SSHClient instance
hostname: The hostname attempting to connect
key: The host key
Raises:
paramiko.SSHException: If hostname doesn't match algo-* pattern
"""
if hostname.startswith("algo-"):
client.get_host_keys().add(hostname, key.get_name(), key)
return
raise paramiko.SSHException(f"Unknown host key for {hostname}")
def _parse_args(sys_args):
"""Parses CLI arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--job_ended", type=str, default="0")
args, _ = parser.parse_known_args(sys_args)
return args
def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
"""Check if the connection to the provided host and port is possible."""
try:
with paramiko.SSHClient() as client:
client.load_system_host_keys()
client.set_missing_host_key_policy(CustomHostKeyPolicy())
client.connect(host, port=port)
logger.info("Can connect to host %s", host)
return True
except Exception as e: # pylint: disable=W0703
logger.info("Cannot connect to host %s", host)
logger.debug("Connection failed with exception: %s", e)
return False
def _write_file_to_host(host: str, status_file: str) -> bool:
"""Write the a file to the provided host."""
try:
logger.info("Writing %s to %s", status_file, host)
subprocess.run(
["ssh", host, "touch", f"{status_file}"],
capture_output=True,
text=True,
check=True,
)
logger.info("Finished writing status file")
return True
except subprocess.CalledProcessError:
logger.info("Cannot connect to %s", host)
return False
def _write_failure_reason_file(failure_msg):
"""Create a file 'failure' with failure reason written if bootstrap runtime env failed.
See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
Args:
failure_msg: The content of file to be written.
"""
if not os.path.exists(FAILURE_REASON_PATH):
with open(FAILURE_REASON_PATH, "w") as f:
f.write("RuntimeEnvironmentError: " + failure_msg)
def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300):
"""Worker nodes wait until they can connect to the master node."""
start_time = time.time()
while True:
logger.info("Worker is attempting to connect to the master node %s...", master_host)
if _can_connect(master_host, port):
logger.info("Worker can connect to master node %s.", master_host)
break
if time.time() - start_time > timeout:
raise TimeoutError("Timed out waiting for master %s to be reachable." % master_host)
time.sleep(5) # Wait for 5 seconds before trying again
def _wait_for_status_file(status_file: str):
"""Wait for the status file to be created."""
logger.info("Waiting for status file %s", status_file)
while not os.path.exists(status_file):
time.sleep(30)
logger.info("Found status file %s", status_file)
def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300):
"""Master node waits until it can connect to all worker nodes."""
start_time = time.time()
if not worker_hosts:
logger.info("No worker nodes to connect to.")
return
while True:
logger.info("Master is attempting to connect to all workers...")
all_workers_connected = all(
_can_connect(worker, port) and os.path.exists(READY_FILE % worker)
for worker in worker_hosts
)
if all_workers_connected:
logger.info("Master can connect to all worker nodes.")
break
if time.time() - start_time > timeout:
raise TimeoutError("Timed out waiting for workers to be reachable.")
time.sleep(5) # Wait for 5 seconds before trying again
[docs]
def bootstrap_master_node(worker_hosts: List[str]):
"""Bootstrap the master node."""
logger.info("Bootstrapping master node...")
_wait_for_workers(worker_hosts)
[docs]
def bootstrap_worker_node(
master_host: str, current_host: str, status_file: str = FINISHED_STATUS_FILE
):
"""Bootstrap the worker nodes."""
logger.info("Bootstrapping worker node...")
_wait_for_master(master_host)
_write_file_to_host(master_host, READY_FILE % current_host)
_wait_for_status_file(status_file)
[docs]
def start_sshd_daemon():
"""Start the SSH daemon on the current node."""
sshd_executable = "/usr/sbin/sshd"
if not os.path.exists(sshd_executable):
raise RuntimeError("SSH daemon not found.")
# Start the sshd in daemon mode (-D)
subprocess.Popen([sshd_executable, "-D"])
logger.info("Started SSH daemon.")
[docs]
def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE):
"""Write the status file to all worker nodes."""
for worker in worker_hosts:
retry = 0
while not _write_file_to_host(worker, status_file):
time.sleep(5)
retry += 1
if retry > 5:
raise TimeoutError("Timed out waiting for %s to be reachable." % worker)
logger.info("Retrying to write status file to %s", worker)
[docs]
def main(sys_args=None):
"""Entry point for bootstrap script"""
try:
args = _parse_args(sys_args)
job_ended = args.job_ended
main_host = os.environ["SM_MASTER_ADDR"]
current_host = os.environ["SM_CURRENT_HOST"]
if job_ended == "0":
logger.info("Job is running, bootstrapping nodes")
start_sshd_daemon()
if current_host != main_host:
bootstrap_worker_node(main_host, current_host)
else:
sorted_hosts = json.loads(os.environ["SM_HOSTS"])
worker_hosts = [host for host in sorted_hosts if host != main_host]
bootstrap_master_node(worker_hosts)
else:
logger.info("Job ended, writing status file to workers")
if current_host == main_host:
sorted_hosts = json.loads(os.environ["SM_HOSTS"])
worker_hosts = [host for host in sorted_hosts if host != main_host]
write_status_file_to_workers(worker_hosts)
except Exception as e: # pylint: disable=broad-except
logger.exception("Error encountered while bootstrapping runtime environment: %s", e)
_write_failure_reason_file(str(e))
sys.exit(DEFAULT_FAILURE_CODE)
if __name__ == "__main__":
main(sys.argv[1:])