# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
import sys
import contextlib
import copy
import errno
import inspect
import logging
import os
import random
import re
import shutil
import tarfile
import tempfile
import time
from functools import lru_cache
from typing import Union, Any, List, Optional, Dict
import json
import abc
import uuid
from datetime import datetime
from os.path import abspath, realpath, dirname, normpath, join as joinpath
from importlib import import_module
import boto3
import botocore
from botocore.utils import merge_dicts
from botocore import exceptions
from botocore.exceptions import ClientError
from six.moves.urllib import parse
from six import viewitems
import sagemaker
from sagemaker.core.enums import RoutingStrategy
from sagemaker.core.session_settings import SessionSettings
from sagemaker.core.workflow import is_pipeline_variable, is_pipeline_parameter_string
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from enum import Enum
ALTERNATE_DOMAINS = {
"cn-north-1": "amazonaws.com.cn",
"cn-northwest-1": "amazonaws.com.cn",
"us-iso-east-1": "c2s.ic.gov",
"us-iso-west-1": "c2s.ic.gov",
"us-isob-east-1": "sc2s.sgov.gov",
"us-isob-west-1": "sc2s.sgov.gov",
"us-isof-south-1": "csp.hci.ic.gov",
"us-isof-east-1": "csp.hci.ic.gov",
"eu-isoe-west-1": "cloud.adc-e.uk",
}
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
MODEL_PACKAGE_ARN_PATTERN = (
r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)"
)
MODEL_ARN_PATTERN = r"arn:aws([a-z\-]*):sagemaker:([a-z0-9\-]*):([0-9]{12}):model/(.*)"
MAX_BUCKET_PATHS_COUNT = 5
S3_PREFIX = "s3://"
HTTP_PREFIX = "http://"
HTTPS_PREFIX = "https://"
DEFAULT_SLEEP_TIME_SECONDS = 10
WAITING_DOT_NUMBER = 10
MAX_ITEMS = 100
PAGE_SIZE = 10
_MAX_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB - Maximum buffer size for streaming iterators
_SENSITIVE_SYSTEM_PATHS = [
abspath(os.path.expanduser("~/.aws")),
abspath(os.path.expanduser("~/.ssh")),
abspath(os.path.expanduser("~/.kube")),
abspath(os.path.expanduser("~/.docker")),
abspath(os.path.expanduser("~/.config")),
abspath(os.path.expanduser("~/.credentials")),
"/etc",
"/root",
"/var/lib",
"/opt/ml/metadata",
]
logger = logging.getLogger(__name__)
TagsDict = Dict[str, Union[str, PipelineVariable]]
Tags = Union[List[TagsDict], TagsDict]
[docs]
class ModelApprovalStatusEnum(str, Enum):
"""Model package approval status enumerator"""
APPROVED = "Approved"
REJECTED = "Rejected"
PENDING_MANUAL_APPROVAL = "PendingManualApproval"
# Use the base name of the image as the job name if the user doesn't give us one
[docs]
def name_from_image(image, max_length=63):
"""Create a training job name based on the image name and a timestamp.
Args:
image (str): Image name.
Returns:
str: Training job name using the algorithm from the image name and a
timestamp.
max_length (int): Maximum length for the resulting string (default: 63).
"""
return name_from_base(base_name_from_image(image), max_length=max_length)
[docs]
def name_from_base(base, max_length=63, short=False):
"""Append a timestamp to the provided string.
This function assures that the total length of the resulting string is
not longer than the specified max length, trimming the input parameter if
necessary.
Args:
base (str): String used as prefix to generate the unique name.
max_length (int): Maximum length for the resulting string (default: 63).
short (bool): Whether or not to use a truncated timestamp (default: False).
Returns:
str: Input parameter with appended timestamp.
"""
timestamp = sagemaker_short_timestamp() if short else sagemaker_timestamp()
trimmed_base = base[: max_length - len(timestamp) - 1]
return "{}-{}".format(trimmed_base, timestamp)
[docs]
def unique_name_from_base_uuid4(base, max_length=63):
"""Append a UUID to the provided string.
This function is used to generate a name using UUID instead of timestamps
for uniqueness.
Args:
base (str): String used as prefix to generate the unique name.
max_length (int): Maximum length for the resulting string (default: 63).
Returns:
str: Input parameter with appended timestamp.
"""
random.seed(int(uuid.uuid4())) # using uuid to randomize
unique = str(uuid.uuid4())
trimmed_base = base[: max_length - len(unique) - 1]
return "{}-{}".format(trimmed_base, unique)
[docs]
def unique_name_from_base(base, max_length=63):
"""Placeholder Docstring"""
random.seed(int(uuid.uuid4())) # using uuid to randomize, otherwise system timestamp is used.
unique = "%04x" % random.randrange(16**4) # 4-digit hex
ts = str(int(time.time()))
available_length = max_length - 2 - len(ts) - len(unique)
trimmed = base[:available_length]
return "{}-{}-{}".format(trimmed, ts, unique)
[docs]
def base_name_from_image(image, default_base_name=None):
"""Extract the base name of the image to use as the 'algorithm name' for the job.
Args:
image (str): Image name.
default_base_name (str): The default base name
Returns:
str: Algorithm name, as extracted from the image name.
"""
if is_pipeline_variable(image):
if is_pipeline_parameter_string(image) and image.default_value:
image_str = image.default_value
else:
return default_base_name if default_base_name else "base_name"
else:
image_str = image
m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image_str)
base_name = m.group(2) if m else image_str
return base_name
[docs]
def base_from_name(name):
"""Extract the base name of the resource name (for use with future resource name generation).
This function looks for timestamps that match the ones produced by
:func:`~sagemaker.utils.name_from_base`.
Args:
name (str): The resource name.
Returns:
str: The base name, as extracted from the resource name.
"""
m = re.match(r"^(.+)-(\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-\d{3}|\d{6}-\d{4})", name)
return m.group(1) if m else name
[docs]
def sagemaker_timestamp():
"""Return a timestamp with millisecond precision."""
moment = time.time()
moment_ms = repr(moment).split(".")[1][:3]
return time.strftime("%Y-%m-%d-%H-%M-%S-{}".format(moment_ms), time.gmtime(moment))
[docs]
def sagemaker_short_timestamp():
"""Return a timestamp that is relatively short in length"""
return time.strftime("%y%m%d-%H%M")
[docs]
def build_dict(key, value):
"""Return a dict of key and value pair if value is not None, otherwise return an empty dict.
Args:
key (str): input key
value (str): input value
Returns:
dict: dict of key and value or an empty dict.
"""
if value:
return {key: value}
return {}
[docs]
def get_config_value(key_path, config):
"""Placeholder Docstring"""
if config is None:
return None
current_section = config
for key in key_path.split("."):
if key in current_section:
current_section = current_section[key]
else:
return None
return current_section
[docs]
def get_nested_value(dictionary: dict, nested_keys: List[str]):
"""Returns a nested value from the given dictionary, and None if none present.
Raises
ValueError if the dictionary structure does not match the nested_keys
"""
if (
dictionary is not None
and isinstance(dictionary, dict)
and nested_keys is not None
and len(nested_keys) > 0
):
current_section = dictionary
for key in nested_keys[:-1]:
current_section = current_section.get(key, None)
if current_section is None:
# means the full path of nested_keys doesnt exist in the dictionary
# or the value was set to None
return None
if not isinstance(current_section, dict):
raise ValueError(
"Unexpected structure of dictionary.",
"Expected value of type dict at key '{}' but got '{}' for dict '{}'".format(
key, current_section, dictionary
),
)
return current_section.get(nested_keys[-1], None)
return None
[docs]
def set_nested_value(dictionary: dict, nested_keys: List[str], value_to_set: object):
"""Sets a nested value in a dictionary.
This sets a nested value inside the given dictionary and returns the new dictionary. Note: if
provided an unintended list of nested keys, this can overwrite an unexpected part of the dict.
Recommended to use after a check with get_nested_value first
"""
if dictionary is None:
dictionary = {}
if (
dictionary is not None
and isinstance(dictionary, dict)
and nested_keys is not None
and len(nested_keys) > 0
):
current_section = dictionary
for key in nested_keys[:-1]:
if (
key not in current_section
or current_section[key] is None
or not isinstance(current_section[key], dict)
):
current_section[key] = {}
current_section = current_section[key]
current_section[nested_keys[-1]] = value_to_set
return dictionary
[docs]
def get_short_version(framework_version):
"""Return short version in the format of x.x
Args:
framework_version: The version string to be shortened.
Returns:
str: The short version string
"""
return ".".join(framework_version.split(".")[:2])
[docs]
def secondary_training_status_changed(current_job_description, prev_job_description):
"""Returns true if training job's secondary status message has changed.
Args:
current_job_description: Current job description, returned from DescribeTrainingJob call.
prev_job_description: Previous job description, returned from DescribeTrainingJob call.
Returns:
boolean: Whether the secondary status message of a training job changed
or not.
"""
current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions")
if (
current_secondary_status_transitions is None
or len(current_secondary_status_transitions) == 0
):
return False
prev_job_secondary_status_transitions = (
prev_job_description.get("SecondaryStatusTransitions")
if prev_job_description is not None
else None
)
last_message = (
prev_job_secondary_status_transitions[-1]["StatusMessage"]
if prev_job_secondary_status_transitions is not None
and len(prev_job_secondary_status_transitions) > 0
else ""
)
message = current_job_description["SecondaryStatusTransitions"][-1]["StatusMessage"]
return message != last_message
[docs]
def secondary_training_status_message(job_description, prev_description):
"""Returns a string contains last modified time and the secondary training job status message.
Args:
job_description: Returned response from DescribeTrainingJob call
prev_description: Previous job description from DescribeTrainingJob call
Returns:
str: Job status string to be printed.
"""
if (
job_description is None
or job_description.get("SecondaryStatusTransitions") is None
or len(job_description.get("SecondaryStatusTransitions")) == 0
):
return ""
prev_description_secondary_transitions = (
prev_description.get("SecondaryStatusTransitions") if prev_description is not None else None
)
prev_transitions_num = (
len(prev_description["SecondaryStatusTransitions"])
if prev_description_secondary_transitions is not None
else 0
)
current_transitions = job_description["SecondaryStatusTransitions"]
if len(current_transitions) == prev_transitions_num:
# Secondary status is not changed but the message changed.
transitions_to_print = current_transitions[-1:]
else:
# Secondary status is changed we need to print all the entries.
transitions_to_print = current_transitions[
prev_transitions_num - len(current_transitions) :
]
status_strs = []
for transition in transitions_to_print:
message = transition["StatusMessage"]
time_str = datetime.utcfromtimestamp(
time.mktime(job_description["LastModifiedTime"].timetuple())
).strftime("%Y-%m-%d %H:%M:%S")
status_strs.append("{} {} - {}".format(time_str, transition["Status"], message))
return "\n".join(status_strs)
[docs]
def download_folder(bucket_name, prefix, target, sagemaker_session):
"""Download a folder from S3 to a local path
Args:
bucket_name (str): S3 bucket name
prefix (str): S3 prefix within the bucket that will be downloaded. Can
be a single file.
target (str): destination path where the downloaded items will be placed
sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
interact with S3.
"""
s3 = sagemaker_session.s3_resource
prefix = prefix.lstrip("/")
if ".." in prefix:
raise ValueError("Traversal components are not allowed in S3 path!")
# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
# Do this first, in case the object has broader permissions than the bucket.
if not prefix.endswith("/"):
try:
file_destination = os.path.join(target, os.path.basename(prefix))
s3.Object(bucket_name, prefix).download_file(file_destination)
return
except botocore.exceptions.ClientError as e:
err_info = e.response["Error"]
if err_info["Code"] == "404" and err_info["Message"] == "Not Found":
# S3 also throws this error if the object is a folder,
# so assume that is the case here, and then raise for an actual 404 later.
pass
else:
raise
_download_files_under_prefix(bucket_name, prefix, target, s3)
def _download_files_under_prefix(bucket_name, prefix, target, s3):
"""Download all S3 files which match the given prefix
Args:
bucket_name (str): S3 bucket name
prefix (str): S3 prefix within the bucket that will be downloaded
target (str): destination path where the downloaded items will be placed
s3 (boto3.resources.base.ServiceResource): S3 resource
"""
bucket = s3.Bucket(bucket_name)
for obj_sum in bucket.objects.filter(Prefix=prefix):
# if obj_sum is a folder object skip it.
if obj_sum.key.endswith("/"):
continue
obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/")
file_path = os.path.join(target, s3_relative_path)
try:
os.makedirs(os.path.dirname(file_path))
except OSError as exc:
# EEXIST means the folder already exists, this is safe to skip
# anything else will be raised.
if exc.errno != errno.EEXIST:
raise
obj.download_file(file_path)
[docs]
def create_tar_file(source_files, target=None):
"""Create a tar file containing all the source_files
Args:
source_files: (List[str]): List of file paths that will be contained in the tar file
target:
Returns:
(str): path to created tar file
"""
if target:
filename = target
else:
_, filename = tempfile.mkstemp()
with tarfile.open(filename, mode="w:gz", dereference=True) as t:
for sf in source_files:
# Add all files from the directory into the root of the directory structure of the tar
t.add(sf, arcname=os.path.basename(sf))
return filename
@contextlib.contextmanager
def _tmpdir(suffix="", prefix="tmp", directory=None):
"""Create a temporary directory with a context manager.
The file is deleted when the context exits, even when there's an exception.
The prefix, suffix, and dir arguments are the same as for mkstemp().
Args:
suffix (str): If suffix is specified, the file name will end with that
suffix, otherwise there will be no suffix.
prefix (str): If prefix is specified, the file name will begin with that
prefix; otherwise, a default prefix is used.
directory (str): If a directory is specified, the file will be downloaded
in this directory; otherwise, a default directory is used.
Returns:
str: path to the directory
"""
if directory is not None and not (os.path.exists(directory) and os.path.isdir(directory)):
raise ValueError(
"Inputted directory for storing newly generated temporary "
f"directory does not exist: '{directory}'"
)
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=directory)
try:
yield tmp
finally:
shutil.rmtree(tmp)
[docs]
def repack_model(
inference_script,
source_directory,
dependencies,
model_uri,
repacked_model_uri,
sagemaker_session,
kms_key=None,
):
"""Unpack model tarball and creates a new model tarball with the provided code script.
This function does the following: - uncompresses model tarball from S3 or
local system into a temp folder - replaces the inference code from the model
with the new code provided - compresses the new model tarball and saves it
in S3 or local file system
Args:
inference_script (str): path or basename of the inference script that
will be packed into the model
source_directory (str): path including all the files that will be packed
into the model
dependencies (list[str]): A list of paths to directories (absolute or
relative) with any additional libraries that will be exported to the
container (default: []). The library folders will be copied to
SageMaker in the same folder where the entrypoint is copied.
Example
The following call >>> Estimator(entry_point='train.py',
dependencies=['my/libs/common', 'virtual-env']) results in the
following inside the container:
>>> $ ls
>>> opt/ml/code
>>> |------ train.py
>>> |------ common
>>> |------ virtual-env
model_uri (str): S3 or file system location of the original model tar
repacked_model_uri (str): path or file system location where the new
model will be saved
sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
interact with S3.
kms_key (str): KMS key ARN for encrypting the repacked model file
Returns:
str: path to the new packed model
"""
dependencies = dependencies or []
local_download_dir = (
None
if sagemaker_session.settings is None
or sagemaker_session.settings.local_download_dir is None
else sagemaker_session.settings.local_download_dir
)
with _tmpdir(directory=local_download_dir) as tmp:
model_dir = _extract_model(model_uri, sagemaker_session, tmp)
_create_or_update_code_dir(
model_dir,
inference_script,
source_directory,
dependencies,
sagemaker_session,
tmp,
)
tmp_model_path = os.path.join(tmp, "temp-model.tar.gz")
with tarfile.open(tmp_model_path, mode="w:gz") as t:
t.add(model_dir, arcname=os.path.sep)
_save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key)
def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
"""Placeholder docstring"""
if repacked_model_uri.lower().startswith("s3://"):
url = parse.urlparse(repacked_model_uri)
bucket, key = url.netloc, url.path.lstrip("/")
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))
settings = (
sagemaker_session.settings if sagemaker_session is not None else SessionSettings()
)
encrypt_artifact = settings.encrypt_repacked_artifacts
if kms_key:
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
elif encrypt_artifact:
extra_args = {"ServerSideEncryption": "aws:kms"}
else:
extra_args = None
sagemaker_session.boto_session.resource(
"s3", region_name=sagemaker_session.boto_region_name
).Object(bucket, new_key).upload_file(tmp_model_path, ExtraArgs=extra_args)
else:
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
def _validate_source_directory(source_directory):
"""Validate that source_directory is safe to use.
Ensures the source directory path does not access restricted system locations.
Args:
source_directory (str): The source directory path to validate.
Raises:
ValueError: If the path is not allowed.
"""
if not source_directory or source_directory.lower().startswith("s3://"):
# S3 paths and None are safe
return
# Resolve symlinks to get the actual path
abs_source = abspath(realpath(source_directory))
# Check if the source path is under any sensitive directory
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if abs_source != "/" and abs_source.startswith(sensitive_path):
raise ValueError(
f"source_directory cannot access sensitive system paths. "
f"Got: {source_directory} (resolved to {abs_source})"
)
def _validate_dependency_path(dependency):
"""Validate that a dependency path is safe to use.
Ensures the dependency path does not access restricted system locations.
Args:
dependency (str): The dependency path to validate.
Raises:
ValueError: If the path is not allowed.
"""
if not dependency:
return
# Resolve symlinks to get the actual path
abs_dependency = abspath(realpath(dependency))
# Check if the dependency path is under any sensitive directory
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
raise ValueError(
f"dependency path cannot access sensitive system paths. "
f"Got: {dependency} (resolved to {abs_dependency})"
)
def _create_or_update_code_dir(
model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp
):
"""Placeholder docstring"""
code_dir = os.path.join(model_dir, "code")
resolved_code_dir = _get_resolved_path(code_dir)
# Validate that code_dir does not resolve to a sensitive system path
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
raise ValueError(
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
)
if source_directory and source_directory.lower().startswith("s3://"):
local_code_path = os.path.join(tmp, "local_code.tar.gz")
download_file_from_url(source_directory, local_code_path, sagemaker_session)
with tarfile.open(name=local_code_path, mode="r:gz") as t:
custom_extractall_tarfile(t, code_dir)
elif source_directory:
# Validate source_directory for security
_validate_source_directory(source_directory)
if os.path.exists(code_dir):
shutil.rmtree(code_dir)
shutil.copytree(source_directory, code_dir)
else:
if not os.path.exists(code_dir):
os.mkdir(code_dir)
try:
shutil.copy2(inference_script, code_dir)
except FileNotFoundError:
if os.path.exists(os.path.join(code_dir, inference_script)):
pass
else:
raise
for dependency in dependencies:
# Validate dependency path for security
_validate_dependency_path(dependency)
lib_dir = os.path.join(code_dir, "lib")
if os.path.isdir(dependency):
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
else:
if not os.path.exists(lib_dir):
os.mkdir(lib_dir)
shutil.copy2(dependency, lib_dir)
def _extract_model(model_uri, sagemaker_session, tmp):
"""Placeholder docstring"""
tmp_model_dir = os.path.join(tmp, "model")
os.mkdir(tmp_model_dir)
if model_uri.lower().startswith("s3://"):
local_model_path = os.path.join(tmp, "tar_file")
download_file_from_url(model_uri, local_model_path, sagemaker_session)
else:
local_model_path = model_uri.replace("file://", "")
with tarfile.open(name=local_model_path, mode="r:gz") as t:
custom_extractall_tarfile(t, tmp_model_dir)
return tmp_model_dir
[docs]
def download_file_from_url(url, dst, sagemaker_session):
"""Placeholder docstring"""
url = parse.urlparse(url)
bucket, key = url.netloc, url.path.lstrip("/")
download_file(bucket, key, dst, sagemaker_session)
[docs]
def download_file(bucket_name, path, target, sagemaker_session):
"""Download a Single File from S3 into a local path
Args:
bucket_name (str): S3 bucket name
path (str): file path within the bucket
target (str): destination directory for the downloaded file.
sagemaker_session (sagemaker.core.helper.session.Session): a sagemaker session to
interact with S3.
"""
path = path.lstrip("/")
boto_session = sagemaker_session.boto_session
s3 = boto_session.resource("s3", region_name=sagemaker_session.boto_region_name)
bucket = s3.Bucket(bucket_name)
bucket.download_file(path, target)
[docs]
def sts_regional_endpoint(region):
"""Get the AWS STS endpoint specific for the given region.
We need this function because the AWS SDK does not yet honor
the ``region_name`` parameter when creating an AWS STS client.
For the list of regional endpoints, see
https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
Args:
region (str): AWS region name
Returns:
str: AWS STS regional endpoint
"""
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
return "https://{}".format(endpoint_data["hostname"])
[docs]
def retries(
max_retry_count,
exception_message_prefix,
seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS,
):
"""Retries until max retry count is reached.
Args:
max_retry_count (int): The retry count.
exception_message_prefix (str): The message to include in the exception on failure.
seconds_to_sleep (int): The number of seconds to sleep between executions.
"""
for i in range(max_retry_count):
yield i
time.sleep(seconds_to_sleep)
raise Exception(
"'{}' has reached the maximum retry count of {}".format(
exception_message_prefix, max_retry_count
)
)
[docs]
def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code=None):
"""Retry with backoff until maximum attempts are reached
Args:
callable_func (Callable): The callable function to retry.
num_attempts (int): The maximum number of attempts to retry.(Default: 8)
botocore_client_error_code (str): The specific Botocore ClientError exception error code
on which to retry on.
If provided other exceptions will be raised directly w/o retry.
If not provided, retry on any exception.
(Default: None)
"""
if num_attempts < 1:
raise ValueError(
"The num_attempts must be >= 1, but the given value is {}.".format(num_attempts)
)
for i in range(num_attempts):
try:
return callable_func()
except Exception as ex: # pylint: disable=broad-except
if not botocore_client_error_code or (
botocore_client_error_code
and isinstance(ex, botocore.exceptions.ClientError)
and ex.response["Error"]["Code"] # pylint: disable=no-member
== botocore_client_error_code
):
if i == num_attempts - 1:
raise ex
else:
raise ex
logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex))
time.sleep(2**i)
def _botocore_resolver():
"""Get the DNS suffix for the given region.
Args:
region (str): AWS region name
Returns:
str: the DNS suffix
"""
loader = botocore.loaders.create_loader()
return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
[docs]
def aws_partition(region):
"""Given a region name (ex: "cn-north-1"), return the corresponding aws partition ("aws-cn").
Args:
region (str): The region name for which to return the corresponding partition.
Ex: "cn-north-1"
Returns:
str: partition corresponding to the region name passed in. Ex: "aws-cn"
"""
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
return endpoint_data["partition"]
[docs]
class DeferredError(object):
"""Stores an exception and raises it at a later time if this object is accessed in any way.
Useful to allow soft-dependencies on imports, so that the ImportError can be raised again
later if code actually relies on the missing library.
Example::
try:
import obscurelib
except ImportError as e:
logger.warning("Failed to import obscurelib. Obscure features will not work.")
obscurelib = DeferredError(e)
"""
def __init__(self, exception):
"""Placeholder docstring"""
self.exc = exception
def __getattr__(self, name):
"""Called by Python interpreter before using any method or property on the object.
So this will short-circuit essentially any access to this object.
Args:
name:
"""
raise self.exc
def _module_import_error(py_module, feature, extras):
"""Return error message for module import errors, provide installation details.
Args:
py_module (str): Module that failed to be imported
feature (str): Affected SageMaker feature
extras (str): Name of the `extras_require` to install the relevant dependencies
Returns:
str: Error message with installation instructions.
"""
error_msg = (
"Failed to import {}. {} features will be impaired or broken. "
"Please run \"pip install 'sagemaker[{}]'\" "
"to install all required dependencies."
)
return error_msg.format(py_module, feature, extras)
[docs]
class DataConfig(abc.ABC):
"""Abstract base class for accessing data config hosted in AWS resources.
Provides a skeleton for customization by overriding of method fetch_data_config.
"""
[docs]
@abc.abstractmethod
def fetch_data_config(self):
"""Abstract method implementing retrieval of data config from a pre-configured data source.
Returns:
object: The data configuration object.
"""
[docs]
class S3DataConfig(DataConfig):
"""This class extends the DataConfig class to fetch a data config file hosted on S3"""
def __init__(
self,
sagemaker_session,
bucket_name,
prefix,
):
"""Initialize a ``S3DataConfig`` instance.
Args:
sagemaker_session (Session): SageMaker session instance to use for boto configuration.
bucket_name (str): Required. Bucket name from which data config needs to be fetched.
prefix (str): Required. The object prefix for the hosted data config.
"""
if bucket_name is None or prefix is None:
raise ValueError(
"Bucket Name and S3 file Prefix are required arguments and must be provided."
)
super(S3DataConfig, self).__init__()
self.bucket_name = bucket_name
self.prefix = prefix
self.sagemaker_session = sagemaker_session
[docs]
def fetch_data_config(self):
"""Fetches data configuration from a S3 bucket.
Returns:
object: The JSON object containing data configuration.
"""
json_string = self.sagemaker_session.read_s3_file(self.bucket_name, self.prefix)
return json.loads(json_string)
[docs]
def get_data_bucket(self, region_requested=None):
"""Provides the bucket containing the data for specified region.
Args:
region_requested (str): The region for which the data is beig requested.
Returns:
str: Name of the S3 bucket containing datasets in the requested region.
"""
config = self.fetch_data_config()
region = region_requested if region_requested else self.sagemaker_session.boto_region_name
return config[region] if region in config.keys() else config["default"]
[docs]
def update_container_with_inference_params(
framework=None,
framework_version=None,
nearest_model_name=None,
data_input_configuration=None,
container_def=None,
container_list=None,
):
"""Function to check if inference recommender parameters exist and update container.
Args:
framework (str): Machine learning framework of the model package container image
(default: None).
framework_version (str): Framework version of the Model Package Container Image
(default: None).
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
Amazon SageMaker Inference Recommender (default: None).
data_input_configuration (str): Input object for the model (default: None).
container_def (dict): object to be updated.
container_list (list): list to be updated.
Returns:
dict: dict with inference recommender params
"""
if container_list is not None:
for obj in container_list:
construct_container_object(
obj, data_input_configuration, framework, framework_version, nearest_model_name
)
if container_def is not None:
construct_container_object(
container_def,
data_input_configuration,
framework,
framework_version,
nearest_model_name,
)
return container_list or container_def
[docs]
def construct_container_object(
obj, data_input_configuration, framework, framework_version, nearest_model_name
):
"""Function to construct container object.
Args:
framework (str): Machine learning framework of the model package container image
(default: None).
framework_version (str): Framework version of the Model Package Container Image
(default: None).
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
Amazon SageMaker Inference Recommender (default: None).
data_input_configuration (str): Input object for the model (default: None).
obj (dict): object to be updated.
Returns:
dict: container object
"""
if framework is not None:
obj.update(
{
"Framework": framework,
}
)
if framework_version is not None:
obj.update(
{
"FrameworkVersion": framework_version,
}
)
if nearest_model_name is not None:
obj.update(
{
"NearestModelName": nearest_model_name,
}
)
if data_input_configuration is not None:
obj.update(
{
"ModelInput": {
"DataInputConfig": data_input_configuration,
},
}
)
return obj
[docs]
def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str] = None):
"""Pop out the unused key-word argument and give a warning.
Args:
arg_name (str): The name of the argument to be checked if it is unused.
kwargs (dict): The key-word argument dict.
override_val (str): The value used to override the unused argument (default: None).
"""
if arg_name not in kwargs:
return
warn_msg = "{} supplied in kwargs will be ignored".format(arg_name)
if override_val:
warn_msg += " and further overridden with {}.".format(override_val)
logging.warning(warn_msg)
kwargs.pop(arg_name)
[docs]
def to_string(obj: object):
"""Convert an object to string
This helper function handles converting PipelineVariable object to string as well
Args:
obj (object): The object to be converted
"""
return obj.to_string() if is_pipeline_variable(obj) else str(obj)
def _start_waiting(waiting_time: int):
"""Waiting and print the in progress animation to stdout.
Args:
waiting_time (int): The total waiting time.
"""
interval = float(waiting_time) / WAITING_DOT_NUMBER
progress = ""
for _ in range(WAITING_DOT_NUMBER):
progress += "."
print(progress, end="\r")
time.sleep(interval)
print(len(progress) * " ", end="\r")
[docs]
def get_module(module_name):
"""Import a module.
Args:
module_name (str): name of the module to import.
Returns:
object: The imported module.
Raises:
Exception: when the module name is not found
"""
try:
return import_module(module_name)
except ImportError:
raise Exception("Cannot import module {}, please try again.".format(module_name))
[docs]
def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None) -> dict:
"""Check user input experiment_config or get it from the current Run object if exists.
Args:
experiment_config (dict): The experiment_config supplied by the user.
Returns:
dict: Return the user supplied experiment_config if it is not None.
Otherwise fetch the experiment_config from the current Run object if exists.
"""
from sagemaker.core.experiments._run_context import _RunContext
run_obj = _RunContext.get_current_run()
if experiment_config:
if run_obj:
logger.warning(
"The function is invoked within an Experiment Run context "
"but another experiment_config (%s) was supplied, so "
"ignoring the experiment_config fetched from the Run object.",
experiment_config,
)
return experiment_config
return run_obj.experiment_config if run_obj else None
[docs]
def resolve_value_from_config(
direct_input=None,
config_path: str = None,
default_value=None,
sagemaker_session=None,
sagemaker_config: dict = None,
):
"""Decides which value for the caller to use.
Note: This method incorporates information from the sagemaker config.
Uses this order of prioritization:
1. direct_input
2. config value
3. default_value
4. None
Args:
direct_input: The value that the caller of this method starts with. Usually this is an
input to the caller's class or method.
config_path (str): A string denoting the path used to lookup the value in the
sagemaker config.
default_value: The value used if not present elsewhere.
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
SageMaker interactions (default: None).
sagemaker_config (dict): The sdk defaults config that is normally accessed through a
Session object by doing `session.sagemaker_config`. (default: None) This parameter will
be checked for the config value if (and only if) sagemaker_session is None. This
parameter exists for the rare cases where the user provided no Session but a default
Session cannot be initialized before config injection is needed. In that case,
the config dictionary may be loaded and passed here before a default Session object
is created.
Returns:
The value that should be used by the caller
"""
config_value = (
get_sagemaker_config_value(
sagemaker_session, config_path, sagemaker_config=sagemaker_config
)
if config_path
else None
)
from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
_log_sagemaker_config_single_substitution(direct_input, config_value, config_path)
if direct_input is not None:
return direct_input
if config_value is not None:
return config_value
return default_value
[docs]
def get_sagemaker_config_value(sagemaker_session, key, sagemaker_config: dict = None):
"""Returns the value that corresponds to the provided key from the configuration file.
Args:
key: Key Path of the config file entry.
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
SageMaker interactions.
sagemaker_config (dict): The sdk defaults config that is normally accessed through a
Session object by doing `session.sagemaker_config`. (default: None) This parameter will
be checked for the config value if (and only if) sagemaker_session is None. This
parameter exists for the rare cases where no Session provided but a default Session
cannot be initialized before config injection is needed. In that case, the config
dictionary may be loaded and passed here before a default Session object is created.
Returns:
object: The corresponding default value in the configuration file.
"""
from sagemaker.core.config.config_manager import SageMakerConfig
if sagemaker_session and hasattr(sagemaker_session, "sagemaker_config"):
config_to_check = sagemaker_session.sagemaker_config
else:
config_to_check = sagemaker_config
if not config_to_check:
return None
SageMakerConfig().validate_sagemaker_config(config_to_check)
config_value = get_config_value(key, config_to_check)
# Copy the value so any modifications to the output will not modify the source config
return copy.deepcopy(config_value)
[docs]
def get_resource_name_from_arn(arn):
"""Extract the resource name from an ARN string.
Args:
arn (str): An ARN.
Returns:
str: The resource name.
"""
return arn.split(":", 5)[5].split("/", 1)[1]
[docs]
def resolve_class_attribute_from_config(
clazz: Optional[type],
instance: Optional[object],
attribute: str,
config_path: str,
default_value=None,
sagemaker_session=None,
):
"""Utility method that merges config values to data classes.
Takes an instance of a class and, if not already set, sets the instance's attribute to a
value fetched from the sagemaker_config or the default_value.
Uses this order of prioritization to determine what the value of the attribute should be:
1. current value of attribute
2. config value
3. default_value
4. does not set it
Args:
clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the
instance is None. If None is provided here, no new object will be created
if 'instance' doesnt exist. Note: if provided, the constructor should set default
values to None; Otherwise, the constructor's non-None default will be left
as-is even if a config value was defined.
instance (Optional[object]): instance of the Class 'clazz' that has an attribute
of 'attribute' to set
attribute (str): attribute of the instance to set if not already set
config_path (str): a string denoting the path to use to lookup the config value in the
sagemaker config
default_value: the value to use if not present elsewhere
sagemaker_session (sagemaker.core.helper.session.Sessionn): A SageMaker Session object, used for
SageMaker interactions (default: None).
Returns:
The updated class instance that should be used by the caller instead of the
'instance' parameter that was passed in.
"""
config_value = get_sagemaker_config_value(sagemaker_session, config_path)
if config_value is None and default_value is None:
# return instance unmodified. Could be None or populated
return instance
if instance is None:
if clazz is None or not inspect.isclass(clazz):
return instance
# construct a new instance if the instance does not exist
instance = clazz()
if not hasattr(instance, attribute):
raise TypeError(
"Unexpected structure of object.",
"Expected attribute {} to be present inside instance {} of class {}".format(
attribute, instance, clazz
),
)
current_value = getattr(instance, attribute)
if current_value is None:
# only set value if object does not already have a value set
if config_value is not None:
setattr(instance, attribute, config_value)
elif default_value is not None:
setattr(instance, attribute, default_value)
from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
_log_sagemaker_config_single_substitution(current_value, config_value, config_path)
return instance
[docs]
def resolve_nested_dict_value_from_config(
dictionary: dict,
nested_keys: List[str],
config_path: str,
default_value: object = None,
sagemaker_session=None,
):
"""Utility method that sets the value of a key path in a nested dictionary .
This method takes a dictionary and, if not already set, sets the value for the provided
list of nested keys to the value fetched from the sagemaker_config or the default_value.
Uses this order of prioritization to determine what the value of the attribute should be:
(1) current value of nested key, (2) config value, (3) default_value, (4) does not set it
Args:
dictionary: The dict to update.
nested_keys: The paths of keys where the value should be checked and set if needed.
config_path (str): A string denoting the path used to find the config value in the
sagemaker config.
default_value: The value to use if not present elsewhere.
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
SageMaker interactions (default: None).
Returns:
The updated dictionary that should be used by the caller instead of the
'dictionary' parameter that was passed in.
"""
config_value = get_sagemaker_config_value(sagemaker_session, config_path)
if config_value is None and default_value is None:
# if there is nothing to set, return early. And there is no need to traverse through
# the dictionary or add nested dicts to it
return dictionary
try:
current_nested_value = get_nested_value(dictionary, nested_keys)
except ValueError as e:
logging.error("Failed to check dictionary for applying sagemaker config: %s", e)
return dictionary
if current_nested_value is None:
# only set value if not already set
if config_value is not None:
dictionary = set_nested_value(dictionary, nested_keys, config_value)
elif default_value is not None:
dictionary = set_nested_value(dictionary, nested_keys, default_value)
from sagemaker.core.config.config_utils import _log_sagemaker_config_single_substitution
_log_sagemaker_config_single_substitution(current_nested_value, config_value, config_path)
return dictionary
[docs]
def update_list_of_dicts_with_values_from_config(
input_list,
config_key_path,
required_key_paths: List[str] = None,
union_key_paths: List[List[str]] = None,
sagemaker_session=None,
):
"""Updates a list of dictionaries with missing values that are present in Config.
In some cases, config file might introduce new parameters which requires certain other
parameters to be provided as part of the input list. Without those parameters, the underlying
service will throw an exception. This method provides the capability to specify required key
paths.
In some other cases, config file might introduce new parameters but the service API requires
either an existing parameter or the new parameter that was supplied by config but not both
Args:
input_list: The input list that was provided as a method parameter.
config_key_path: The Key Path in the Config file that corresponds to the input_list
parameter.
required_key_paths (List[str]): List of required key paths that should be verified in the
merged output. If a required key path is missing, we will not perform the merge for that
item.
union_key_paths (List[List[str]]): List of List of Key paths for which we need to verify
whether exactly zero/one of the parameters exist.
For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter or
neither but not both, then pass [['X1', 'X2']]
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
SageMaker interactions (default: None).
Returns:
No output. In place merge happens.
"""
if not input_list:
return
inputs_copy = copy.deepcopy(input_list)
inputs_from_config = get_sagemaker_config_value(sagemaker_session, config_key_path) or []
unmodified_inputs_from_config = copy.deepcopy(inputs_from_config)
for i in range(min(len(input_list), len(inputs_from_config))):
dict_from_inputs = input_list[i]
dict_from_config = inputs_from_config[i]
merge_dicts(dict_from_config, dict_from_inputs)
# Check if required key paths are present in merged dict (dict_from_config)
required_key_path_check_passed = _validate_required_paths_in_a_dict(
dict_from_config, required_key_paths
)
if not required_key_path_check_passed:
# Don't do the merge, config is introducing a new parameter which needs a
# corresponding required parameter.
continue
union_key_path_check_passed = _validate_union_key_paths_in_a_dict(
dict_from_config, union_key_paths
)
if not union_key_path_check_passed:
# Don't do the merge, Union parameters are not obeyed.
continue
input_list[i] = dict_from_config
from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
_log_sagemaker_config_merge(
source_value=inputs_copy,
config_value=unmodified_inputs_from_config,
merged_source_and_config_value=input_list,
config_key_path=config_key_path,
)
def _validate_required_paths_in_a_dict(source_dict, required_key_paths: List[str] = None) -> bool:
"""Placeholder docstring"""
if not required_key_paths:
return True
for required_key_path in required_key_paths:
if get_config_value(required_key_path, source_dict) is None:
return False
return True
def _validate_union_key_paths_in_a_dict(
source_dict, union_key_paths: List[List[str]] = None
) -> bool:
"""Placeholder docstring"""
if not union_key_paths:
return True
for union_key_path in union_key_paths:
union_parameter_present = False
for key_path in union_key_path:
if get_config_value(key_path, source_dict):
if union_parameter_present:
return False
union_parameter_present = True
return True
[docs]
def update_nested_dictionary_with_values_from_config(
source_dict, config_key_path, sagemaker_session=None
) -> dict:
"""Updates a nested dictionary with missing values that are present in Config.
Args:
source_dict: The input nested dictionary that was provided as method parameter.
config_key_path: The Key Path in the Config file which corresponds to this
source_dict parameter.
sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for
SageMaker interactions (default: None).
Returns:
dict: The merged nested dictionary that is updated with missing values that are present
in the Config file.
"""
inferred_config_dict = get_sagemaker_config_value(sagemaker_session, config_key_path) or {}
original_config_dict_value = copy.deepcopy(inferred_config_dict)
merge_dicts(inferred_config_dict, source_dict or {})
if original_config_dict_value == {}:
# The config value is empty. That means either
# (1) inferred_config_dict equals source_dict, or
# (2) if source_dict was None, inferred_config_dict equals {}
# We should return whatever source_dict was to be safe. Because if for example,
# a VpcConfig is set to {} instead of None, some boto calls will fail due to
# ParamValidationError (because a VpcConfig was specified but required parameters for
# the VpcConfig were missing.)
# Don't need to print because no config value was used or defined
return source_dict
from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
_log_sagemaker_config_merge(
source_value=source_dict,
config_value=original_config_dict_value,
merged_source_and_config_value=inferred_config_dict,
config_key_path=config_key_path,
)
return inferred_config_dict
[docs]
def stringify_object(obj: Any) -> str:
"""Returns string representation of object, returning only non-None fields."""
non_none_atts = {key: value for key, value in obj.__dict__.items() if value is not None}
return f"{type(obj).__name__}: {str(non_none_atts)}"
[docs]
def volume_size_supported(instance_type: str) -> bool:
"""Returns True if SageMaker allows volume_size to be used for the instance type.
Raises:
ValueError: If the instance type is improperly formatted.
"""
try:
# local mode does not support volume size
# instance type given as pipeline parameter does not support volume size
# do not change the if statement order below.
if is_pipeline_variable(instance_type) or instance_type.startswith("local"):
return False
parts: List[str] = instance_type.split(".")
if len(parts) == 3 and parts[0] == "ml":
parts = parts[1:]
if len(parts) != 2:
raise ValueError(f"Failed to parse instance type '{instance_type}'")
# Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc)
# + g5 or g6 or p5 does not support attaching an EBS volume.
family = parts[0]
unsupported_families = ["g5", "g6", "p5", "trn1"]
return "d" not in family and not any(
family.startswith(prefix) for prefix in unsupported_families
)
except Exception as e:
raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}")
[docs]
def instance_supports_kms(instance_type: str) -> bool:
"""Returns True if SageMaker allows KMS keys to be attached to the instance.
Raises:
ValueError: If the instance type is improperly formatted.
"""
return volume_size_supported(instance_type)
[docs]
def get_instance_type_family(instance_type: str) -> str:
"""Return the family of the instance type.
Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
or there is no match, return an empty string.
"""
instance_type_family = ""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
if match is not None:
instance_type_family = match[1]
return instance_type_family
[docs]
def create_paginator_config(max_items: int = None, page_size: int = None) -> Dict[str, int]:
"""Placeholder docstring"""
return {
"MaxItems": max_items if max_items else MAX_ITEMS,
"PageSize": page_size if page_size else PAGE_SIZE,
}
def _get_resolved_path(path):
"""Return the normalized absolute path of a given path.
abspath - returns the absolute path without resolving symlinks
realpath - resolves the symlinks and gets the actual path
normpath - normalizes paths (e.g. remove redudant separators)
and handles platform-specific differences
"""
return normpath(realpath(abspath(path)))
def _is_bad_path(path, base):
"""Checks if the joined path (base directory + file path) is rooted under the base directory
Ensuring that the file does not attempt to access paths
outside the expected directory structure.
Args:
path (str): The file path.
base (str): The base directory.
Returns:
bool: True if the path is not rooted under the base directory, False otherwise.
"""
# joinpath will ignore base if path is absolute
return not _get_resolved_path(joinpath(base, path)).startswith(base)
def _is_bad_link(info, base):
"""Checks if the link is rooted under the base directory.
Ensuring that the link does not attempt to access paths outside the expected directory structure
Args:
info (tarfile.TarInfo): The tar file info.
base (str): The base directory.
Returns:
bool: True if the link is not rooted under the base directory, False otherwise.
"""
# Links are interpreted relative to the directory containing the link
tip = _get_resolved_path(joinpath(base, dirname(info.name)))
return _is_bad_path(info.linkname, base=tip)
def _get_safe_members(members):
"""A generator that yields members that are safe to extract.
It filters out bad paths and bad links.
Args:
members (list): A list of members to check.
Yields:
tarfile.TarInfo: The tar file info.
"""
base = _get_resolved_path("")
for file_info in members:
if _is_bad_path(file_info.name, base):
logger.error("%s is blocked (illegal path)", file_info.name)
elif file_info.issym() and _is_bad_link(file_info, base):
logger.error("%s is blocked: Symlink to %s", file_info.name, file_info.linkname)
elif file_info.islnk() and _is_bad_link(file_info, base):
logger.error("%s is blocked: Hard link to %s", file_info.name, file_info.linkname)
else:
yield file_info
def _validate_extracted_paths(extract_path):
"""Validate that extracted paths remain within the expected directory.
Performs post-extraction validation to ensure all extracted files and directories
are within the intended extraction path.
Args:
extract_path (str): The path where files were extracted.
Raises:
ValueError: If any extracted file is outside the expected extraction path.
"""
base = _get_resolved_path(extract_path)
for root, dirs, files in os.walk(extract_path):
# Check directories
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
resolved = _get_resolved_path(dir_path)
if not resolved.startswith(base):
logger.error("Extracted directory escaped extraction path: %s", dir_path)
raise ValueError(f"Extracted path outside expected directory: {dir_path}")
# Check files
for file_name in files:
file_path = os.path.join(root, file_name)
resolved = _get_resolved_path(file_path)
if not resolved.startswith(base):
logger.error("Extracted file escaped extraction path: %s", file_path)
raise ValueError(f"Extracted path outside expected directory: {file_path}")
[docs]
def can_model_package_source_uri_autopopulate(source_uri: str):
"""Checks if the source_uri can lead to auto-population of information in the Model registry.
Args:
source_uri (str): The source uri.
Returns:
bool: True if the source_uri can lead to auto-population, False otherwise.
"""
return bool(
re.match(MODEL_PACKAGE_ARN_PATTERN, source_uri) or re.match(MODEL_ARN_PATTERN, source_uri)
)
[docs]
def flatten_dict(
d: Dict[str, Any],
max_flatten_depth=None,
) -> Dict[str, Any]:
"""Flatten a dictionary object.
d (Dict[str, Any]):
The dict that will be flattened.
max_flatten_depth (Optional[int]):
Maximum depth to merge.
"""
def tuple_reducer(k1, k2):
if k1 is None:
return (k2,)
return k1 + (k2,)
# check max_flatten_depth
if max_flatten_depth is not None and max_flatten_depth < 1:
raise ValueError("max_flatten_depth should not be less than 1.")
reducer = tuple_reducer
flat_dict = {}
def _flatten(_d, depth, parent=None):
key_value_iterable = viewitems(_d)
has_item = False
for key, value in key_value_iterable:
has_item = True
flat_key = reducer(parent, key)
if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth):
has_child = _flatten(value, depth=depth + 1, parent=flat_key)
if has_child:
continue
if flat_key in flat_dict:
raise ValueError("duplicated key '{}'".format(flat_key))
flat_dict[flat_key] = value
return has_item
_flatten(d, depth=1)
return flat_dict
[docs]
def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None:
"""Set a value to a sequence of nested keys."""
key = keys[0]
if len(keys) == 1:
d[key] = value
return
d = d.setdefault(key, {})
nested_set_dict(d, keys[1:], value)
[docs]
def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
"""Unflatten dict-like object.
d (Dict[str, Any]) :
The dict that will be unflattened.
"""
unflattened_dict = {}
for flat_key, value in viewitems(d):
key_tuple = flat_key
nested_set_dict(unflattened_dict, key_tuple, value)
return unflattened_dict
[docs]
def deep_override_dict(
dict1: Dict[str, Any], dict2: Dict[str, Any], skip_keys: Optional[List[str]] = None
) -> Dict[str, Any]:
"""Overrides any overlapping contents of dict1 with the contents of dict2."""
if skip_keys is None:
skip_keys = []
flattened_dict1 = flatten_dict(dict1)
flattened_dict1 = {key: value for key, value in flattened_dict1.items() if value is not None}
flattened_dict2 = flatten_dict(
{key: value for key, value in dict2.items() if key not in skip_keys}
)
flattened_dict1.update(flattened_dict2)
return unflatten_dict(flattened_dict1) if flattened_dict1 else {}
def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Resolve Routing Config
Args:
routing_config (Optional[Dict[str, Any]]): The routing config.
Returns:
Optional[Dict[str, Any]]: The resolved routing config.
Raises:
ValueError: If the RoutingStrategy is invalid.
"""
if routing_config:
routing_strategy = routing_config.get("RoutingStrategy", None)
if routing_strategy:
if isinstance(routing_strategy, RoutingStrategy):
return {"RoutingStrategy": routing_strategy.name}
if isinstance(routing_strategy, str) and (
routing_strategy.upper() == RoutingStrategy.RANDOM.name
or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name
):
return {"RoutingStrategy": routing_strategy.upper()}
raise ValueError(
"RoutingStrategy must be either RoutingStrategy.RANDOM "
"or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS"
)
return None
[docs]
@lru_cache
def get_instance_rate_per_hour(
instance_type: str,
region: str,
) -> Optional[Dict[str, str]]:
"""Gets instance rate per hour for the given instance type.
Args:
instance_type (str): The instance type.
region (str): The region.
Returns:
Optional[Dict[str, str]]: Instance rate per hour.
Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}.
Raises:
Exception: An exception is raised if
the IAM role is not authorized to perform pricing:GetProducts.
or unexpected event happened.
"""
region_name = "us-east-1"
if region.startswith("eu") or region.startswith("af"):
region_name = "eu-central-1"
elif region.startswith("ap") or region.startswith("cn"):
region_name = "ap-south-1"
pricing_client: boto3.client = boto3.client("pricing", region_name=region_name)
res = pricing_client.get_products(
ServiceCode="AmazonSageMaker",
Filters=[
{"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type},
{"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"},
{"Type": "TERM_MATCH", "Field": "regionCode", "Value": region},
],
)
price_list = res.get("PriceList", [])
if len(price_list) > 0:
price_data = price_list[0]
if isinstance(price_data, str):
price_data = json.loads(price_data)
instance_rate_per_hour = extract_instance_rate_per_hour(price_data)
if instance_rate_per_hour is not None:
return instance_rate_per_hour
raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.")
[docs]
def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]:
"""Iteratively updates a dictionary to convert all keys from snake_case to PascalCase.
Args:
data (dict): The dictionary to be updated.
Returns:
dict: The updated dictionary with keys in PascalCase.
"""
result = {}
def convert_key(key):
"""Converts a snake_case key to PascalCase."""
return "".join(part.capitalize() for part in key.split("_"))
def convert_value(value):
"""Recursively processes the value of a key-value pair."""
if isinstance(value, dict):
return camel_case_to_pascal_case(value)
if isinstance(value, list):
return [convert_value(item) for item in value]
return value
for key, value in data.items():
result[convert_key(key)] = convert_value(value)
return result
[docs]
def tag_exists(tag: TagsDict, curr_tags: Optional[Tags]) -> bool:
"""Returns True if ``tag`` already exists.
Args:
tag (TagsDict): The tag dictionary.
curr_tags (Optional[Tags]): The current tags.
Returns:
bool: True if the tag exists.
"""
if curr_tags is None:
return False
for curr_tag in curr_tags:
if tag["Key"] == curr_tag["Key"]:
return True
return False
def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> Optional[Tags]:
"""Validates new tags against existing tags.
Args:
new_tags (Optional[Tags]): The new tags.
curr_tags (Optional[Tags]): The current tags.
Returns:
Optional[Tags]: The updated tags.
"""
if curr_tags is None:
return new_tags
if curr_tags and isinstance(curr_tags, dict):
curr_tags = [curr_tags]
if isinstance(new_tags, dict):
if not tag_exists(new_tags, curr_tags):
curr_tags.append(new_tags)
elif isinstance(new_tags, list):
for new_tag in new_tags:
if not tag_exists(new_tag, curr_tags):
curr_tags.append(new_tag)
return curr_tags
[docs]
def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]:
"""Remove a tag with the given key from the list of tags.
Args:
key (str): The key of the tag to remove.
tags (Optional[Tags]): The current list of tags.
Returns:
Optional[Tags]: The updated list of tags with the tag removed.
"""
if tags is None:
return tags
if isinstance(tags, dict):
tags = [tags]
updated_tags = []
for tag in tags:
if tag["Key"] != key:
updated_tags.append(tag)
if not updated_tags:
return None
if len(updated_tags) == 1:
return updated_tags[0]
return updated_tags
[docs]
def get_domain_for_region(region: str) -> str:
"""Returns the domain for the given region.
Args:
region (str): AWS region name.
"""
return ALTERNATE_DOMAINS.get(region, "amazonaws.com")
[docs]
def camel_to_snake(camel_case_string: str) -> str:
"""Converts camelCase to snake_case_string using a regex.
This regex cannot handle whitespace ("camelString TwoWords")
"""
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
[docs]
def walk_and_apply_json(
json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"]
) -> Dict[Any, Any]:
"""Recursively walks a json object and applies a given function to the keys.
stop_keys (Optional[list[str]]): List of field keys that should stop the application function.
Any children of these keys will not have the application function applied to them.
"""
def _walk_and_apply_json(json_obj, new):
if isinstance(json_obj, dict) and isinstance(new, dict):
for key, value in json_obj.items():
new_key = apply(key)
if (stop_keys and new_key not in stop_keys) or stop_keys is None:
if isinstance(value, dict):
new[new_key] = {}
_walk_and_apply_json(value, new=new[new_key])
elif isinstance(value, list):
new[new_key] = []
for item in value:
_walk_and_apply_json(item, new=new[new_key])
else:
new[new_key] = value
else:
new[new_key] = value
elif isinstance(json_obj, dict) and isinstance(new, list):
new.append(_walk_and_apply_json(json_obj, new={}))
elif isinstance(json_obj, list) and isinstance(new, dict):
new.update(json_obj)
elif isinstance(json_obj, list) and isinstance(new, list):
new.append(json_obj)
elif isinstance(json_obj, str) and isinstance(new, list):
new.append(json_obj)
return new
return _walk_and_apply_json(json_obj, new={})
def _wait_until(callable_fn, poll=5):
"""Placeholder docstring"""
elapsed_time = 0
result = None
while result is None:
try:
elapsed_time += poll
time.sleep(poll)
result = callable_fn()
except botocore.exceptions.ClientError as err:
# For initial 5 mins we accept/pass AccessDeniedException.
# The reason is to await tag propagation to avoid false AccessDenied claims for an
# access policy based on resource tags, The caveat here is for true AccessDenied
# cases the routine will fail after 5 mins
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
logger.warning(
"Received AccessDeniedException. This could mean the IAM role does not "
"have the resource permissions, in which case please add resource access "
"and retry. For cases where the role has tag based resource policy, "
"continuing to wait for tag propagation.."
)
continue
raise err
return result
def _flush_log_streams(
stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
):
"""Placeholder docstring"""
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so this list
# may be dynamic until we have a stream for every instance.
try:
streams = client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + "/",
orderBy="LogStreamName",
limit=min(instance_count, 50),
)
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
while "nextToken" in streams:
streams = client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + "/",
orderBy="LogStreamName",
limit=50,
)
stream_names.extend([s["logStreamName"] for s in streams["logStreams"]])
positions.update(
[
(s, sagemaker.core.logs.Position(timestamp=0, skip=0))
for s in stream_names
if s not in positions
]
)
except ClientError as e:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
err = e.response.get("Error", {})
if err.get("Code", None) != "ResourceNotFoundException":
raise
if len(stream_names) > 0:
if dot:
print("")
dot = False
for idx, event in sagemaker.core.logs.multi_stream_iter(
client, log_group, stream_names, positions
):
color_wrap(idx, event["message"])
ts, count = positions[stream_names[idx]]
if event["timestamp"] == ts:
positions[stream_names[idx]] = sagemaker.core.logs.Position(
timestamp=ts, skip=count + 1
)
else:
positions[stream_names[idx]] = sagemaker.core.logs.Position(
timestamp=event["timestamp"], skip=1
)
else:
dot = True
print(".", end="")
sys.stdout.flush()
[docs]
class LogState(object):
"""Placeholder docstring"""
STARTING = 1
WAIT_IN_PROGRESS = 2
TAILING = 3
JOB_COMPLETE = 4
COMPLETE = 5
_STATUS_CODE_TABLE = {
"COMPLETED": "Completed",
"INPROGRESS": "InProgress",
"IN_PROGRESS": "InProgress",
"FAILED": "Failed",
"STOPPED": "Stopped",
"STOPPING": "Stopping",
"STARTING": "Starting",
"PENDING": "Pending",
}
def _get_initial_job_state(description, status_key, wait):
"""Placeholder docstring"""
status = description[status_key]
job_already_completed = status in ("Completed", "Failed", "Stopped")
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
def _logs_init(boto_session, description, job):
"""Placeholder docstring"""
if job == "Training":
if "InstanceGroups" in description["ResourceConfig"]:
instance_count = 0
for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
instance_count += instanceGroup["InstanceCount"]
else:
instance_count = description["ResourceConfig"]["InstanceCount"]
elif job == "Transform":
instance_count = description["TransformResources"]["InstanceCount"]
elif job == "Processing":
instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
elif job == "AutoML":
instance_count = 0
stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
# to be interrupted by a transient exception.
config = botocore.config.Config(retries={"max_attempts": 15})
client = boto_session.client("logs", config=config)
log_group = "/aws/sagemaker/" + job + "Jobs"
dot = False
color_wrap = sagemaker.core.logs.ColorWrap()
return instance_count, stream_names, positions, client, log_group, dot, color_wrap
def _check_job_status(job, desc, status_key_name):
"""Check to see if the job completed successfully.
If not, construct and raise a exceptions. (UnexpectedStatusException).
Args:
job (str): The name of the job to check.
desc (dict[str, str]): The result of ``describe_training_job()``.
status_key_name (str): Status key name to check for.
Raises:
exceptions.CapacityError: If the training job fails with CapacityError.
exceptions.UnexpectedStatusException: If the training job fails.
"""
status = desc[status_key_name]
# If the status is capital case, then convert it to Camel case
status = _STATUS_CODE_TABLE.get(status, status)
if status == "Stopped":
logger.warning(
"Job ended with status 'Stopped' rather than 'Completed'. "
"This could mean the job timed out or stopped early for some other reason: "
"Consider checking whether it completed as you expect."
)
elif status != "Completed":
reason = desc.get("FailureReason", "(No reason provided)")
job_type = status_key_name.replace("JobStatus", " job")
troubleshooting = (
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
"sagemaker-python-sdk-troubleshooting.html"
)
message = (
"Error for {job_type} {job_name}: {status}. Reason: {reason}. "
"Check troubleshooting guide for common errors: {troubleshooting}"
).format(
job_type=job_type,
job_name=job,
status=status,
reason=reason,
troubleshooting=troubleshooting,
)
if "CapacityError" in str(reason):
raise exceptions.CapacityError(
message=message,
allowed_statuses=["Completed", "Stopped"],
actual_status=status,
)
raise exceptions.UnexpectedStatusException(
message=message,
allowed_statuses=["Completed", "Stopped"],
actual_status=status,
)
def _create_resource(create_fn):
"""Call create function and accepts/pass when resource already exists.
This is a helper function to use an existing resource if found when creating.
Args:
create_fn: Create resource function.
Returns:
(bool): True if new resource was created, False if resource already exists.
"""
try:
create_fn()
# create function succeeded, resource does not exist already
return True
except ClientError as ce:
error_code = ce.response["Error"]["Code"]
error_message = ce.response["Error"]["Message"]
already_exists_exceptions = ["ValidationException", "ResourceInUse"]
already_exists_msg_patterns = ["Cannot create already existing", "already exists"]
if not (
error_code in already_exists_exceptions
and any(p in error_message for p in already_exists_msg_patterns)
):
raise ce
# no new resource created as resource already exists
return False
def _is_s3_uri(s3_uri: Optional[str]) -> bool:
"""Checks whether an S3 URI is valid.
Args:
s3_uri (Optional[str]): The S3 URI.
Returns:
bool: Whether the S3 URI is valid.
"""
if s3_uri is None:
return False
return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None