# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import, annotations, print_function
import json
import logging
import os
import re
import sys
import time
import uuid
import six
import warnings
from functools import reduce
from typing import Dict, Any, Optional, List
import boto3
import botocore
import botocore.config
from botocore.exceptions import ClientError
from sagemaker.core import exceptions
from sagemaker.core.common_utils import (
secondary_training_status_changed,
secondary_training_status_message,
)
import sagemaker.core.logs
from sagemaker.core.session_settings import SessionSettings
from sagemaker.core.common_utils import (
secondary_training_status_changed,
secondary_training_status_message,
sts_regional_endpoint,
retries,
resolve_value_from_config,
get_sagemaker_config_value,
resolve_class_attribute_from_config,
resolve_nested_dict_value_from_config,
update_nested_dictionary_with_values_from_config,
update_list_of_dicts_with_values_from_config,
format_tags,
Tags,
TagsDict,
instance_supports_kms,
create_paginator_config,
)
from sagemaker.core.config.config_utils import _log_sagemaker_config_merge
from sagemaker.core._studio import _append_project_tags
from sagemaker.core.config.config import load_sagemaker_config, validate_sagemaker_config
from sagemaker.core.config.config_schema import (
KEY,
TRANSFORM_JOB,
TRANSFORM_JOB_ENVIRONMENT_PATH,
TRANSFORM_JOB_KMS_KEY_ID_PATH,
TRANSFORM_OUTPUT_KMS_KEY_ID_PATH,
VOLUME_KMS_KEY_ID,
TRANSFORM_JOB_VOLUME_KMS_KEY_ID_PATH,
MODEL,
MODEL_CONTAINERS_PATH,
MODEL_EXECUTION_ROLE_ARN_PATH,
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
MODEL_PRIMARY_CONTAINER_PATH,
MODEL_VPC_CONFIG_PATH,
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
KMS_KEY_ID,
ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
ENDPOINT_CONFIG,
ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
ENDPOINT_CONFIG_VPC_CONFIG_PATH,
ENDPOINT_CONFIG_ENABLE_NETWORK_ISOLATION_PATH,
ENDPOINT_CONFIG_EXECUTION_ROLE_ARN_PATH,
ENDPOINT,
INFERENCE_COMPONENT,
SAGEMAKER,
TAGS,
SESSION_DEFAULT_S3_BUCKET_PATH,
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
)
# Setting LOGGER for backward compatibility, in case users import it...
logger = LOGGER = logging.getLogger("sagemaker")
NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
MODEL_MONITOR_ONE_TIME_SCHEDULE = "NOW"
_STATUS_CODE_TABLE = {
"COMPLETED": "Completed",
"INPROGRESS": "InProgress",
"IN_PROGRESS": "InProgress",
"FAILED": "Failed",
"STOPPED": "Stopped",
"STOPPING": "Stopping",
"STARTING": "Starting",
"PENDING": "Pending",
}
EP_LOGGER_POLL = 30
DEFAULT_EP_POLL = 30
[docs]
class LogState(object):
"""Placeholder docstring"""
STARTING = 1
WAIT_IN_PROGRESS = 2
TAILING = 3
JOB_COMPLETE = 4
COMPLETE = 5
[docs]
class Session(object): # pylint: disable=too-many-public-methods
"""Manage interactions with the Amazon SageMaker APIs and any other AWS services needed.
This class provides convenient methods for manipulating entities and resources that Amazon
SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
AWS service calls are delegated to an underlying Boto3 session, which by default
is initialized using the AWS configuration chain. When you make an Amazon SageMaker API call
that accesses an S3 bucket location and one is not specified, the ``Session`` creates a default
bucket based on a naming convention which includes the current AWS account ID.
"""
def __init__(
self,
boto_session=None,
sagemaker_client=None,
sagemaker_runtime_client=None,
sagemaker_featurestore_runtime_client=None,
default_bucket=None,
sagemaker_config: dict = None,
settings=None,
sagemaker_metrics_client=None,
default_bucket_prefix: str = None,
):
"""Initialize a SageMaker ``Session``.
Args:
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
calls are delegated to (default: None). If not provided, one is created with
default AWS configuration chain.
sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker service
calls other than ``InvokeEndpoint`` (default: None). Estimators created using this
``Session`` use this client. If not provided, one will be created using this
instance's ``boto_session``.
sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
using this ``Session`` use this client. If not provided, one will be created using
this instance's ``boto_session``.
sagemaker_featurestore_runtime_client (boto3.SageMakerFeatureStoreRuntime.Client):
Client which makes SageMaker FeatureStore record related calls to Amazon SageMaker
(default: None). If not provided, one will be created using
this instance's ``boto_session``.
default_bucket (str): The default Amazon S3 bucket to be used by this session.
This will be created the next time an Amazon S3 bucket is needed (by calling
:func:`default_bucket`).
If not provided, it will be fetched from the sagemaker_config. If not configured
there either, a default bucket will be created based on the following format:
"sagemaker-{region}-{aws-account-id}".
Example: "sagemaker-my-custom-bucket".
sagemaker_metrics_client (boto3.SageMakerMetrics.Client):
Client which makes SageMaker Metrics related calls to Amazon SageMaker
(default: None). If not provided, one will be created using
this instance's ``boto_session``.
default_bucket_prefix (str): The default prefix to use for S3 Object Keys. (default:
None). If provided and where applicable, it will be used by the SDK to construct
default S3 URIs, in the format:
`s3://{default_bucket}/{default_bucket_prefix}/<rest of object key>`
This parameter can also be specified via `{sagemaker_config}` instead of here. If
not provided here or within `{sagemaker_config}`, default S3 URIs will have the
format: `s3://{default_bucket}/<rest of object key>`
"""
# sagemaker_config is validated and initialized inside :func:`_initialize`,
# so if default_bucket is None and the sagemaker_config has a default S3 bucket configured,
# _default_bucket_name_override will be set again inside :func:`_initialize`.
self.endpoint_arn = None
self._default_bucket = None
self._default_bucket_name_override = default_bucket
# this may also be set again inside :func:`_initialize` if it is None
self.default_bucket_prefix = default_bucket_prefix
self._default_bucket_set_by_sdk = False
self.s3_resource = None
self.s3_client = None
self.resource_groups_client = None
self.resource_group_tagging_client = None
self._config = None
self.lambda_client = None
self.settings = settings if settings else SessionSettings()
self._initialize(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=sagemaker_runtime_client,
sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client,
sagemaker_metrics_client=sagemaker_metrics_client,
sagemaker_config=sagemaker_config,
)
def _initialize(
self,
boto_session,
sagemaker_client,
sagemaker_runtime_client,
sagemaker_featurestore_runtime_client,
sagemaker_metrics_client,
sagemaker_config: dict = None,
):
"""Initialize this SageMaker Session.
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
Sets the region_name.
"""
self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session()
self._region_name = self.boto_session.region_name
if self._region_name is None:
raise ValueError(
"Must setup local AWS configuration with a region supported by SageMaker."
)
# Make use of user_agent_extra field of the botocore_config object
# to append SageMaker Python SDK specific user_agent suffix
# to the current User-Agent header value from boto3
# This config will also make sure that user_agent never fails to log the User-Agent string
# even if boto User-Agent header format is updated in the future
# Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
# Create sagemaker_client with the botocore_config object
# This config is customized to append SageMaker Python SDK specific user_agent suffix
if sagemaker_client is not None:
self.sagemaker_client = sagemaker_client
else:
from sagemaker.core.user_agent import get_user_agent_extra_suffix
config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix())
self.sagemaker_client = self.boto_session.client("sagemaker", config=config)
if sagemaker_runtime_client is not None:
self.sagemaker_runtime_client = sagemaker_runtime_client
else:
config = botocore.config.Config(read_timeout=80)
self.sagemaker_runtime_client = self.boto_session.client(
"runtime.sagemaker", config=config
)
if sagemaker_featurestore_runtime_client:
self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client
else:
self.sagemaker_featurestore_runtime_client = self.boto_session.client(
"sagemaker-featurestore-runtime"
)
if sagemaker_metrics_client:
self.sagemaker_metrics_client = sagemaker_metrics_client
else:
self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics")
self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name)
self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name)
self.local_mode = False
if sagemaker_config:
validate_sagemaker_config(sagemaker_config)
self.sagemaker_config = sagemaker_config
else:
# self.s3_resource might be None. If it is None, load_sagemaker_config will
# create a default S3 resource, but only if it needs to fetch from S3
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
# after sagemaker_config initialization, update self._default_bucket_name_override if needed
self._default_bucket_name_override = resolve_value_from_config(
direct_input=self._default_bucket_name_override,
config_path=SESSION_DEFAULT_S3_BUCKET_PATH,
sagemaker_session=self,
)
# after sagemaker_config initialization, update self.default_bucket_prefix if needed
self.default_bucket_prefix = resolve_value_from_config(
direct_input=self.default_bucket_prefix,
config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
sagemaker_session=self,
)
[docs]
def account_id(self) -> str:
"""Get the AWS account id of the caller.
Returns:
AWS account ID.
"""
region = self.boto_session.region_name
sts_client = self.boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
)
return sts_client.get_caller_identity()["Account"]
@property
def config(self) -> Dict | None:
"""The config for the local mode, unused in a normal session"""
return self._config
@config.setter
def config(self, value: Dict | None):
"""The config for the local mode, unused in a normal session"""
self._config = value
@property
def boto_region_name(self):
"""Placeholder docstring"""
return self._region_name
[docs]
def get_caller_identity_arn(self):
"""Returns the ARN user or role whose credentials are used to call the API.
Returns:
str: The ARN user or role
"""
if os.path.exists(NOTEBOOK_METADATA_FILE):
with open(NOTEBOOK_METADATA_FILE, "rb") as f:
metadata = json.loads(f.read())
instance_name = metadata.get("ResourceName")
domain_id = metadata.get("DomainId")
user_profile_name = metadata.get("UserProfileName")
execution_role_arn = metadata.get("ExecutionRoleArn")
try:
# find execution role from the metadata file if present
if execution_role_arn is not None:
return execution_role_arn
if domain_id is None:
instance_desc = self.sagemaker_client.describe_notebook_instance(
NotebookInstanceName=instance_name
)
return instance_desc["RoleArn"]
user_profile_desc = self.sagemaker_client.describe_user_profile(
DomainId=domain_id, UserProfileName=user_profile_name
)
# First, try to find role in userSettings
if user_profile_desc.get("UserSettings", {}).get("ExecutionRole"):
return user_profile_desc["UserSettings"]["ExecutionRole"]
# If not found, fallback to the domain
domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id)
return domain_desc["DefaultUserSettings"]["ExecutionRole"]
except ClientError:
logger.debug(
"Couldn't call 'describe_notebook_instance' to get the Role "
"ARN of the instance %s.",
instance_name,
)
assumed_role = self.boto_session.client(
"sts",
region_name=self.boto_region_name,
endpoint_url=sts_regional_endpoint(self.boto_region_name),
).get_caller_identity()["Arn"]
role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role)
# Call IAM to get the role's path
role_name = role[role.rfind("/") + 1 :]
try:
role = self.boto_session.client("iam").get_role(RoleName=role_name)["Role"]["Arn"]
except ClientError:
logger.warning(
"Couldn't call 'get_role' to get Role ARN from role name %s to get Role path.",
role_name,
)
# This conditional has been present since the inception of SageMaker
# Guessing this conditional's purpose was to handle lack of IAM permissions
# https://github.com/aws/sagemaker-python-sdk/issues/2089#issuecomment-791802713
if "AmazonSageMaker-ExecutionRole" in assumed_role:
logger.warning(
"Assuming role was created in SageMaker AWS console, "
"as the name contains `AmazonSageMaker-ExecutionRole`. "
"Defaulting to Role ARN with service-role in path. "
"If this Role ARN is incorrect, please add "
"IAM read permissions to your role or supply the "
"Role Arn directly."
)
role = re.sub(
r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$",
r"\1iam::\2:role/service-role/\3",
assumed_role,
)
return role
[docs]
def upload_data(self, path, bucket=None, key_prefix="data", callback=None, extra_args=None):
"""Upload local file or directory to S3.
If a single file is specified for upload, the resulting S3 object key is
``{key_prefix}/{filename}`` (filename does not include the local path, if any specified).
If a directory is specified for upload, the API uploads all content, recursively,
preserving relative structure of subdirectories. The resulting object key names are:
``{key_prefix}/{relative_subdirectory_path}/filename``.
Args:
path (str): Path (absolute or relative) of local file or directory to upload.
bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
default bucket of the ``Session`` is used (if default bucket does not exist, the
``Session`` creates it).
key_prefix (str): Optional S3 object key name prefix (default: 'data'). S3 uses the
prefix to create a directory structure for the bucket content that it display in
the S3 console.
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
ExtraArgs parameter documentation here:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
Returns:
str: The S3 URI of the uploaded file(s). If a file is specified in the path argument,
the URI format is: ``s3://{bucket name}/{key_prefix}/{original_file_name}``.
If a directory is specified in the path argument, the URI format is
``s3://{bucket name}/{key_prefix}``.
"""
bucket, key_prefix = self.determine_bucket_and_prefix(
bucket=bucket, key_prefix=key_prefix, sagemaker_session=self
)
# Generate a tuple for each file that we want to upload of the form (local_path, s3_key).
files = []
key_suffix = None
if os.path.isdir(path):
for dirpath, _, filenames in os.walk(path):
for name in filenames:
local_path = os.path.join(dirpath, name)
s3_relative_prefix = (
"" if path == dirpath else os.path.relpath(dirpath, start=path) + "/"
)
s3_key = "{}/{}{}".format(key_prefix, s3_relative_prefix, name)
files.append((local_path, s3_key))
else:
_, name = os.path.split(path)
s3_key = "{}/{}".format(key_prefix, name)
files.append((path, s3_key))
key_suffix = name
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
for local_path, s3_key in files:
s3.Object(bucket, s3_key).upload_file(
local_path, Callback=callback, ExtraArgs=extra_args
)
s3_uri = "s3://{}/{}".format(bucket, key_prefix)
# If a specific file was used as input (instead of a directory), we return the full S3 key
# of the uploaded object. This prevents unintentionally using other files under the same
# prefix during training.
if key_suffix:
s3_uri = "{}/{}".format(s3_uri, key_suffix)
return s3_uri
[docs]
def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
"""Upload a string as a file body.
Args:
body (str): String representing the body of the file.
bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
default bucket of the ``Session`` is used (if default bucket does not exist, the
``Session`` creates it).
key (str): S3 object key. This is the s3 path to the file.
kms_key (str): The KMS key to use for encrypting the file.
Returns:
str: The S3 URI of the uploaded file.
The URI format is: ``s3://{bucket name}/{key}``.
"""
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
s3_object = s3.Object(bucket_name=bucket, key=key)
if kms_key is not None:
s3_object.put(Body=body, SSEKMSKeyId=kms_key, ServerSideEncryption="aws:kms")
else:
s3_object.put(Body=body)
s3_uri = "s3://{}/{}".format(bucket, key)
return s3_uri
[docs]
def download_data(self, path, bucket, key_prefix="", extra_args=None):
"""Download file or directory from S3.
Args:
path (str): Local path where the file or directory should be downloaded to.
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): Optional S3 object key name prefix.
extra_args (dict): Optional extra arguments that may be passed to the
download operation. Please refer to the ExtraArgs parameter in the boto3
documentation here:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html
Returns:
list[str]: List of local paths of downloaded files
"""
# Initialize the S3 client.
if self.s3_client is None:
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_client
# Initialize the variables used to loop through the contents of the S3 bucket.
keys = []
directories = []
next_token = ""
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
# Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into
# a "keys" list.
while next_token is not None:
request_parameters = base_parameters.copy()
if next_token != "":
request_parameters.update({"ContinuationToken": next_token})
response = s3.list_objects_v2(**request_parameters)
contents = response.get("Contents", None)
if not contents:
logger.info(
"Nothing to download from bucket: %s, key_prefix: %s.", bucket, key_prefix
)
return []
# For each object, save its key or directory.
for s3_object in contents:
key: str = s3_object.get("Key")
obj_size = s3_object.get("Size")
if key.endswith("/") and int(obj_size) == 0:
directories.append(os.path.join(path, key))
else:
keys.append(key)
next_token = response.get("NextContinuationToken")
# For each object key, create the directory on the local machine if needed, and then
# download the file.
downloaded_paths = []
for dir_path in directories:
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
for key in keys:
tail_s3_uri_path = os.path.basename(key)
if not os.path.splitext(key_prefix)[1]:
tail_s3_uri_path = os.path.relpath(key, key_prefix)
destination_path = os.path.join(path, tail_s3_uri_path)
if not os.path.exists(os.path.dirname(destination_path)):
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
s3.download_file(
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
)
downloaded_paths.append(destination_path)
return downloaded_paths
[docs]
def read_s3_file(self, bucket, key_prefix):
"""Read a single file from S3.
Args:
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): S3 object key name prefix.
Returns:
str: The body of the s3 file as a string.
"""
if self.s3_client is None:
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_client
# Explicitly passing a None kms_key to boto3 throws a validation error.
s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
return s3_object["Body"].read().decode("utf-8")
[docs]
def list_s3_files(self, bucket, key_prefix):
"""Lists the S3 files given an S3 bucket and key.
Args:
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): S3 object key name prefix.
Returns:
[str]: The list of files at the S3 path.
"""
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
s3_bucket = s3.Bucket(name=bucket)
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
return [s3_object.key for s3_object in s3_objects]
[docs]
def default_bucket(self):
"""Return the name of the default bucket to use in relevant Amazon SageMaker interactions.
This function will create the s3 bucket if it does not exist.
Returns:
str: The name of the default bucket. If the name was not explicitly specified through
the Session or sagemaker_config, the bucket will take the form:
``sagemaker-{region}-{AWS account ID}``.
"""
if self._default_bucket:
return self._default_bucket
region = self.boto_session.region_name
default_bucket = self._default_bucket_name_override
if not default_bucket:
default_bucket = self.generate_default_sagemaker_bucket_name(self.boto_session)
self._default_bucket_set_by_sdk = True
self._create_s3_bucket_if_it_does_not_exist(
bucket_name=default_bucket,
region=region,
)
self._default_bucket = default_bucket
return self._default_bucket
def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
"""Creates an S3 Bucket if it does not exist.
Also swallows a few common exceptions that indicate that the bucket already exists or
that it is being created.
Args:
bucket_name (str): Name of the S3 bucket to be created.
region (str): The region in which to create the bucket.
Raises:
botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket
creation.
If the exception is due to the bucket already existing or
already being created, no exception is raised.
"""
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=region)
else:
s3 = self.s3_resource
bucket = s3.Bucket(name=bucket_name)
if bucket.creation_date is None:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
elif self._default_bucket_set_by_sdk:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)
expected_bucket_owner_id = self.account_id()
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)
[docs]
def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
"""Checks if the bucket belongs to a particular owner and throws a Client Error if it is not
Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
expected_bucket_owner_id (str): Owner ID string
"""
try:
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name,
Prefix=self.default_bucket_prefix,
ExpectedBucketOwner=expected_bucket_owner_id,
)
else:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "403" and message == "Forbidden":
LOGGER.error(
"Since default_bucket param was not set, SageMaker Python SDK tried to use "
"%s bucket. "
"This bucket cannot be configured to use as it is not owned by Account %s. "
"To unblock it's recommended to use custom default_bucket "
"parameter in sagemaker.Session",
bucket_name,
expected_bucket_owner_id,
)
raise
[docs]
def general_bucket_check_if_user_has_permission(
self, bucket_name, s3, bucket, region, bucket_creation_date_none
):
"""Checks if the person running has the permissions to the bucket
If there is any other error that comes up with calling head bucket, it is raised up here
If there is no bucket , it will create one
Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
region (str): The region in which to create the bucket.
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
"""
try:
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name, Prefix=self.default_bucket_prefix
)
else:
s3.meta.client.head_bucket(Bucket=bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
# bucket does not exist or forbidden to access
if bucket_creation_date_none:
if error_code == "404" and message == "Not Found":
self.create_bucket_for_not_exist_error(bucket_name, region, s3)
elif error_code == "403" and message == "Forbidden":
LOGGER.error(
"Bucket %s exists, but access is forbidden. Please try again after "
"adding appropriate access.",
bucket.name,
)
raise
else:
raise
[docs]
def create_bucket_for_not_exist_error(self, bucket_name, region, s3):
"""Creates the S3 bucket in the given region
Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
region (str): The region in which to create the bucket.
"""
# bucket does not exist, create one
try:
if region == "us-east-1":
# 'us-east-1' cannot be specified because it is the default region:
# https://github.com/boto/boto3/issues/125
s3.create_bucket(Bucket=bucket_name)
else:
s3.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": region},
)
logger.info("Created S3 bucket: %s", bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "OperationAborted" and "conflicting conditional operation" in message:
# If this bucket is already being concurrently created,
# we don't need to create it again.
pass
else:
raise
[docs]
def generate_default_sagemaker_bucket_name(self, boto_session):
"""Generates a name for the default sagemaker S3 bucket.
Args:
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
"""
region = boto_session.region_name
account = boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
).get_caller_identity()["Account"]
return "sagemaker-{}-{}".format(region, account)
[docs]
def determine_bucket_and_prefix(
self, bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None
):
"""Helper function that returns the correct S3 bucket and prefix to use depending on the inputs.
Args:
bucket (Optional[str]): S3 Bucket to use (if it exists)
key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists)
sagemaker_session (sagemaker.core.session.Session): Session to fetch a default bucket and
prefix from, if bucket doesn't exist. Expected to exist
Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used
"""
if bucket:
final_bucket = bucket
final_key_prefix = key_prefix
else:
final_bucket = sagemaker_session.default_bucket()
# default_bucket_prefix (if it exists) should be appended if (and only if) 'bucket' does not
# exist and we are using the Session's default_bucket.
final_key_prefix = s3_path_join(sagemaker_session.default_bucket_prefix, key_prefix)
# We should not append default_bucket_prefix even if the bucket exists but is equal to the
# default_bucket, because either:
# (1) the bucket was explicitly passed in by the user and just happens to be the same as the
# default_bucket (in which case we don't want to change the user's input), or
# (2) the default_bucket was fetched from Session earlier already (and the default prefix
# should have been fetched then as well), and then this function was
# called with it. If we appended the default prefix here, we would be appending it more than
# once in total.
return final_bucket, final_key_prefix
def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str):
"""Appends tags specified in the sagemaker_config to the given list of tags.
To minimize the chance of duplicate tags being applied, this is intended to be used
immediately before calls to sagemaker_client, rather than during initialization of
classes like EstimatorBase.
Args:
tags: The list of tags to append to.
config_path_to_tags: The path to look up tags in the config.
Returns:
A list of tags.
"""
config_tags = get_sagemaker_config_value(self, config_path_to_tags)
if config_tags is None or len(config_tags) == 0:
return tags
all_tags = tags or []
for config_tag in config_tags:
config_tag_key = config_tag[KEY]
if not any(tag.get("Key", None) == config_tag_key for tag in all_tags):
# This check prevents new tags with duplicate keys from being added
# (to prevent API failure and/or overwriting of tags). If there is a conflict,
# the user-provided tag should take precedence over the config-provided tag.
# Note: this does not check user-provided tags for conflicts with other
# user-provided tags.
all_tags.append(config_tag)
_log_sagemaker_config_merge(
source_value=tags,
config_value=config_tags,
merged_source_and_config_value=all_tags,
config_key_path=config_path_to_tags,
)
return all_tags
[docs]
def endpoint_from_production_variants(
self,
name,
production_variants,
tags=None,
kms_key=None,
wait=True,
data_capture_config_dict=None,
async_inference_config_dict=None,
explainer_config_dict=None,
live_logging=False,
vpc_config=None,
enable_network_isolation=None,
role=None,
):
"""Create an SageMaker ``Endpoint`` from a list of production variants.
Args:
name (str): The name of the ``Endpoint`` to create.
production_variants (list[dict[str, str]]): The list of production variants to deploy.
tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint
(default: None).
kms_key (str): The KMS key that is used to encrypt the data on the storage volume
attached to the instance hosting the endpoint.
wait (bool): Whether to wait for the endpoint deployment to complete before returning
(default: True).
data_capture_config_dict (dict): Specifies configuration related to Endpoint data
capture for use with Amazon SageMaker Model Monitoring. Default: None.
async_inference_config_dict (dict) : specifies configuration related to async endpoint.
Use this configuration when trying to create async endpoint and make async inference
(default: None)
explainer_config_dict (dict) : Specifies configuration related to explainer.
Use this configuration when trying to use online explainability.
(default: None).
vpc_config (dict[str, list[str]]:
The VpcConfig set on the model (default: None).
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
enable_network_isolation (Boolean): Default False.
If True, enables network isolation in the endpoint, isolating the model
container. No inbound or outbound network calls can be made to
or from the model container.
role (str): An AWS IAM role (either name or full ARN). The Amazon
SageMaker training jobs and APIs that create Amazon SageMaker
endpoints use this role to access training data and model
artifacts. After the endpoint is created, the inference code
might use the IAM role if it needs to access some AWS resources.
(default: None).
Returns:
str: The name of the created ``Endpoint``.
"""
supports_kms = any(
[
instance_supports_kms(production_variant["InstanceType"])
for production_variant in production_variants
if "InstanceType" in production_variant
]
)
update_list_of_dicts_with_values_from_config(
production_variants,
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
required_key_paths=["CoreDumpConfig.DestinationS3Uri"],
sagemaker_session=self,
)
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
kms_key = (
resolve_value_from_config(
kms_key, ENDPOINT_CONFIG_KMS_KEY_ID_PATH, sagemaker_session=self
)
if supports_kms
else kms_key
)
vpc_config = resolve_value_from_config(
vpc_config,
ENDPOINT_CONFIG_VPC_CONFIG_PATH,
sagemaker_session=self,
)
enable_network_isolation = resolve_value_from_config(
enable_network_isolation,
ENDPOINT_CONFIG_ENABLE_NETWORK_ISOLATION_PATH,
sagemaker_session=self,
)
role = resolve_value_from_config(
role,
ENDPOINT_CONFIG_EXECUTION_ROLE_ARN_PATH,
sagemaker_session=self,
sagemaker_config=load_sagemaker_config() if (self is None) else None,
)
# For Amazon SageMaker inference component based endpoint, it will not pass
# Model names during endpoint creation. Instead, ExecutionRoleArn will be
# needed in the endpoint config to create Endpoint
model_names = [pv["ModelName"] for pv in production_variants if "ModelName" in pv]
if len(model_names) == 0:
# Currently, SageMaker Python SDK allow using RoleName to deploy models.
# Use expand_role method to handle this situation.
role = self.expand_role(role)
config_options["ExecutionRoleArn"] = role
endpoint_config_tags = _append_project_tags(format_tags(tags))
endpoint_tags = _append_project_tags(format_tags(tags))
endpoint_config_tags = self._append_sagemaker_config_tags(
endpoint_config_tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT_CONFIG, TAGS)
)
if endpoint_config_tags:
config_options["Tags"] = endpoint_config_tags
if kms_key:
config_options["KmsKeyId"] = kms_key
if data_capture_config_dict is not None:
inferred_data_capture_config_dict = update_nested_dictionary_with_values_from_config(
data_capture_config_dict, ENDPOINT_CONFIG_DATA_CAPTURE_PATH, sagemaker_session=self
)
config_options["DataCaptureConfig"] = inferred_data_capture_config_dict
if async_inference_config_dict is not None:
inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config(
async_inference_config_dict,
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
sagemaker_session=self,
)
config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict
if explainer_config_dict is not None:
config_options["ExplainerConfig"] = explainer_config_dict
if vpc_config is not None:
config_options["VpcConfig"] = vpc_config
if enable_network_isolation is not None:
config_options["EnableNetworkIsolation"] = enable_network_isolation
if role is not None:
config_options["ExecutionRoleArn"] = role
logger.info("Creating endpoint-config with name %s", name)
self.sagemaker_client.create_endpoint_config(**config_options)
return self.create_endpoint(
endpoint_name=name,
config_name=name,
tags=endpoint_tags,
wait=wait,
live_logging=live_logging,
)
[docs]
def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True, live_logging=False):
"""Create an Amazon SageMaker ``Endpoint`` according to the configuration in the request.
Once the ``Endpoint`` is created, client applications can send requests to obtain
inferences. The endpoint configuration is created using the ``CreateEndpointConfig`` API.
Args:
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` being created.
config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy.
wait (bool): Whether to wait for the endpoint deployment to complete before returning
(default: True).
tags (Optional[Tags]): A list of key-value pairs for tagging the endpoint
(default: None).
Returns:
str: Name of the Amazon SageMaker ``Endpoint`` created.
Raises:
botocore.exceptions.ClientError: If Sagemaker throws an exception while creating
endpoint.
"""
logger.info("Creating endpoint with name %s", endpoint_name)
tags = format_tags(tags) or []
tags = _append_project_tags(tags)
tags = self._append_sagemaker_config_tags(
tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT, TAGS)
)
try:
res = self.sagemaker_client.create_endpoint(
EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags
)
if res:
self.endpoint_arn = res["EndpointArn"]
if wait:
self.wait_for_endpoint(endpoint_name, live_logging=live_logging)
return endpoint_name
except Exception as e:
troubleshooting = (
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
"sagemaker-python-sdk-troubleshooting.html"
"#sagemaker-python-sdk-troubleshooting-create-endpoint"
)
logger.error(
"Please check the troubleshooting guide for common errors: %s", troubleshooting
)
raise e
[docs]
def wait_for_endpoint(self, endpoint, poll=DEFAULT_EP_POLL, live_logging=False):
"""Wait for an Amazon SageMaker endpoint deployment to complete.
Args:
endpoint (str): Name of the ``Endpoint`` to wait for.
poll (int): Polling interval in seconds (default: 30).
Raises:
exceptions.CapacityError: If the endpoint creation job fails with CapacityError.
exceptions.UnexpectedStatusException: If the endpoint creation job fails.
Returns:
dict: Return value from the ``DescribeEndpoint`` API.
"""
if not live_logging or not _has_permission_for_live_logging(self.boto_session, endpoint):
desc = _wait_until(lambda: _deploy_done(self.sagemaker_client, endpoint), poll)
else:
cloudwatch_client = self.boto_session.client("logs")
paginator = cloudwatch_client.get_paginator("filter_log_events")
paginator_config = create_paginator_config()
desc = _wait_until(
lambda: _live_logging_deploy_done(
self.sagemaker_client, endpoint, paginator, paginator_config, EP_LOGGER_POLL
),
poll=EP_LOGGER_POLL,
)
status = desc["EndpointStatus"]
if status != "InService":
reason = desc.get("FailureReason", None)
trouble_shooting = (
"Try changing the instance type or reference the troubleshooting page "
"https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-troubleshooting"
".html"
)
message = "Error hosting endpoint {}: {}. Reason: {}. {}".format(
endpoint, status, reason, trouble_shooting
)
if "CapacityError" in str(reason):
raise exceptions.CapacityError(
message=message,
allowed_statuses=["InService"],
actual_status=status,
)
raise exceptions.UnexpectedStatusException(
message=message,
allowed_statuses=["InService"],
actual_status=status,
)
return desc
[docs]
def create_inference_component(
self,
inference_component_name: str,
endpoint_name: str,
variant_name: str,
specification: Dict[str, Any],
runtime_config: Optional[Dict[str, Any]] = None,
tags: Optional[Tags] = None,
wait: bool = True,
):
"""Create an Amazon SageMaker Inference Component.
Args:
inference_component_name (str): Name of the Amazon SageMaker inference component
to create.
endpoint_name (str): Name of the Amazon SageMaker endpoint that the inference component
will deploy to.
variant_name (str): Name of the Amazon SageMaker variant that the inference component
will deploy to.
specification (Dict[str, Any]): The inference component specification.
runtime_config (Optional[Dict[str, Any]]): Optional. The inference component
runtime configuration. (Default: None).
tags (Optional[Tags]): Optional. Either a dictionary or a list
of dictionaries containing key-value pairs. (Default: None).
wait (bool) : Optional. Wait for the inference component to finish being created before
returning a value. (Default: True).
Returns:
str: Name of the Amazon SageMaker ``InferenceComponent`` if created.
"""
LOGGER.info(
"Creating inference component with name %s for endpoint %s",
inference_component_name,
endpoint_name,
)
if runtime_config is None:
runtime_config = {"CopyCount": 1}
request = {
"InferenceComponentName": inference_component_name,
"EndpointName": endpoint_name,
"VariantName": variant_name,
"Specification": specification,
"RuntimeConfig": runtime_config,
}
tags = format_tags(tags)
tags = _append_project_tags(tags)
tags = self._append_sagemaker_config_tags(
tags, "{}.{}.{}".format(SAGEMAKER, INFERENCE_COMPONENT, TAGS)
)
if tags and len(tags) != 0:
request["Tags"] = tags
self.sagemaker_client.create_inference_component(**request)
if wait:
self.wait_for_inference_component(inference_component_name)
return inference_component_name
[docs]
def wait_for_inference_component(self, inference_component_name, poll=20):
"""Wait for an Amazon SageMaker ``Inference Component`` deployment to complete.
Args:
inference_component_name (str): Name of the ``Inference Component`` to wait for.
poll (int): Polling interval in seconds (default: 20).
Raises:
exceptions.CapacityError: If the inference component creation fails with CapacityError.
exceptions.UnexpectedStatusException: If the inference component creation fails.
Returns:
dict: Return value from the ``DescribeInferenceComponent`` API.
"""
desc = _wait_until(
lambda: self._inference_component_done(self.sagemaker_client, inference_component_name),
poll,
)
status = desc["InferenceComponentStatus"]
if status != "InService":
message = f"Error creating inference component '{inference_component_name}'"
reason = desc.get("FailureReason")
if reason:
message = f"{message}: {reason}"
if "CapacityError" in str(reason):
raise exceptions.CapacityError(
message=message,
allowed_statuses=["InService"],
actual_status=status,
)
raise exceptions.UnexpectedStatusException(
message=message,
allowed_statuses=["InService"],
actual_status=status,
)
return desc
[docs]
def describe_inference_component(self, inference_component_name):
"""Describe an Amazon SageMaker ``InferenceComponent``
Args:
inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
Returns:
dict[str,str]: Inference component details.
"""
return self.sagemaker_client.describe_inference_component(
InferenceComponentName=inference_component_name
)
def _inference_component_done(self, sagemaker_client, inference_component_name):
"""Check if creation of inference component is done.
Args:
sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker
service calls
inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
Returns:
dict[str,str]: Inference component details.
"""
create_inference_component_codes = {
"InService": "!",
"Creating": "-",
"Updating": "-",
"Failed": "*",
"Deleting": "o",
}
in_progress_statuses = ["Creating", "Updating", "Deleting"]
desc = sagemaker_client.describe_inference_component(
InferenceComponentName=inference_component_name
)
status = desc["InferenceComponentStatus"]
print(create_inference_component_codes.get(status, "?"), end="", flush=True)
return None if status in in_progress_statuses else desc
[docs]
def update_endpoint(self, endpoint_name, endpoint_config_name, wait=True):
"""Update an Amazon SageMaker ``Endpoint`` , Raise an error endpoint_name does not exist.
Args:
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to update.
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to
deploy.
wait (bool): Whether to wait for the endpoint deployment to complete before returning
(default: True).
Returns:
str: Name of the Amazon SageMaker ``Endpoint`` being updated.
Raises:
- ValueError: if the endpoint does not already exist
- botocore.exceptions.ClientError: If SageMaker throws an error while
creating endpoint config, describing endpoint or updating endpoint
"""
if not _deployment_entity_exists(
lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
):
raise ValueError(
"Endpoint with name '{}' does not exist; please use an "
"existing endpoint name".format(endpoint_name)
)
try:
res = self.sagemaker_client.update_endpoint(
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
if res:
self.endpoint_arn = res["EndpointArn"]
if wait:
self.wait_for_endpoint(endpoint_name)
return endpoint_name
except Exception as e:
troubleshooting = (
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
"sagemaker-python-sdk-troubleshooting.html"
"#sagemaker-python-sdk-troubleshooting-update-endpoint"
)
logger.error(
"Please check the troubleshooting guide for common errors: %s", troubleshooting
)
raise e
[docs]
def endpoint_in_service_or_not(self, endpoint_name: str):
"""Check whether an Amazon SageMaker ``Endpoint``` is in IN_SERVICE status.
Raise any exception that is not recognized as "not found".
Args:
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to
check status.
Returns:
bool: True if ``Endpoint`` is IN_SERVICE, False if ``Endpoint`` not exists
or it's in other status.
Raises:
"""
try:
desc = self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
status = desc["EndpointStatus"]
if status == "InService":
return True
return False
except botocore.exceptions.ClientError as e:
str_err = str(e).lower()
if "could not find" in str_err or "not found" in str_err:
return False
raise
def _intercept_create_request(
self,
request: Dict,
create,
func_name: str = None,
# pylint: disable=unused-argument
):
"""This function intercepts the create job request.
PipelineSession inherits this Session class and will override
this function to intercept the create request.
Args:
request (dict): the create job request
create (functor): a functor calls the sagemaker client create method
func_name (str): the name of the function needed intercepting
"""
return create(request)
def _create_inference_recommendations_job_request(
self,
role: str,
job_name: str,
job_description: str,
framework: str,
sample_payload_url: str,
supported_content_types: List[str],
tags: Optional[Tags],
model_name: str = None,
model_package_version_arn: str = None,
job_duration_in_seconds: int = None,
job_type: str = "Default",
framework_version: str = None,
nearest_model_name: str = None,
supported_instance_types: List[str] = None,
endpoint_configurations: List[Dict[str, Any]] = None,
traffic_pattern: Dict[str, Any] = None,
stopping_conditions: Dict[str, Any] = None,
resource_limit: Dict[str, Any] = None,
) -> Dict[str, Any]:
"""Get request dictionary for CreateInferenceRecommendationsJob API.
Args:
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
jobs and APIs that create Amazon SageMaker endpoints use this role to access
training data and model artifacts.
You must grant sufficient permissions to this role.
job_name (str): The name of the Inference Recommendations Job.
job_description (str): A description of the Inference Recommendations Job.
framework (str): The machine learning framework of the Image URI.
sample_payload_url (str): The S3 path where the sample payload is stored.
supported_content_types (List[str]): The supported MIME types for the input data.
model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
versioned model package.
job_duration_in_seconds (int): The maximum job duration that a job
can run for. Will be used for `Advanced` jobs.
job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
framework_version (str): The framework version of the Image URI.
nearest_model_name (str): The name of a pre-trained machine learning model
benchmarked by Amazon SageMaker Inference Recommender that matches your model.
supported_instance_types (List[str]): A list of the instance types that are used
to generate inferences in real-time.
tags (Optional[Tags]): Tags used to identify where
the Inference Recommendatons Call was made from.
endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
to use for a job. Will be used for `Advanced` jobs.
traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
Will be used for `Advanced` jobs.
stopping_conditions (Dict[str, any]): A set of conditions for stopping a
recommendation job.
If any of the conditions are met, the job is automatically stopped.
Will be used for `Advanced` jobs.
resource_limit (Dict[str, any]): Defines the resource limit for the job.
Will be used for `Advanced` jobs.
Returns:
Dict[str, Any]: request dictionary for the CreateInferenceRecommendationsJob API
"""
containerConfig = {
"Domain": "MACHINE_LEARNING",
"Task": "OTHER",
"Framework": framework,
"PayloadConfig": {
"SamplePayloadUrl": sample_payload_url,
"SupportedContentTypes": supported_content_types,
},
}
if framework_version:
containerConfig["FrameworkVersion"] = framework_version
if nearest_model_name:
containerConfig["NearestModelName"] = nearest_model_name
if supported_instance_types:
containerConfig["SupportedInstanceTypes"] = supported_instance_types
request = {
"JobName": job_name,
"JobType": job_type,
"RoleArn": role,
"InputConfig": {
"ContainerConfig": containerConfig,
},
"Tags": format_tags(tags),
}
request.get("InputConfig").update(
{"ModelPackageVersionArn": model_package_version_arn}
if model_package_version_arn
else {"ModelName": model_name}
)
if job_description:
request["JobDescription"] = job_description
if job_duration_in_seconds:
request["InputConfig"]["JobDurationInSeconds"] = job_duration_in_seconds
if job_type == "Advanced":
if stopping_conditions:
request["StoppingConditions"] = stopping_conditions
if resource_limit:
request["InputConfig"]["ResourceLimit"] = resource_limit
if traffic_pattern:
request["InputConfig"]["TrafficPattern"] = traffic_pattern
if endpoint_configurations:
request["InputConfig"]["EndpointConfigurations"] = endpoint_configurations
return request
[docs]
def create_inference_recommendations_job(
self,
role: str,
sample_payload_url: str,
supported_content_types: List[str],
job_name: str = None,
job_type: str = "Default",
model_name: str = None,
model_package_version_arn: str = None,
job_duration_in_seconds: int = None,
nearest_model_name: str = None,
supported_instance_types: List[str] = None,
framework: str = None,
framework_version: str = None,
endpoint_configurations: List[Dict[str, any]] = None,
traffic_pattern: Dict[str, any] = None,
stopping_conditions: Dict[str, any] = None,
resource_limit: Dict[str, any] = None,
):
"""Creates an Inference Recommendations Job
Args:
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
jobs and APIs that create Amazon SageMaker endpoints use this role to access
training data and model artifactså.
You must grant sufficient permissions to this role.
sample_payload_url (str): The S3 path where the sample payload is stored.
supported_content_types (List[str]): The supported MIME types for the input data.
model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
versioned model package.
job_name (str): The name of the job being run.
job_type (str): The type of job being run. Must either be `Default` or `Advanced`.
job_duration_in_seconds (int): The maximum job duration that a job
can run for. Will be used for `Advanced` jobs.
nearest_model_name (str): The name of a pre-trained machine learning model
benchmarked by Amazon SageMaker Inference Recommender that matches your model.
supported_instance_types (List[str]): A list of the instance types that are used
to generate inferences in real-time.
framework (str): The machine learning framework of the Image URI.
framework_version (str): The framework version of the Image URI.
endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations
to use for a job. Will be used for `Advanced` jobs.
traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job.
Will be used for `Advanced` jobs.
stopping_conditions (Dict[str, any]): A set of conditions for stopping a
recommendation job.
If any of the conditions are met, the job is automatically stopped.
Will be used for `Advanced` jobs.
resource_limit (Dict[str, any]): Defines the resource limit for the job.
Will be used for `Advanced` jobs.
Returns:
str: The name of the job created. In the form of `SMPYTHONSDK-<timestamp>`
"""
if model_name is None and model_package_version_arn is None:
raise ValueError("Please provide either model_name or model_package_version_arn.")
if model_name is not None and model_package_version_arn is not None:
raise ValueError("Please provide either model_name or model_package_version_arn.")
if not job_name:
unique_tail = uuid.uuid4()
job_name = "SMPYTHONSDK-" + str(unique_tail)
job_description = "#python-sdk-create"
tags = [{"Key": "ClientType", "Value": "PythonSDK-RightSize"}]
create_inference_recommendations_job_request = (
self._create_inference_recommendations_job_request(
role=role,
model_name=model_name,
model_package_version_arn=model_package_version_arn,
job_name=job_name,
job_type=job_type,
job_duration_in_seconds=job_duration_in_seconds,
job_description=job_description,
framework=framework,
framework_version=framework_version,
nearest_model_name=nearest_model_name,
sample_payload_url=sample_payload_url,
supported_content_types=supported_content_types,
supported_instance_types=supported_instance_types,
endpoint_configurations=endpoint_configurations,
traffic_pattern=traffic_pattern,
stopping_conditions=stopping_conditions,
resource_limit=resource_limit,
tags=tags,
)
)
def submit(request):
logger.info("Creating Inference Recommendations job with name: %s", job_name)
logger.debug("process request: %s", json.dumps(request, indent=4))
self.sagemaker_client.create_inference_recommendations_job(**request)
self._intercept_create_request(
create_inference_recommendations_job_request,
submit,
self.create_inference_recommendations_job.__name__,
)
return job_name
[docs]
def wait_for_inference_recommendations_job(
self, job_name: str, poll: int = 120, log_level: str = "Verbose"
) -> Dict[str, Any]:
"""Wait for an Amazon SageMaker Inference Recommender job to complete.
Args:
job_name (str): Name of the Inference Recommender job to wait for.
poll (int): Polling interval in seconds (default: 120).
log_level (str): The level of verbosity for the logs.
Can be "Quiet" or "Verbose" (default: "Quiet").
Returns:
(dict): Return value from the ``DescribeInferenceRecommendationsJob`` API.
Raises:
exceptions.CapacityError: If the Inference Recommender job fails with CapacityError.
exceptions.UnexpectedStatusException: If the Inference Recommender job fails.
"""
if log_level == "Quiet":
_wait_until(
lambda: _describe_inference_recommendations_job_status(
self.sagemaker_client, job_name
),
poll,
)
elif log_level == "Verbose":
_display_inference_recommendations_job_steps_status(
self, self.sagemaker_client, job_name
)
else:
raise ValueError("log_level must be either Quiet or Verbose")
desc = _describe_inference_recommendations_job_status(self.sagemaker_client, job_name)
_check_job_status(job_name, desc, "Status")
return desc
[docs]
def delete_model(self, model_name):
"""Delete an Amazon SageMaker Model.
Args:
model_name (str): Name of the Amazon SageMaker model to delete.
"""
logger.info("Deleting model with name: %s", model_name)
self.sagemaker_client.delete_model(ModelName=model_name)
[docs]
def delete_endpoint(self, endpoint_name):
"""Delete an Amazon SageMaker ``Endpoint``.
Args:
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to delete.
"""
logger.info("Deleting endpoint with name: %s", endpoint_name)
self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
[docs]
def delete_endpoint_config(self, endpoint_config_name):
"""Delete an Amazon SageMaker endpoint configuration.
Args:
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to
delete.
"""
logger.info("Deleting endpoint configuration with name: %s", endpoint_config_name)
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
[docs]
def wait_for_optimization_job(self, job, poll=5):
"""Wait for an Amazon SageMaker Optimization job to complete.
Args:
job (str): Name of optimization job to wait for.
poll (int): Polling interval in seconds (default: 5).
Returns:
(dict): Return value from the ``DescribeOptimizationJob`` API.
Raises:
exceptions.ResourceNotFound: If optimization job fails with CapacityError.
exceptions.UnexpectedStatusException: If optimization job fails.
"""
desc = _wait_until(lambda: _optimization_job_status(self.sagemaker_client, job), poll)
_check_job_status(job, desc, "OptimizationJobStatus")
return desc
[docs]
def update_inference_component(
self, inference_component_name, specification=None, runtime_config=None, wait=True
):
"""Update an Amazon SageMaker ``InferenceComponent``
Args:
inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent``.
specification ([dict[str,int]]): Resource configuration. Optional.
Example: {
"MinMemoryRequiredInMb": 1024,
"NumberOfCpuCoresRequired": 1,
"NumberOfAcceleratorDevicesRequired": 1,
"MaxMemoryRequiredInMb": 4096,
},
runtime_config ([dict[str,int]]): Number of copies. Optional.
Default: {
"copyCount": 1
}
wait: Wait for inference component to be created before return. Optional. Default is
True.
Return:
str: inference component name
Raises:
ValueError: If the inference_component_name does not exist.
"""
if not _deployment_entity_exists(
lambda: self.sagemaker_client.describe_inference_component(
InferenceComponentName=inference_component_name
)
):
raise ValueError(
"InferenceComponent with name '{}' does not exist; please use an "
"existing model name".format(inference_component_name)
)
request = {
"InferenceComponentName": inference_component_name,
"Specification": specification,
"RuntimeConfig": runtime_config,
}
self.sagemaker_client.update_inference_component(**request)
if wait:
self.wait_for_inference_component(inference_component_name)
return inference_component_name
def _create_model_request(
self,
name,
role,
container_defs,
vpc_config=None,
enable_network_isolation=False,
primary_container=None,
tags=None,
): # pylint: disable=redefined-outer-name
"""Placeholder docstring"""
if container_defs and primary_container:
raise ValueError("Both container_defs and primary_container can not be passed as input")
if primary_container:
msg = (
"primary_container is going to be deprecated in a future release. Please use "
"container_defs instead."
)
warnings.warn(msg, DeprecationWarning)
container_defs = primary_container
role = self.expand_role(role)
if isinstance(container_defs, list):
update_list_of_dicts_with_values_from_config(
container_defs, MODEL_CONTAINERS_PATH, sagemaker_session=self
)
container_definition = container_defs
else:
container_definition = _expand_container_def(container_defs)
container_definition = update_nested_dictionary_with_values_from_config(
container_definition, MODEL_PRIMARY_CONTAINER_PATH, sagemaker_session=self
)
request = {"ModelName": name, "ExecutionRoleArn": role}
if isinstance(container_definition, list):
request["Containers"] = container_definition
elif "ModelPackageName" in container_definition:
request["Containers"] = [container_definition]
else:
request["PrimaryContainer"] = container_definition
if tags:
request["Tags"] = format_tags(tags)
if vpc_config:
request["VpcConfig"] = vpc_config
if enable_network_isolation:
# enable_network_isolation may be a pipeline variable which is
# parsed in execution time
request["EnableNetworkIsolation"] = enable_network_isolation
return request
[docs]
def create_model(
self,
name,
role=None,
container_defs=None,
vpc_config=None,
enable_network_isolation=None,
primary_container=None,
tags=None,
):
"""Create an Amazon SageMaker ``Model``.
Specify the S3 location of the model artifacts and Docker image containing
the inference code. Amazon SageMaker uses this information to deploy the
model in Amazon SageMaker. This method can also be used to create a Model for an Inference
Pipeline if you pass the list of container definitions through the containers parameter.
Args:
name (str): Name of the Amazon SageMaker ``Model`` to create.
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
jobs and APIs that create Amazon SageMaker endpoints use this role to access
training data and model artifacts. You must grant sufficient permissions to this
role.
container_defs (list[dict[str, str]] or [dict[str, str]]): A single container
definition or a list of container definitions which will be invoked sequentially
while performing the prediction. If the list contains only one container, then
it'll be passed to SageMaker Hosting as the ``PrimaryContainer`` and otherwise,
it'll be passed as ``Containers``.You can also specify the return value of
``sagemaker.get_container_def()`` or ``sagemaker.pipeline_container_def()``,
which will used to create more advanced container configurations, including model
containers which need artifacts from S3.
vpc_config (dict[str, list[str]]): The VpcConfig set on the model (default: None)
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
enable_network_isolation (bool): Whether the model requires network isolation or not.
primary_container (str or dict[str, str]): Docker image which defines the inference
code. You can also specify the return value of ``sagemaker.container_def()``,
which is used to create more advanced container configurations, including model
containers which need artifacts from S3. This field is deprecated, please use
container_defs instead.
tags(Optional[Tags]): Optional. The list of tags to add to the model.
Example:
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
Returns:
str: Name of the Amazon SageMaker ``Model`` created.
"""
tags = _append_project_tags(format_tags(tags))
tags = self._append_sagemaker_config_tags(tags, "{}.{}.{}".format(SAGEMAKER, MODEL, TAGS))
role = resolve_value_from_config(
role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self
)
vpc_config = resolve_value_from_config(
vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self
)
enable_network_isolation = resolve_value_from_config(
direct_input=enable_network_isolation,
config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH,
default_value=False,
sagemaker_session=self,
)
# Due to ambuiguity in container_defs which accepts both a single
# container definition(dtype: dict) and a list of container definitions (dtype: list),
# we need to inject environment variables into the container_defs in the helper function
# _create_model_request.
create_model_request = self._create_model_request(
name=name,
role=role,
container_defs=container_defs,
vpc_config=vpc_config,
enable_network_isolation=enable_network_isolation,
primary_container=primary_container,
tags=tags,
)
def submit(request):
logger.info("Creating model with name: %s", name)
logger.debug("CreateModel request: %s", json.dumps(request, indent=4))
try:
self.sagemaker_client.create_model(**request)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if (
error_code == "ValidationException"
and "Cannot create already existing model" in message
):
logger.warning("Using already existing model: %s", name)
else:
raise
self._intercept_create_request(create_model_request, submit, self.create_model.__name__)
return name
[docs]
def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data):
"""Create a SageMaker Model Package from the results of training with an Algorithm Package.
Args:
name (str): ModelPackage name
description (str): Model Package description
algorithm_arn (str): arn or name of the algorithm used for training.
model_data (str or dict[str, Any]): s3 URI or a dictionary representing a
``ModelDataSource`` to the model artifacts produced by training
"""
sourceAlgorithm = {"AlgorithmName": algorithm_arn}
if isinstance(model_data, dict):
sourceAlgorithm["ModelDataSource"] = model_data
else:
sourceAlgorithm["ModelDataUrl"] = model_data
request = {
"ModelPackageName": name,
"ModelPackageDescription": description,
"SourceAlgorithmSpecification": {"SourceAlgorithms": [sourceAlgorithm]},
}
try:
logger.info("Creating model package with name: %s", name)
self.sagemaker_client.create_model_package(**request)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "ValidationException" and "ModelPackage already exists" in message:
logger.warning("Using already existing model package: %s", name)
else:
raise
[docs]
def expand_role(self, role):
"""Expand an IAM role name into an ARN.
If the role is already in the form of an ARN, then the role is simply returned. Otherwise
we retrieve the full ARN and return it.
Args:
role (str): An AWS IAM role (either name or full ARN).
Returns:
str: The corresponding AWS IAM role ARN.
"""
if "/" in role:
return role
return self.boto_session.resource("iam").Role(role).arn
def _expand_container_def(c_def):
"""Placeholder docstring"""
if isinstance(c_def, six.string_types):
return container_def(c_def)
return c_def
[docs]
def expand_role(self, role):
"""Expand an IAM role name into an ARN.
If the role is already in the form of an ARN, then the role is simply returned. Otherwise
we retrieve the full ARN and return it.
Args:
role (str): An AWS IAM role (either name or full ARN).
Returns:
str: The corresponding AWS IAM role ARN.
"""
if "/" in role:
return role
return self.boto_session.resource("iam").Role(role).arn
[docs]
def s3_path_join(*args, with_end_slash: bool = False):
"""Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix).
Behavior of this function:
- If the first argument is "s3://", then that is preserved.
- The output by default will have no slashes at the beginning or end. There is one exception
(see `with_end_slash`). For example, `s3_path_join("/foo", "bar/")` will yield
`"foo/bar"` and `s3_path_join("foo", "bar", with_end_slash=True)` will yield `"foo/bar/"`
- Any repeat slashes will be removed in the output (except for "s3://" if provided at the
beginning). For example, `s3_path_join("s3://", "//foo/", "/bar///baz")` will yield
`"s3://foo/bar/baz"`.
- Empty or None arguments will be skipped. For example
`s3_path_join("foo", "", None, "bar")` will yield `"foo/bar"`
Alternatives to this function that are NOT recommended for S3 paths:
- `os.path.join(...)` will have different behavior on Unix machines vs non-Unix machines
- `pathlib.PurePosixPath(...)` will apply potentially unintended simplification of single
dots (".") and root directories. (for example
`pathlib.PurePosixPath("foo", "/bar/./", "baz")` would yield `"/bar/baz"`)
- `"{}/{}/{}".format(...)` and similar may result in unintended repeat slashes
Args:
*args: The strings to join with a slash.
with_end_slash (bool): (default: False) If true and if the path is not empty, appends a "/"
to the end of the path
Returns:
str: The joined string, without a slash at the end unless with_end_slash is True.
"""
delimiter = "/"
non_empty_args = list(filter(lambda item: item is not None and item != "", args))
merged_path = ""
for index, path in enumerate(non_empty_args):
if (
index == 0
or (merged_path and merged_path[-1] == delimiter)
or (path and path[0] == delimiter)
):
# dont need to add an extra slash because either this is the beginning of the string,
# or one (or more) slash already exists
merged_path += path
else:
merged_path += delimiter + path
if with_end_slash and merged_path and merged_path[-1] != delimiter:
merged_path += delimiter
# At this point, merged_path may include slashes at the beginning and/or end. And some of the
# provided args may have had duplicate slashes inside or at the ends.
# For backwards compatibility reasons, these need to be filtered out (done below). In the
# future, if there is a desire to support multiple slashes for S3 paths throughout the SDK,
# one option is to create a new optional argument (or a new function) that only executes the
# logic above.
filtered_path = merged_path
# remove duplicate slashes
if filtered_path:
def duplicate_delimiter_remover(sequence, next_char):
if sequence[-1] == delimiter and next_char == delimiter:
return sequence
return sequence + next_char
if filtered_path.startswith("s3://"):
filtered_path = reduce(
duplicate_delimiter_remover, filtered_path[5:], filtered_path[:5]
)
else:
filtered_path = reduce(duplicate_delimiter_remover, filtered_path)
# remove beginning slashes
filtered_path = filtered_path.lstrip(delimiter)
# remove end slashes
if not with_end_slash and filtered_path != "s3://":
filtered_path = filtered_path.rstrip(delimiter)
return filtered_path
[docs]
def botocore_resolver():
"""Get the DNS suffix for the given region.
Args:
region (str): AWS region name
Returns:
str: the DNS suffix
"""
loader = botocore.loaders.create_loader()
return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
[docs]
def sts_regional_endpoint(region):
"""Get the AWS STS endpoint specific for the given region.
We need this function because the AWS SDK does not yet honor
the ``region_name`` parameter when creating an AWS STS client.
For the list of regional endpoints, see
https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
Args:
region (str): AWS region name
Returns:
str: AWS STS regional endpoint
"""
endpoint_data = botocore_resolver().construct_endpoint("sts", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
return "https://{}".format(endpoint_data["hostname"])
[docs]
def get_execution_role(sagemaker_session=None, use_default=False):
"""Return the role ARN whose credentials are used to call the API.
Throws an exception if role doesn't exist.
Args:
sagemaker_session (Session): Current sagemaker session.
use_default (bool): Use a default role if ``get_caller_identity_arn`` does not
return a correct role. This default role will be created if needed.
Defaults to ``False``.
Returns:
(str): The role ARN
"""
if not sagemaker_session:
sagemaker_session = Session()
arn = sagemaker_session.get_caller_identity_arn()
if ":role/" in arn:
return arn
if use_default:
default_role_name = "AmazonSageMaker-DefaultRole"
LOGGER.warning("Using default role: %s", default_role_name)
boto3_session = sagemaker_session.boto_session
permissions_policy = json.dumps(
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {"Service": ["sagemaker.amazonaws.com"]},
"Action": "sts:AssumeRole",
}
],
}
)
iam_client = boto3_session.client("iam")
try:
iam_client.get_role(RoleName=default_role_name)
except iam_client.exceptions.NoSuchEntityException:
iam_client.create_role(
RoleName=default_role_name, AssumeRolePolicyDocument=str(permissions_policy)
)
LOGGER.warning("Created new sagemaker execution role: %s", default_role_name)
iam_client.attach_role_policy(
PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
RoleName=default_role_name,
)
return iam_client.get_role(RoleName=default_role_name)["Role"]["Arn"]
message = (
"The current AWS identity is not a role: {}, therefore it cannot be used as a "
"SageMaker execution role"
)
raise ValueError(message.format(arn))
[docs]
def get_add_model_package_inference_args(
model_package_arn,
name,
containers=None,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
description=None,
):
"""Get request dictionary for UpdateModelPackage API for additional inference.
Args:
model_package_arn (str): Arn for the model package.
name (str): Name to identify the additional inference specification
containers (dict): The Amazon ECR registry path of the Docker image
that contains the inference code.
image_uris (List[str]): The ECR path where inference code is stored.
description (str): Description for the additional inference specification
content_types (list[str]): The supported MIME types
for the input data.
response_types (list[str]): The supported MIME types
for the output data.
inference_instances (list[str]): A list of the instance
types that are used to generate inferences in real-time (default: None).
transform_instances (list[str]): A list of the instance
types on which a transformation job can be run or on which an endpoint can be
deployed (default: None).
"""
request_dict = {}
if containers is not None:
inference_specification = {
"Containers": containers,
}
if name is not None:
inference_specification.update({"Name": name})
if description is not None:
inference_specification.update({"Description": description})
if content_types is not None:
inference_specification.update(
{
"SupportedContentTypes": content_types,
}
)
if response_types is not None:
inference_specification.update(
{
"SupportedResponseMIMETypes": response_types,
}
)
if inference_instances is not None:
inference_specification.update(
{
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
}
)
if transform_instances is not None:
inference_specification.update(
{
"SupportedTransformInstanceTypes": transform_instances,
}
)
request_dict["AdditionalInferenceSpecificationsToAdd"] = [inference_specification]
request_dict.update({"ModelPackageArn": model_package_arn})
return request_dict
[docs]
def get_update_model_package_inference_args(
model_package_arn,
containers=None,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
):
"""Get request dictionary for UpdateModelPackage API for inference specification.
Args:
model_package_arn (str): Arn for the model package.
containers (dict): The Amazon ECR registry path of the Docker image
that contains the inference code.
content_types (list[str]): The supported MIME types
for the input data.
response_types (list[str]): The supported MIME types
for the output data.
inference_instances (list[str]): A list of the instance
types that are used to generate inferences in real-time (default: None).
transform_instances (list[str]): A list of the instance
types on which a transformation job can be run or on which an endpoint can be
deployed (default: None).
"""
request_dict = {}
if containers is not None:
inference_specification = {
"Containers": containers,
}
if content_types is not None:
inference_specification.update(
{
"SupportedContentTypes": content_types,
}
)
if response_types is not None:
inference_specification.update(
{
"SupportedResponseMIMETypes": response_types,
}
)
if inference_instances is not None:
inference_specification.update(
{
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
}
)
if transform_instances is not None:
inference_specification.update(
{
"SupportedTransformInstanceTypes": transform_instances,
}
)
request_dict["InferenceSpecification"] = inference_specification
request_dict.update({"ModelPackageArn": model_package_arn})
return request_dict
def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None
):
"""Display logs for a given training job, optionally tailing them until job is complete.
If the output is a tty or a Jupyter cell, it will be color-coded
based on which instance the log entry is from.
Args:
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions.
job_name (str): Name of the training job to display the logs for.
wait (bool): Whether to keep looking for new log entries until the job completes
(default: False).
poll (int): The interval in seconds between polling for new log entries and job
completion (default: 5).
log_type ([str]): A list of strings specifying which logs to print. Acceptable
strings are "All", "None", "Training", or "Rules". To maintain backwards
compatibility, boolean values are also accepted and converted to strings.
timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by
default.
Returns:
Last call to sagemaker DescribeTrainingJob
Raises:
exceptions.CapacityError: If the training job fails with CapacityError.
exceptions.UnexpectedStatusException: If waiting and the training job fails.
"""
sagemaker_client = sagemaker_session.sagemaker_client
request_end_time = time.time() + timeout if timeout else None
description = _wait_until(
lambda: sagemaker_client.describe_training_job(TrainingJobName=job_name)
)
print(secondary_training_status_message(description, None), end="")
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
sagemaker_session.boto_session, description, job="Training"
)
state = _get_initial_job_state(description, "TrainingJobStatus", wait)
# The loop below implements a state machine that alternates between checking the job status
# and reading whatever is available in the logs at this point. Note, that if we were
# called with wait == False, we never check the job status.
#
# If wait == TRUE and job is not completed, the initial state is TAILING
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
# complete).
#
# The state table:
#
# STATE ACTIONS CONDITION NEW STATE
# ---------------- ---------------- ----------------- ----------------
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
# Else TAILING
# JOB_COMPLETE Read logs, Pause Any COMPLETE
# COMPLETE Read logs, Exit N/A
#
# Notes:
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
# Cloudwatch after the job was marked complete.
last_describe_job_call = time.time()
last_description = description
last_debug_rule_statuses = None
last_profiler_rule_statuses = None
while True:
_flush_log_streams(
stream_names,
instance_count,
client,
log_group,
job_name,
positions,
dot,
color_wrap,
)
if timeout and time.time() > request_end_time:
print("Timeout Exceeded. {} seconds elapsed.".format(timeout))
break
if state == LogState.COMPLETE:
break
time.sleep(poll)
if state == LogState.JOB_COMPLETE:
state = LogState.COMPLETE
elif time.time() - last_describe_job_call >= 30:
description = sagemaker_client.describe_training_job(TrainingJobName=job_name)
last_describe_job_call = time.time()
if secondary_training_status_changed(description, last_description):
print()
print(secondary_training_status_message(description, last_description), end="")
last_description = description
status = description["TrainingJobStatus"]
if status in ("Completed", "Failed", "Stopped"):
print()
state = LogState.JOB_COMPLETE
# Print prettified logs related to the status of SageMaker Debugger rules.
debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {})
if (
debug_rule_statuses
and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses)
and (log_type in {"All", "Rules"})
):
for status in debug_rule_statuses:
rule_log = (
f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
)
print(rule_log)
last_debug_rule_statuses = debug_rule_statuses
# Print prettified logs related to the status of SageMaker Profiler rules.
profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {})
if (
profiler_rule_statuses
and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses)
and (log_type in {"All", "Rules"})
):
for status in profiler_rule_statuses:
rule_log = (
f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
)
print(rule_log)
last_profiler_rule_statuses = profiler_rule_statuses
if wait:
_check_job_status(job_name, description, "TrainingJobStatus")
if dot:
print()
# Customers are not billed for hardware provisioning, so billable time is less than
# total time
training_time = description.get("TrainingTimeInSeconds")
billable_time = description.get("BillableTimeInSeconds")
if training_time is not None:
print("Training seconds:", training_time * instance_count)
if billable_time is not None:
print("Billable seconds:", billable_time * instance_count)
if description.get("EnableManagedSpotTraining"):
saving = (1 - float(billable_time) / training_time) * 100
print("Managed Spot Training savings: {:.1f}%".format(saving))
return last_description
def _check_job_status(job, desc, status_key_name):
"""Check to see if the job completed successfully.
If not, construct and raise a exceptions. (UnexpectedStatusException).
Args:
job (str): The name of the job to check.
desc (dict[str, str]): The result of ``describe_training_job()``.
status_key_name (str): Status key name to check for.
Raises:
exceptions.CapacityError: If the training job fails with CapacityError.
exceptions.UnexpectedStatusException: If the training job fails.
"""
status = desc[status_key_name]
# If the status is capital case, then convert it to Camel case
status = _STATUS_CODE_TABLE.get(status, status)
if status == "Stopped":
logger.warning(
"Job ended with status 'Stopped' rather than 'Completed'. "
"This could mean the job timed out or stopped early for some other reason: "
"Consider checking whether it completed as you expect."
)
elif status != "Completed":
reason = desc.get("FailureReason", "(No reason provided)")
job_type = status_key_name.replace("JobStatus", " job")
troubleshooting = (
"https://docs.aws.amazon.com/sagemaker/latest/dg/"
"sagemaker-python-sdk-troubleshooting.html"
)
message = (
"Error for {job_type} {job_name}: {status}. Reason: {reason}. "
"Check troubleshooting guide for common errors: {troubleshooting}"
).format(
job_type=job_type,
job_name=job,
status=status,
reason=reason,
troubleshooting=troubleshooting,
)
if "CapacityError" in str(reason):
raise exceptions.CapacityError(
message=message,
allowed_statuses=["Completed", "Stopped"],
actual_status=status,
)
raise exceptions.UnexpectedStatusException(
message=message,
allowed_statuses=["Completed", "Stopped"],
actual_status=status,
)
def _logs_init(boto_session, description, job):
"""Placeholder docstring"""
if job == "Training":
if "InstanceGroups" in description["ResourceConfig"]:
instance_count = 0
for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
instance_count += instanceGroup["InstanceCount"]
else:
instance_count = description["ResourceConfig"]["InstanceCount"]
elif job == "Transform":
instance_count = description["TransformResources"]["InstanceCount"]
elif job == "Processing":
instance_count = description["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
elif job == "AutoML":
instance_count = 0
stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
# to be interrupted by a transient exception.
config = botocore.config.Config(retries={"max_attempts": 15})
client = boto_session.client("logs", config=config)
log_group = "/aws/sagemaker/" + job + "Jobs"
dot = False
color_wrap = sagemaker.core.logs.ColorWrap()
return instance_count, stream_names, positions, client, log_group, dot, color_wrap
def _flush_log_streams(
stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
):
"""Placeholder docstring"""
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so this list
# may be dynamic until we have a stream for every instance.
try:
streams = client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + "/",
orderBy="LogStreamName",
limit=min(instance_count, 50),
)
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
while "nextToken" in streams:
streams = client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + "/",
orderBy="LogStreamName",
limit=50,
)
stream_names.extend([s["logStreamName"] for s in streams["logStreams"]])
positions.update(
[
(s, sagemaker.core.logs.Position(timestamp=0, skip=0))
for s in stream_names
if s not in positions
]
)
except ClientError as e:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
err = e.response.get("Error", {})
if err.get("Code", None) != "ResourceNotFoundException":
raise
if len(stream_names) > 0:
if dot:
print("")
dot = False
for idx, event in sagemaker.core.logs.multi_stream_iter(
client, log_group, stream_names, positions
):
color_wrap(idx, event["message"])
ts, count = positions[stream_names[idx]]
if event["timestamp"] == ts:
positions[stream_names[idx]] = sagemaker.core.logs.Position(
timestamp=ts, skip=count + 1
)
else:
positions[stream_names[idx]] = sagemaker.core.logs.Position(
timestamp=event["timestamp"], skip=1
)
else:
dot = True
print(".", end="")
sys.stdout.flush()
def _wait_until(callable_fn, poll=5):
"""Placeholder docstring"""
elapsed_time = 0
result = None
while result is None:
try:
elapsed_time += poll
time.sleep(poll)
result = callable_fn()
except botocore.exceptions.ClientError as err:
# For initial 5 mins we accept/pass AccessDeniedException.
# The reason is to await tag propagation to avoid false AccessDenied claims for an
# access policy based on resource tags, The caveat here is for true AccessDenied
# cases the routine will fail after 5 mins
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
logger.warning(
"Received AccessDeniedException. This could mean the IAM role does not "
"have the resource permissions, in which case please add resource access "
"and retry. For cases where the role has tag based resource policy, "
"continuing to wait for tag propagation.."
)
continue
raise err
return result
def _get_initial_job_state(description, status_key, wait):
"""Placeholder docstring"""
status = description[status_key]
job_already_completed = status in ("Completed", "Failed", "Stopped")
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
def _rule_statuses_changed(current_statuses, last_statuses):
"""Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules."""
if not last_statuses:
return True
for current, last in zip(current_statuses, last_statuses):
if (current["RuleConfigurationName"] == last["RuleConfigurationName"]) and (
current["RuleEvaluationStatus"] != last["RuleEvaluationStatus"]
):
return True
return False
[docs]
def update_args(args: Dict[str, Any], **kwargs):
"""Updates the request arguments dict with the value if populated.
This is to handle the case that the service API doesn't like NoneTypes for argument values.
Args:
request_args (Dict[str, Any]): the request arguments dict
kwargs: key, value pairs to update the args dict
"""
for key, value in kwargs.items():
if value is not None:
args.update({key: value})
[docs]
def production_variant(
model_name=None,
instance_type=None,
initial_instance_count=None,
variant_name="AllTraffic",
initial_weight=1,
accelerator_type=None,
serverless_inference_config=None,
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
managed_instance_scaling=None,
routing_config=None,
inference_ami_version=None,
):
"""Create a production variant description suitable for use in a ``ProductionVariant`` list.
This is also part of a ``CreateEndpointConfig`` request.
Args:
model_name (str): The name of the SageMaker model this production variant references.
instance_type (str): The EC2 instance type for this production variant. For example,
'ml.c4.8xlarge'.
initial_instance_count (int): The initial instance count for this production variant
(default: 1).
variant_name (string): The ``VariantName`` of this production variant
(default: 'AllTraffic').
initial_weight (int): The relative ``InitialVariantWeight`` of this production variant
(default: 1).
accelerator_type (str): Type of Elastic Inference accelerator for this production variant.
For example, 'ml.eia1.medium'.
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
serverless_inference_config (dict): Specifies configuration dict related to serverless
endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig
object (default: None)
volume_size (int): The size, in GB, of the ML storage volume attached to individual
inference instance associated with the production variant. Currenly only Amazon EBS
gp2 storage volumes are supported.
model_data_download_timeout (int): The timeout value, in seconds, to download and extract
model data from Amazon S3 to the individual inference instance associated with this
production variant.
container_startup_health_check_timeout (int): The timeout value, in seconds, for your
inference container to pass health check by SageMaker Hosting. For more information
about health check see:
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html
#your-algorithms
-inference-algo-ping-requests
Returns:
dict[str, str]: An SageMaker ``ProductionVariant`` description
"""
production_variant_configuration = {
"VariantName": variant_name,
}
if model_name:
production_variant_configuration["ModelName"] = model_name
production_variant_configuration["InitialVariantWeight"] = initial_weight
if accelerator_type:
production_variant_configuration["AcceleratorType"] = accelerator_type
if managed_instance_scaling:
production_variant_configuration["ManagedInstanceScaling"] = managed_instance_scaling
if serverless_inference_config:
production_variant_configuration["ServerlessConfig"] = serverless_inference_config
else:
initial_instance_count = initial_instance_count or 1
production_variant_configuration["InitialInstanceCount"] = initial_instance_count
production_variant_configuration["InstanceType"] = instance_type
update_args(
production_variant_configuration,
VolumeSizeInGB=volume_size,
ModelDataDownloadTimeoutInSeconds=model_data_download_timeout,
ContainerStartupHealthCheckTimeoutInSeconds=container_startup_health_check_timeout,
RoutingConfig=routing_config,
)
if inference_ami_version:
production_variant_configuration["InferenceAmiVersion"] = inference_ami_version
return production_variant_configuration
def _has_permission_for_live_logging(boto_session, endpoint_name) -> bool:
"""Validate if customer's role has the right permission to access logs from CloudWatch"""
try:
cloudwatch_client = boto_session.client("logs")
cloudwatch_client.filter_log_events(
logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}",
logStreamNamePrefix="AllTraffic/",
)
return True
except ClientError as e:
if e.response["Error"]["Code"] == "AccessDeniedException":
LOGGER.warning(
("Failed to enable live logging: %s. Fallback to default logging..."),
e,
)
return False
return True
def _deploy_done(sagemaker_client, endpoint_name):
"""Placeholder docstring"""
hosting_status_codes = {
"OutOfService": "x",
"Creating": "-",
"Updating": "-",
"InService": "!",
"RollingBack": "<",
"Deleting": "o",
"Failed": "*",
}
in_progress_statuses = ["Creating", "Updating"]
desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
status = desc["EndpointStatus"]
print(hosting_status_codes.get(status, "?"), end="")
sys.stdout.flush()
return None if status in in_progress_statuses else desc
def _live_logging_deploy_done(sagemaker_client, endpoint_name, paginator, paginator_config, poll):
"""Placeholder docstring"""
stop = False
endpoint_status = None
try:
desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
endpoint_status = desc["EndpointStatus"]
except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException":
LOGGER.debug("Waiting for endpoint to become visible")
return None
raise e
try:
# if endpoint is in an invalid state -> set stop to true, sleep, and flush the logs
if endpoint_status != "Creating":
stop = True
if endpoint_status == "InService":
LOGGER.info("Created endpoint with name %s. Waiting for it to be InService", endpoint_name)
else:
time.sleep(poll)
pages = paginator.paginate(
logGroupName=f"/aws/sagemaker/Endpoints/{endpoint_name}",
logStreamNamePrefix="AllTraffic/",
PaginationConfig=paginator_config,
)
for page in pages:
if "nextToken" in page:
paginator_config["StartingToken"] = page["nextToken"]
for event in page["events"]:
LOGGER.info(event["message"])
else:
LOGGER.debug("No log events available")
# if stop is true -> return the describe response and stop polling
if stop:
return desc
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
LOGGER.debug("Waiting for endpoint log group to appear")
return None
raise e
return None
def _deployment_entity_exists(describe_fn):
"""Placeholder docstring"""
try:
describe_fn()
return True
except ClientError as ce:
error_code = ce.response["Error"]["Code"]
if not (
error_code == "ValidationException"
and "Could not find" in ce.response["Error"]["Message"]
):
raise ce
return False
[docs]
def get_log_events_for_inference_recommender(cw_client, log_group_name, log_stream_name):
"""Retrieves log events from the specified CloudWatch log group and log stream.
Args:
cw_client (boto3.client): A boto3 CloudWatch client.
log_group_name (str): The name of the CloudWatch log group.
log_stream_name (str): The name of the CloudWatch log stream.
Returns:
(dict): A dictionary containing log events from CloudWatch log group and log stream.
"""
print("Fetching logs from CloudWatch...", flush=True)
for _ in retries(
max_retry_count=30, # 30*10 = 5min
exception_message_prefix="Waiting for cloudwatch stream to appear. ",
seconds_to_sleep=10,
):
try:
return cw_client.get_log_events(
logGroupName=log_group_name, logStreamName=log_stream_name
)
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
pass
def _describe_inference_recommendations_job_status(sagemaker_client, job_name: str):
"""Describes the status of a job and returns the job description.
Args:
sagemaker_client (boto3.client.sagemaker): A SageMaker client.
job_name (str): The name of the job.
Returns:
dict: The job description, or None if the job is still in progress.
"""
inference_recommendations_job_status_codes = {
"PENDING": ".",
"IN_PROGRESS": ".",
"COMPLETED": "!",
"FAILED": "*",
"STOPPING": "_",
"STOPPED": "s",
}
in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"}
desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name)
status = desc["Status"]
print(inference_recommendations_job_status_codes.get(status, "?"), end="", flush=True)
if status in in_progress_statuses:
return None
print("")
return desc
def _display_inference_recommendations_job_steps_status(
sagemaker_session, sagemaker_client, job_name: str, poll: int = 60
):
"""Placeholder docstring"""
cloudwatch_client = sagemaker_session.boto_session.client("logs")
in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"}
log_group_name = "/aws/sagemaker/InferenceRecommendationsJobs"
log_stream_name = job_name + "/execution"
initial_logs_batch = get_log_events_for_inference_recommender(
cloudwatch_client, log_group_name, log_stream_name
)
print(f"Retrieved logStream: {log_stream_name} from logGroup: {log_group_name}", flush=True)
events = initial_logs_batch["events"]
print(*[event["message"] for event in events], sep="\n", flush=True)
next_forward_token = initial_logs_batch["nextForwardToken"] if events else None
flush_remaining = True
while True:
logs_batch = (
cloudwatch_client.get_log_events(
logGroupName=log_group_name,
logStreamName=log_stream_name,
nextToken=next_forward_token,
)
if next_forward_token
else cloudwatch_client.get_log_events(
logGroupName=log_group_name, logStreamName=log_stream_name
)
)
events = logs_batch["events"]
desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name)
status = desc["Status"]
if not events:
if status in in_progress_statuses:
time.sleep(poll)
continue
if flush_remaining:
flush_remaining = False
time.sleep(poll)
continue
next_forward_token = logs_batch["nextForwardToken"]
print(*[event["message"] for event in events], sep="\n", flush=True)
if status not in in_progress_statuses:
break
time.sleep(poll)
def _optimization_job_status(sagemaker_client, job_name):
"""Placeholder docstring"""
optimization_job_status_codes = {
"INPROGRESS": ".",
"COMPLETED": "!",
"FAILED": "*",
"STARTING": ".",
"STOPPING": "_",
"STOPPED": "s",
}
in_progress_statuses = ["INPROGRESS", "STARTING", "STOPPING"]
desc = sagemaker_client.describe_optimization_job(OptimizationJobName=job_name)
status = desc["OptimizationJobStatus"]
print(optimization_job_status_codes.get(status, "?"), end="")
sys.stdout.flush()
if status in in_progress_statuses:
return None
print("")
return desc
[docs]
def container_def(
image_uri,
model_data_url=None,
env=None,
container_mode=None,
image_config=None,
accept_eula=None,
additional_model_data_sources=None,
model_reference_arn=None,
):
"""Create a definition for executing a container as part of a SageMaker model.
Args:
image_uri (str): Docker image URI to run for this container.
model_data_url (str or dict[str, Any]): S3 location of model data required by this
container, e.g. SageMaker training job model artifacts. It can either be a string
representing S3 URI of model data, or a dictionary representing a
``ModelDataSource`` object. (default: None).
env (dict[str, str]): Environment variables to set inside the container (default: None).
container_mode (str): The model container mode. Valid modes:
* MultiModel: Indicates that model container can support hosting multiple models
* SingleModel: Indicates that model container can support hosting a single model
This is the default model container mode when container_mode = None
image_config (dict[str, str]): Specifies whether the image of model container is pulled
from ECR, or private registry in your VPC. By default it is set to pull model
container image from ECR. (default: None).
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
additional_model_data_sources (PipelineVariable or dict): Additional location
of SageMaker model data (default: None).
Returns:
dict[str, str]: A complete container definition object usable with the CreateModel API if
passed via `PrimaryContainers` field.
"""
if env is None:
env = {}
c_def = {"Image": image_uri, "Environment": env}
if additional_model_data_sources:
c_def["AdditionalModelDataSources"] = additional_model_data_sources
if isinstance(model_data_url, str) and (
not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz"))
or accept_eula is None
):
c_def["ModelDataUrl"] = model_data_url
elif isinstance(model_data_url, (dict, str)):
if isinstance(model_data_url, dict):
c_def["ModelDataSource"] = model_data_url
else:
c_def["ModelDataSource"] = {
"S3DataSource": {
"S3Uri": model_data_url,
"S3DataType": "S3Object",
"CompressionType": "Gzip",
}
}
if accept_eula is not None:
c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = {
"AcceptEula": accept_eula
}
if model_reference_arn:
c_def["ModelDataSource"]["S3DataSource"]["HubAccessConfig"] = {
"HubContentArn": model_reference_arn
}
elif model_data_url is not None:
c_def["ModelDataUrl"] = model_data_url
if container_mode:
c_def["Mode"] = container_mode
if image_config:
c_def["ImageConfig"] = image_config
return c_def