Source code for sagemaker.train.remote_function.runtime_environment.runtime_environment_manager

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""SageMaker runtime environment module. This must be kept independent of SageMaker PySDK"""

from __future__ import absolute_import


import logging
import sys
import shlex
import os
import subprocess
import time
import dataclasses
import json


class _UTCFormatter(logging.Formatter):
    """Class that overrides the default local time provider in log formatter."""

    converter = time.gmtime


[docs] def get_logger(): """Return a logger with the name 'sagemaker'""" sagemaker_logger = logging.getLogger("sagemaker.remote_function") if len(sagemaker_logger.handlers) == 0: sagemaker_logger.setLevel(logging.INFO) handler = logging.StreamHandler() formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s") handler.setFormatter(formatter) sagemaker_logger.addHandler(handler) # don't stream logs with the root logger handler sagemaker_logger.propagate = 0 return sagemaker_logger
logger = get_logger() @dataclasses.dataclass class _DependencySettings: """Dependency settings for the remote function. Instructs the runtime environment script on how to handle dependencies. If ``dependency_file`` is set, the runtime environment script will attempt to install the dependencies. If ``dependency_file`` is not set, the runtime environment script will assume no dependencies are required. """ dependency_file: str = None def to_string(self): """Converts the dependency settings to a string.""" return json.dumps(dataclasses.asdict(self)) @staticmethod def from_string(dependency_settings_string): """Converts a json string to dependency settings. Args: dependency_settings_string (str): The json string to convert. """ if dependency_settings_string is None: return None dependency_settings_dict = json.loads(dependency_settings_string) return _DependencySettings(dependency_settings_dict.get("dependency_file")) @staticmethod def from_dependency_file_path(dependency_file_path): """Converts a dependency file path to dependency settings. Args: dependency_file_path (str): The path to the dependency file. """ if dependency_file_path is None: return _DependencySettings() if dependency_file_path == "auto_capture": return _DependencySettings("env_snapshot.yml") return _DependencySettings(os.path.basename(dependency_file_path))
[docs] class RuntimeEnvironmentManager: """Runtime Environment Manager class to manage runtime environment.""" def _validate_path(self, path: str) -> str: """Validate and sanitize file path to prevent path traversal attacks. Args: path (str): The file path to validate Returns: str: The validated absolute path Raises: ValueError: If the path is invalid or contains suspicious patterns """ if not path: raise ValueError("Path cannot be empty") # Get absolute path to prevent path traversal abs_path = os.path.abspath(path) # Check for null bytes (common in path traversal attacks) if '\x00' in path: raise ValueError(f"Invalid path contains null byte: {path}") return abs_path def _validate_env_name(self, env_name: str) -> None: """Validate conda environment name to prevent command injection. Args: env_name (str): The environment name to validate Raises: ValueError: If the environment name contains invalid characters """ if not env_name: raise ValueError("Environment name cannot be empty") # Allow only alphanumeric, underscore, and hyphen import re if not re.match(r'^[a-zA-Z0-9_-]+$', env_name): raise ValueError( f"Invalid environment name '{env_name}'. " "Only alphanumeric characters, underscores, and hyphens are allowed." )
[docs] def snapshot(self, dependencies: str = None) -> str: """Creates snapshot of the user's environment If a req.txt or conda.yml file is provided, it verifies their existence and returns the local file path If ``auto_capture`` is set, this method will take the snapshot of user's dependencies installed in the local runtime. Current support for ``auto_capture``: * conda env, generate a yml file and return it's local path Args: dependencies (str): Local path where dependencies file exists. Returns: file path of the existing or generated dependencies file """ # No additional dependencies specified if dependencies is None: return None if dependencies == "auto_capture": return self._capture_from_local_runtime() # Dependencies specified as either req.txt or conda_env.yml if ( dependencies.endswith(".txt") or dependencies.endswith(".yml") or dependencies.endswith(".yaml") ): self._is_file_exists(dependencies) return dependencies raise ValueError(f'Invalid dependencies provided: "{dependencies}"')
def _capture_from_local_runtime(self) -> str: """Generates dependencies list from the user's local runtime. Raises RuntimeEnvironmentError if not able to. Currently supports: conda environments """ # Try to capture dependencies from the conda environment, if any. conda_env_name = self._get_active_conda_env_name() conda_env_prefix = self._get_active_conda_env_prefix() if conda_env_name: logger.info("Found conda_env_name: '%s'", conda_env_name) elif conda_env_prefix: logger.info("Found conda_env_prefix: '%s'", conda_env_prefix) else: raise ValueError("No conda environment seems to be active.") if conda_env_name == "base": logger.warning( "We recommend using an environment other than base to " "isolate your project dependencies from conda dependencies" ) local_dependencies_path = os.path.join(os.getcwd(), "env_snapshot.yml") self._export_conda_env_from_prefix(conda_env_prefix, local_dependencies_path) return local_dependencies_path def _get_active_conda_env_prefix(self) -> str: """Returns the conda prefix from the set environment variable. None otherwise.""" return os.getenv("CONDA_PREFIX") def _get_active_conda_env_name(self) -> str: """Returns the conda environment name from the set environment variable. None otherwise.""" return os.getenv("CONDA_DEFAULT_ENV")
[docs] def bootstrap( self, local_dependencies_file: str, client_python_version: str, conda_env: str = None ): """Bootstraps the runtime environment by installing the additional dependencies if any. Args: local_dependencies_file (str): path where dependencies file exists. conda_env (str): conda environment to be activated. Default is None. Returns: None """ if local_dependencies_file.endswith(".txt"): if conda_env: self._install_req_txt_in_conda_env(conda_env, local_dependencies_file) self._write_conda_env_to_file(conda_env) else: self._install_requirements_txt(local_dependencies_file, _python_executable()) elif local_dependencies_file.endswith(".yml") or local_dependencies_file.endswith(".yaml"): if conda_env: self._update_conda_env(conda_env, local_dependencies_file) else: conda_env = "sagemaker-runtime-env" self._create_conda_env(conda_env, local_dependencies_file) self._validate_python_version(client_python_version, conda_env) self._write_conda_env_to_file(conda_env)
[docs] def run_pre_exec_script(self, pre_exec_script_path: str): """Runs script of pre-execution commands if existing. Args: pre_exec_script_path (str): Path to pre-execution command script file. """ if os.path.isfile(pre_exec_script_path): logger.info("Running pre-execution commands in '%s'", pre_exec_script_path) return_code, error_logs = _run_pre_execution_command_script(pre_exec_script_path) if return_code: error_message = ( f"Encountered error while running pre-execution commands. Reason: {error_logs}" ) raise RuntimeEnvironmentError(error_message) else: logger.info( "'%s' does not exist. Assuming no pre-execution commands to run", pre_exec_script_path, )
[docs] def change_dir_permission(self, dirs: list, new_permission: str): """Change the permission of given directories Args: dirs (list[str]): A list of directories for permission update. new_permission (str): The new permission for the given directories. """ _ERROR_MSG_PREFIX = "Failed to change directory permissions due to: " command = ["sudo", "chmod", "-R", new_permission] + dirs logger.info("Executing '%s'.", " ".join(command)) try: subprocess.run(command, check=True, stderr=subprocess.PIPE) except subprocess.CalledProcessError as called_process_err: err_msg = called_process_err.stderr.decode("utf-8") raise RuntimeEnvironmentError(f"{_ERROR_MSG_PREFIX} {err_msg}") except FileNotFoundError as file_not_found_err: if "[Errno 2] No such file or directory: 'sudo'" in str(file_not_found_err): raise RuntimeEnvironmentError( f"{_ERROR_MSG_PREFIX} {file_not_found_err}. " "Please contact the image owner to install 'sudo' in the job container " "and provide sudo privilege to the container user." ) raise RuntimeEnvironmentError(file_not_found_err)
def _is_file_exists(self, dependencies): """Check whether the dependencies file exists at the given location. Raises error if not """ if not os.path.isfile(dependencies): raise ValueError(f'No dependencies file named "{dependencies}" was found.') def _install_requirements_txt(self, local_path, python_executable): """Install requirements.txt file""" # Validate path to prevent command injection validated_path = self._validate_path(local_path) cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"] logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd()) _run_shell_cmd(cmd) logger.info("Command %s ran successfully", " ".join(cmd)) def _create_conda_env(self, env_name, local_path): """Create conda env using conda yml file""" # Validate inputs to prevent command injection self._validate_env_name(env_name) validated_path = self._validate_path(local_path) cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path] logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd)) _run_shell_cmd(cmd) logger.info("Conda environment %s created successfully.", env_name) def _install_req_txt_in_conda_env(self, env_name, local_path): """Install requirements.txt in the given conda environment""" # Validate inputs to prevent command injection self._validate_env_name(env_name) validated_path = self._validate_path(local_path) cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"] logger.info("Activating conda env and installing requirements: %s", " ".join(cmd)) _run_shell_cmd(cmd) logger.info("Requirements installed successfully in conda env %s", env_name) def _update_conda_env(self, env_name, local_path): """Update conda env using conda yml file""" # Validate inputs to prevent command injection self._validate_env_name(env_name) validated_path = self._validate_path(local_path) cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path] logger.info("Updating conda env: %s", " ".join(cmd)) _run_shell_cmd(cmd) logger.info("Conda env %s updated succesfully", env_name) def _export_conda_env_from_prefix(self, prefix, local_path): """Export the conda env to a conda yml file""" cmd = [self._get_conda_exe(), "env", "export", "-p", prefix, "--no-builds", ">", local_path] logger.info("Exporting conda environment: %s", cmd) _run_shell_cmd(cmd) logger.info("Conda environment %s exported successfully", prefix) def _write_conda_env_to_file(self, env_name): """Writes conda env to the text file""" file_name = "remote_function_conda_env.txt" file_path = os.path.join(os.getcwd(), file_name) with open(file_path, "w") as output_file: output_file.write(env_name) def _get_conda_exe(self): """Checks whether conda or mamba is available to use""" if not subprocess.Popen(["which", "mamba"]).wait(): return "mamba" if not subprocess.Popen(["which", "conda"]).wait(): return "conda" raise ValueError("Neither conda nor mamba is installed on the image") def _python_version_in_conda_env(self, env_name): """Returns python version inside a conda environment""" cmd = f"{self._get_conda_exe()} run -n {env_name} python --version" try: output = ( subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT) .decode("utf-8") .strip() ) # convert 'Python 3.7.16' to [3, 7, 16] version = output.split("Python ")[1].split(".") return version[0] + "." + version[1] except subprocess.CalledProcessError as e: raise RuntimeEnvironmentError(e.output) def _current_python_version(self): """Returns the current python version where program is running""" return f"{sys.version_info.major}.{sys.version_info.minor}".strip() def _current_sagemaker_pysdk_version(self): """Returns the current sagemaker python sdk version where program is running""" try: from importlib import metadata return metadata.version("sagemaker") except Exception: return "3.0.0.dev0" # Development version fallback def _validate_python_version(self, client_python_version: str, conda_env: str = None): """Validate the python version Validates if the python version where remote function runs matches the one used on client side. """ if conda_env: job_python_version = self._python_version_in_conda_env(conda_env) else: job_python_version = self._current_python_version() if client_python_version.strip() != job_python_version.strip(): raise RuntimeEnvironmentError( f"Python version found in the container is '{job_python_version}' which " f"does not match python version '{client_python_version}' on the local client. " f"Please make sure that the python version used in the training container " f"is same as the local python version." ) def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version): """Validate the sagemaker python sdk version Validates if the sagemaker python sdk version where remote function runs matches the one used on client side. Otherwise, log a warning to call out that unexpected behaviors may occur in this case. """ job_sagemaker_pysdk_version = self._current_sagemaker_pysdk_version() if ( client_sagemaker_pysdk_version and client_sagemaker_pysdk_version != job_sagemaker_pysdk_version ): logger.warning( "Inconsistent sagemaker versions found: " "sagemaker python sdk version found in the container is " "'%s' which does not match the '%s' on the local client. " "Please make sure that the sagemaker version used in the training container " "is the same as the local sagemaker version in case of unexpected behaviors.", job_sagemaker_pysdk_version, client_sagemaker_pysdk_version, )
def _run_and_get_output_shell_cmd(cmd: str) -> str: """Run and return the output of the given shell command""" return subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8") def _run_pre_execution_command_script(script_path: str): """This method runs a given shell script using subprocess Raises RuntimeEnvironmentError if the shell script fails """ current_dir = os.path.dirname(script_path) process = subprocess.Popen( ["/bin/bash", "-eu", script_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=current_dir, ) _log_output(process) error_logs = _log_error(process) return_code = process.wait() return return_code, error_logs def _run_shell_cmd(cmd: list): """This method runs a given shell command using subprocess Args: cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt']) Raises: RuntimeEnvironmentError: If the command fails ValueError: If cmd is not a list """ if not isinstance(cmd, list): raise ValueError("Command must be a list of arguments for security reasons") process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) _log_output(process) error_logs = _log_error(process) return_code = process.wait() if return_code: error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}" raise RuntimeEnvironmentError(error_message) def _log_output(process): """This method takes in Popen process and logs the output of that process""" with process.stdout as pipe: for line in iter(pipe.readline, b""): logger.info(str(line, "UTF-8")) def _log_error(process): """This method takes in Popen process and logs the error of that process. Returns those logs as a string """ error_logs = "" with process.stderr as pipe: for line in iter(pipe.readline, b""): error_str = str(line, "UTF-8") if "ERROR:" in error_str: logger.error(error_str) else: logger.warning(error_str) error_logs = error_logs + error_str return error_logs def _python_executable(): """Return the real path for the Python executable, if it exists. Return RuntimeEnvironmentError otherwise. Returns: (str): The real path of the current Python executable. """ if not sys.executable: raise RuntimeEnvironmentError( "Failed to retrieve the path for the Python executable binary" ) return sys.executable
[docs] class RuntimeEnvironmentError(Exception): """The base exception class for bootstrap env excepitons""" def __init__(self, message): self.message = message super().__init__(self.message)