Source code for sagemaker.train.container_drivers.scripts.environment

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module is used to define the environment variables for the training job container."""
from __future__ import absolute_import

from typing import Dict, Any
import multiprocessing
import subprocess
import json
import os
import sys
from pathlib import Path
import logging

sys.path.insert(0, str(Path(__file__).parent.parent))

from common.utils import (  # noqa: E402 # pylint: disable=C0413,E0611
    safe_serialize,
    safe_deserialize,
    read_distributed_json,
    read_source_code_json,
)

# Initialize logger
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
logger = logging.getLogger(__name__)
console_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(console_handler)
logger.setLevel(int(SM_LOG_LEVEL))

SM_MODEL_DIR = "/opt/ml/model"

SM_INPUT_DIR = "/opt/ml/input"
SM_INPUT_DATA_DIR = "/opt/ml/input/data"
SM_INPUT_CONFIG_DIR = "/opt/ml/input/config"

SM_OUTPUT_DIR = "/opt/ml/output"
SM_OUTPUT_FAILURE = "/opt/ml/output/failure"
SM_OUTPUT_DATA_DIR = "/opt/ml/output/data"
SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code"
SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers"

SM_MASTER_ADDR = "algo-1"
SM_MASTER_PORT = 7777

RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json"
INPUT_DATA_CONFIG = f"{SM_INPUT_CONFIG_DIR}/inputdataconfig.json"
HYPERPARAMETERS_CONFIG = f"{SM_INPUT_CONFIG_DIR}/hyperparameters.json"

ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env"

SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"]
HIDDEN_VALUE = "******"


[docs] def num_cpus() -> int: """Return the number of CPUs available in the current container. Returns: int: Number of CPUs available in the current container. """ return multiprocessing.cpu_count()
[docs] def num_gpus() -> int: """Return the number of GPUs available in the current container. Returns: int: Number of GPUs available in the current container. """ try: cmd = ["nvidia-smi", "--list-gpus"] output = subprocess.check_output(cmd).decode("utf-8") return sum(1 for line in output.splitlines() if line.startswith("GPU ")) except (OSError, subprocess.CalledProcessError): logger.info("No GPUs detected (normal if no gpus installed)") return 0
[docs] def num_neurons() -> int: """Return the number of neuron cores available in the current container. Returns: int: Number of Neuron Cores available in the current container. """ try: cmd = ["neuron-ls", "-j"] output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") j = json.loads(output) neuron_cores = 0 for item in j: neuron_cores += item.get("nc_count", 0) logger.info("Found %s neurons on this instance", neuron_cores) return neuron_cores except OSError: logger.info("No Neurons detected (normal if no neurons installed)") return 0 except subprocess.CalledProcessError as e: if e.output is not None: try: msg = e.output.decode("utf-8").partition("error=")[2] logger.info( "No Neurons detected (normal if no neurons installed). \ If neuron installed then %s", msg, ) except AttributeError: logger.info("No Neurons detected (normal if no neurons installed)") else: logger.info("No Neurons detected (normal if no neurons installed)") return 0
[docs] def deserialize_hyperparameters(hyperparameters: Dict[str, str]) -> Dict[str, Any]: """Deserialize hyperparameters from string to their original types. Args: hyperparameters (Dict[str, str]): Hyperparameters as strings. Returns: Dict[str, Any]: Hyperparameters as their original types. """ deserialized_hyperparameters = {} for key, value in hyperparameters.items(): deserialized_hyperparameters[key] = safe_deserialize(value) return deserialized_hyperparameters
[docs] def set_env( resource_config: Dict[str, Any], input_data_config: Dict[str, Any], hyperparameters_config: Dict[str, Any], output_file: str = ENV_OUTPUT_FILE, ): """Set environment variables for the training job container. Args: resource_config (Dict[str, Any]): Resource configuration for the training job. input_data_config (Dict[str, Any]): Input data configuration for the training job. hyperparameters_config (Dict[str, Any]): Hyperparameters configuration for the training job. output_file (str): Output file to write the environment variables. """ # Constants env_vars = { "SM_MODEL_DIR": SM_MODEL_DIR, "SM_INPUT_DIR": SM_INPUT_DIR, "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, "SM_OUTPUT_DIR": SM_OUTPUT_DIR, "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, "SM_LOG_LEVEL": SM_LOG_LEVEL, "SM_MASTER_ADDR": SM_MASTER_ADDR, "SM_MASTER_PORT": SM_MASTER_PORT, } # SourceCode and DistributedConfig Environment Variables source_code = read_source_code_json() if source_code: env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "") distributed = read_distributed_json() if distributed: env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH env_vars["SM_DISTRIBUTED_CONFIG"] = distributed # Data Channels channels = list(input_data_config.keys()) for channel in channels: env_vars[f"SM_CHANNEL_{channel.upper()}"] = f"{SM_INPUT_DATA_DIR}/{channel}" env_vars["SM_CHANNELS"] = channels # Hyperparameters hps = deserialize_hyperparameters(hyperparameters_config) for key, value in hps.items(): key_upper = key.replace("-", "_").upper() env_vars[f"SM_HP_{key_upper}"] = value env_vars["SM_HPS"] = hps # Host Variables current_host = resource_config["current_host"] current_instance_type = resource_config["current_instance_type"] hosts = resource_config["hosts"] sorted_hosts = sorted(hosts) env_vars["SM_CURRENT_HOST"] = current_host env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type env_vars["SM_HOSTS"] = sorted_hosts env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] env_vars["SM_HOST_COUNT"] = len(sorted_hosts) env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) env_vars["SM_NUM_CPUS"] = num_cpus() env_vars["SM_NUM_GPUS"] = num_gpus() env_vars["SM_NUM_NEURONS"] = num_neurons() # Misc. env_vars["SM_RESOURCE_CONFIG"] = resource_config env_vars["SM_INPUT_DATA_CONFIG"] = input_data_config # All Training Environment Variables env_vars["SM_TRAINING_ENV"] = { "channel_input_dirs": { channel: env_vars[f"SM_CHANNEL_{channel.upper()}"] for channel in channels }, "current_host": env_vars["SM_CURRENT_HOST"], "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], "hosts": env_vars["SM_HOSTS"], "master_addr": env_vars["SM_MASTER_ADDR"], "master_port": env_vars["SM_MASTER_PORT"], "hyperparameters": env_vars["SM_HPS"], "input_data_config": input_data_config, "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], "input_dir": env_vars["SM_INPUT_DIR"], "job_name": os.environ["TRAINING_JOB_NAME"], "log_level": env_vars["SM_LOG_LEVEL"], "model_dir": env_vars["SM_MODEL_DIR"], "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], "num_cpus": env_vars["SM_NUM_CPUS"], "num_gpus": env_vars["SM_NUM_GPUS"], "num_neurons": env_vars["SM_NUM_NEURONS"], "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], "resource_config": env_vars["SM_RESOURCE_CONFIG"], } with open(output_file, "w") as f: for key, value in env_vars.items(): f.write(f"export {key}='{safe_serialize(value)}'\n") logger.info("Environment Variables:") log_env_variables(env_vars_dict=env_vars)
[docs] def mask_sensitive_info(data): """Recursively mask sensitive information in a dictionary.""" if isinstance(data, dict): for k, v in data.items(): if isinstance(v, dict): data[k] = mask_sensitive_info(v) elif isinstance(v, str) and any( keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS ): data[k] = HIDDEN_VALUE return data
[docs] def log_key_value(key: str, value: str): """Log a key-value pair, masking sensitive values if necessary.""" if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): logger.info("%s=%s", key, HIDDEN_VALUE) elif isinstance(value, dict): masked_value = mask_sensitive_info(value) logger.info("%s=%s", key, json.dumps(masked_value)) else: try: decoded_value = json.loads(value) if isinstance(decoded_value, dict): masked_value = mask_sensitive_info(decoded_value) logger.info("%s=%s", key, json.dumps(masked_value)) else: logger.info("%s=%s", key, decoded_value) except (json.JSONDecodeError, TypeError): logger.info("%s=%s", key, value)
[docs] def log_env_variables(env_vars_dict: Dict[str, Any]): """Log Environment Variables from the environment and an env_vars_dict.""" for key, value in os.environ.items(): log_key_value(key, value) for key, value in env_vars_dict.items(): log_key_value(key, value)
[docs] def main(): """Main function to set the environment variables for the training job container.""" with open(RESOURCE_CONFIG, "r") as f: resource_config = json.load(f) with open(INPUT_DATA_CONFIG, "r") as f: input_data_config = json.load(f) with open(HYPERPARAMETERS_CONFIG, "r") as f: hyperparameters_config = json.load(f) set_env( resource_config=resource_config, input_data_config=input_data_config, hyperparameters_config=hyperparameters_config, output_file=ENV_OUTPUT_FILE, )
if __name__ == "__main__": main()