# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""SageMaker runtime environment module. This must be kept independent of SageMaker PySDK"""
from __future__ import absolute_import
import logging
import sys
import shlex
import os
import subprocess
import time
import dataclasses
import json
class _UTCFormatter(logging.Formatter):
"""Class that overrides the default local time provider in log formatter."""
converter = time.gmtime
[docs]
def get_logger():
"""Return a logger with the name 'sagemaker'"""
sagemaker_logger = logging.getLogger("sagemaker.remote_function")
if len(sagemaker_logger.handlers) == 0:
sagemaker_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s")
handler.setFormatter(formatter)
sagemaker_logger.addHandler(handler)
# don't stream logs with the root logger handler
sagemaker_logger.propagate = 0
return sagemaker_logger
logger = get_logger()
@dataclasses.dataclass
class _DependencySettings:
"""Dependency settings for the remote function.
Instructs the runtime environment script on how to handle dependencies.
If ``dependency_file`` is set, the runtime environment script will attempt
to install the dependencies. If ``dependency_file`` is not set, the runtime
environment script will assume no dependencies are required.
"""
dependency_file: str = None
def to_string(self):
"""Converts the dependency settings to a string."""
return json.dumps(dataclasses.asdict(self))
@staticmethod
def from_string(dependency_settings_string):
"""Converts a json string to dependency settings.
Args:
dependency_settings_string (str): The json string to convert.
"""
if dependency_settings_string is None:
return None
dependency_settings_dict = json.loads(dependency_settings_string)
return _DependencySettings(dependency_settings_dict.get("dependency_file"))
@staticmethod
def from_dependency_file_path(dependency_file_path):
"""Converts a dependency file path to dependency settings.
Args:
dependency_file_path (str): The path to the dependency file.
"""
if dependency_file_path is None:
return _DependencySettings()
if dependency_file_path == "auto_capture":
return _DependencySettings("env_snapshot.yml")
return _DependencySettings(os.path.basename(dependency_file_path))
[docs]
class RuntimeEnvironmentManager:
"""Runtime Environment Manager class to manage runtime environment."""
def _validate_path(self, path: str) -> str:
"""Validate and sanitize file path to prevent path traversal attacks.
Args:
path (str): The file path to validate
Returns:
str: The validated absolute path
Raises:
ValueError: If the path is invalid or contains suspicious patterns
"""
if not path:
raise ValueError("Path cannot be empty")
# Get absolute path to prevent path traversal
abs_path = os.path.abspath(path)
# Check for null bytes (common in path traversal attacks)
if '\x00' in path:
raise ValueError(f"Invalid path contains null byte: {path}")
return abs_path
def _validate_env_name(self, env_name: str) -> None:
"""Validate conda environment name to prevent command injection.
Args:
env_name (str): The environment name to validate
Raises:
ValueError: If the environment name contains invalid characters
"""
if not env_name:
raise ValueError("Environment name cannot be empty")
# Allow only alphanumeric, underscore, and hyphen
import re
if not re.match(r'^[a-zA-Z0-9_-]+$', env_name):
raise ValueError(
f"Invalid environment name '{env_name}'. "
"Only alphanumeric characters, underscores, and hyphens are allowed."
)
[docs]
def snapshot(self, dependencies: str = None) -> str:
"""Creates snapshot of the user's environment
If a req.txt or conda.yml file is provided, it verifies their existence and
returns the local file path
If ``auto_capture`` is set, this method will take the snapshot of
user's dependencies installed in the local runtime.
Current support for ``auto_capture``:
* conda env, generate a yml file and return it's local path
Args:
dependencies (str): Local path where dependencies file exists.
Returns:
file path of the existing or generated dependencies file
"""
# No additional dependencies specified
if dependencies is None:
return None
if dependencies == "auto_capture":
return self._capture_from_local_runtime()
# Dependencies specified as either req.txt or conda_env.yml
if (
dependencies.endswith(".txt")
or dependencies.endswith(".yml")
or dependencies.endswith(".yaml")
):
self._is_file_exists(dependencies)
return dependencies
raise ValueError(f'Invalid dependencies provided: "{dependencies}"')
def _capture_from_local_runtime(self) -> str:
"""Generates dependencies list from the user's local runtime.
Raises RuntimeEnvironmentError if not able to.
Currently supports: conda environments
"""
# Try to capture dependencies from the conda environment, if any.
conda_env_name = self._get_active_conda_env_name()
conda_env_prefix = self._get_active_conda_env_prefix()
if conda_env_name:
logger.info("Found conda_env_name: '%s'", conda_env_name)
elif conda_env_prefix:
logger.info("Found conda_env_prefix: '%s'", conda_env_prefix)
else:
raise ValueError("No conda environment seems to be active.")
if conda_env_name == "base":
logger.warning(
"We recommend using an environment other than base to "
"isolate your project dependencies from conda dependencies"
)
local_dependencies_path = os.path.join(os.getcwd(), "env_snapshot.yml")
self._export_conda_env_from_prefix(conda_env_prefix, local_dependencies_path)
return local_dependencies_path
def _get_active_conda_env_prefix(self) -> str:
"""Returns the conda prefix from the set environment variable. None otherwise."""
return os.getenv("CONDA_PREFIX")
def _get_active_conda_env_name(self) -> str:
"""Returns the conda environment name from the set environment variable. None otherwise."""
return os.getenv("CONDA_DEFAULT_ENV")
[docs]
def bootstrap(
self, local_dependencies_file: str, client_python_version: str, conda_env: str = None
):
"""Bootstraps the runtime environment by installing the additional dependencies if any.
Args:
local_dependencies_file (str): path where dependencies file exists.
conda_env (str): conda environment to be activated. Default is None.
Returns: None
"""
if local_dependencies_file.endswith(".txt"):
if conda_env:
self._install_req_txt_in_conda_env(conda_env, local_dependencies_file)
self._write_conda_env_to_file(conda_env)
else:
self._install_requirements_txt(local_dependencies_file, _python_executable())
elif local_dependencies_file.endswith(".yml") or local_dependencies_file.endswith(".yaml"):
if conda_env:
self._update_conda_env(conda_env, local_dependencies_file)
else:
conda_env = "sagemaker-runtime-env"
self._create_conda_env(conda_env, local_dependencies_file)
self._validate_python_version(client_python_version, conda_env)
self._write_conda_env_to_file(conda_env)
[docs]
def run_pre_exec_script(self, pre_exec_script_path: str):
"""Runs script of pre-execution commands if existing.
Args:
pre_exec_script_path (str): Path to pre-execution command script file.
"""
if os.path.isfile(pre_exec_script_path):
logger.info("Running pre-execution commands in '%s'", pre_exec_script_path)
return_code, error_logs = _run_pre_execution_command_script(pre_exec_script_path)
if return_code:
error_message = (
f"Encountered error while running pre-execution commands. Reason: {error_logs}"
)
raise RuntimeEnvironmentError(error_message)
else:
logger.info(
"'%s' does not exist. Assuming no pre-execution commands to run",
pre_exec_script_path,
)
[docs]
def change_dir_permission(self, dirs: list, new_permission: str):
"""Change the permission of given directories
Args:
dirs (list[str]): A list of directories for permission update.
new_permission (str): The new permission for the given directories.
"""
_ERROR_MSG_PREFIX = "Failed to change directory permissions due to: "
command = ["sudo", "chmod", "-R", new_permission] + dirs
logger.info("Executing '%s'.", " ".join(command))
try:
subprocess.run(command, check=True, stderr=subprocess.PIPE)
except subprocess.CalledProcessError as called_process_err:
err_msg = called_process_err.stderr.decode("utf-8")
raise RuntimeEnvironmentError(f"{_ERROR_MSG_PREFIX} {err_msg}")
except FileNotFoundError as file_not_found_err:
if "[Errno 2] No such file or directory: 'sudo'" in str(file_not_found_err):
raise RuntimeEnvironmentError(
f"{_ERROR_MSG_PREFIX} {file_not_found_err}. "
"Please contact the image owner to install 'sudo' in the job container "
"and provide sudo privilege to the container user."
)
raise RuntimeEnvironmentError(file_not_found_err)
def _is_file_exists(self, dependencies):
"""Check whether the dependencies file exists at the given location.
Raises error if not
"""
if not os.path.isfile(dependencies):
raise ValueError(f'No dependencies file named "{dependencies}" was found.')
def _install_requirements_txt(self, local_path, python_executable):
"""Install requirements.txt file"""
# Validate path to prevent command injection
validated_path = self._validate_path(local_path)
cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"]
logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd())
_run_shell_cmd(cmd)
logger.info("Command %s ran successfully", " ".join(cmd))
def _create_conda_env(self, env_name, local_path):
"""Create conda env using conda yml file"""
# Validate inputs to prevent command injection
self._validate_env_name(env_name)
validated_path = self._validate_path(local_path)
cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path]
logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd))
_run_shell_cmd(cmd)
logger.info("Conda environment %s created successfully.", env_name)
def _install_req_txt_in_conda_env(self, env_name, local_path):
"""Install requirements.txt in the given conda environment"""
# Validate inputs to prevent command injection
self._validate_env_name(env_name)
validated_path = self._validate_path(local_path)
cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"]
logger.info("Activating conda env and installing requirements: %s", " ".join(cmd))
_run_shell_cmd(cmd)
logger.info("Requirements installed successfully in conda env %s", env_name)
def _update_conda_env(self, env_name, local_path):
"""Update conda env using conda yml file"""
# Validate inputs to prevent command injection
self._validate_env_name(env_name)
validated_path = self._validate_path(local_path)
cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path]
logger.info("Updating conda env: %s", " ".join(cmd))
_run_shell_cmd(cmd)
logger.info("Conda env %s updated succesfully", env_name)
def _export_conda_env_from_prefix(self, prefix, local_path):
"""Export the conda env to a conda yml file"""
cmd = [self._get_conda_exe(), "env", "export", "-p", prefix, "--no-builds", ">", local_path]
logger.info("Exporting conda environment: %s", cmd)
_run_shell_cmd(cmd)
logger.info("Conda environment %s exported successfully", prefix)
def _write_conda_env_to_file(self, env_name):
"""Writes conda env to the text file"""
file_name = "remote_function_conda_env.txt"
file_path = os.path.join(os.getcwd(), file_name)
with open(file_path, "w") as output_file:
output_file.write(env_name)
def _get_conda_exe(self):
"""Checks whether conda or mamba is available to use"""
if not subprocess.Popen(["which", "mamba"]).wait():
return "mamba"
if not subprocess.Popen(["which", "conda"]).wait():
return "conda"
raise ValueError("Neither conda nor mamba is installed on the image")
def _python_version_in_conda_env(self, env_name):
"""Returns python version inside a conda environment"""
cmd = f"{self._get_conda_exe()} run -n {env_name} python --version"
try:
output = (
subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT)
.decode("utf-8")
.strip()
)
# convert 'Python 3.7.16' to [3, 7, 16]
version = output.split("Python ")[1].split(".")
return version[0] + "." + version[1]
except subprocess.CalledProcessError as e:
raise RuntimeEnvironmentError(e.output)
def _current_python_version(self):
"""Returns the current python version where program is running"""
return f"{sys.version_info.major}.{sys.version_info.minor}".strip()
def _current_sagemaker_pysdk_version(self):
"""Returns the current sagemaker python sdk version where program is running"""
try:
from importlib import metadata
return metadata.version("sagemaker")
except Exception:
return "3.0.0.dev0" # Development version fallback
def _validate_python_version(self, client_python_version: str, conda_env: str = None):
"""Validate the python version
Validates if the python version where remote function runs
matches the one used on client side.
"""
if conda_env:
job_python_version = self._python_version_in_conda_env(conda_env)
else:
job_python_version = self._current_python_version()
if client_python_version.strip() != job_python_version.strip():
raise RuntimeEnvironmentError(
f"Python version found in the container is '{job_python_version}' which "
f"does not match python version '{client_python_version}' on the local client. "
f"Please make sure that the python version used in the training container "
f"is same as the local python version."
)
def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version):
"""Validate the sagemaker python sdk version
Validates if the sagemaker python sdk version where remote function runs
matches the one used on client side.
Otherwise, log a warning to call out that unexpected behaviors
may occur in this case.
"""
job_sagemaker_pysdk_version = self._current_sagemaker_pysdk_version()
if (
client_sagemaker_pysdk_version
and client_sagemaker_pysdk_version != job_sagemaker_pysdk_version
):
logger.warning(
"Inconsistent sagemaker versions found: "
"sagemaker python sdk version found in the container is "
"'%s' which does not match the '%s' on the local client. "
"Please make sure that the sagemaker version used in the training container "
"is the same as the local sagemaker version in case of unexpected behaviors.",
job_sagemaker_pysdk_version,
client_sagemaker_pysdk_version,
)
def _run_and_get_output_shell_cmd(cmd: str) -> str:
"""Run and return the output of the given shell command"""
return subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8")
def _run_pre_execution_command_script(script_path: str):
"""This method runs a given shell script using subprocess
Raises RuntimeEnvironmentError if the shell script fails
"""
current_dir = os.path.dirname(script_path)
process = subprocess.Popen(
["/bin/bash", "-eu", script_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=current_dir,
)
_log_output(process)
error_logs = _log_error(process)
return_code = process.wait()
return return_code, error_logs
def _run_shell_cmd(cmd: list):
"""This method runs a given shell command using subprocess
Args:
cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt'])
Raises:
RuntimeEnvironmentError: If the command fails
ValueError: If cmd is not a list
"""
if not isinstance(cmd, list):
raise ValueError("Command must be a list of arguments for security reasons")
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
_log_output(process)
error_logs = _log_error(process)
return_code = process.wait()
if return_code:
error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}"
raise RuntimeEnvironmentError(error_message)
def _log_output(process):
"""This method takes in Popen process and logs the output of that process"""
with process.stdout as pipe:
for line in iter(pipe.readline, b""):
logger.info(str(line, "UTF-8"))
def _log_error(process):
"""This method takes in Popen process and logs the error of that process.
Returns those logs as a string
"""
error_logs = ""
with process.stderr as pipe:
for line in iter(pipe.readline, b""):
error_str = str(line, "UTF-8")
if "ERROR:" in error_str:
logger.error(error_str)
else:
logger.warning(error_str)
error_logs = error_logs + error_str
return error_logs
def _python_executable():
"""Return the real path for the Python executable, if it exists.
Return RuntimeEnvironmentError otherwise.
Returns:
(str): The real path of the current Python executable.
"""
if not sys.executable:
raise RuntimeEnvironmentError(
"Failed to retrieve the path for the Python executable binary"
)
return sys.executable
[docs]
class RuntimeEnvironmentError(Exception):
"""The base exception class for bootstrap env excepitons"""
def __init__(self, message):
self.message = message
super().__init__(self.message)