# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
from __future__ import absolute_import
import json
import logging
import os
from typing import Optional
from packaging.version import Version
import requests
from sagemaker.core.serverless_inference_config import ServerlessInferenceConfig
from sagemaker.core.training_compiler_config import TrainingCompilerConfig
from sagemaker.core.fw_utils import (
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
GRAVITON_ALLOWED_FRAMEWORKS,
)
from sagemaker.core.common_utils import _botocore_resolver, get_instance_type_family
logger = logging.getLogger(__name__)
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
HUGGING_FACE_FRAMEWORK = "huggingface"
HUGGING_FACE_LLM_FRAMEWORK = "huggingface-llm"
HUGGING_FACE_TEI_GPU_FRAMEWORK = "huggingface-tei"
HUGGING_FACE_TEI_CPU_FRAMEWORK = "huggingface-tei-cpu"
HUGGING_FACE_LLM_NEURONX_FRAMEWORK = "huggingface-llm-neuronx"
XGBOOST_FRAMEWORK = "xgboost"
SKLEARN_FRAMEWORK = "sklearn"
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
INFERENCE_GRAVITON = "inference_graviton"
DATA_WRANGLER_FRAMEWORK = "data-wrangler"
STABILITYAI_FRAMEWORK = "stabilityai"
SAGEMAKER_TRITONSERVER_FRAMEWORK = "sagemaker-tritonserver"
def _get_image_tag(
container_version,
distributed,
final_image_scope,
framework,
inference_tool,
instance_type,
processor,
py_version,
tag_prefix,
version,
):
"""Return image tag based on framework, container, and compute configuration(s)."""
instance_type_family = get_instance_type_family(instance_type)
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
if instance_type_family and final_image_scope == INFERENCE_GRAVITON:
_validate_arg(
instance_type_family,
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
"instance type",
)
if (
instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
or final_image_scope == INFERENCE_GRAVITON
):
version_to_arm64_tag_mapping = {
"xgboost": {
"1.5-1": "1.5-1-arm64",
"1.3-1": "1.3-1-arm64",
},
"sklearn": {
"1.0-1": "1.0-1-arm64-cpu-py3",
},
}
tag = version_to_arm64_tag_mapping[framework][version]
else:
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
else:
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
if instance_type is not None and _should_auto_select_container_version(
instance_type, distributed
):
container_versions = {
"tensorflow-2.3-gpu-py37": "cu110-ubuntu18.04-v3",
"tensorflow-2.3.1-gpu-py37": "cu110-ubuntu18.04",
"tensorflow-2.3.2-gpu-py37": "cu110-ubuntu18.04",
"tensorflow-1.15-gpu-py37": "cu110-ubuntu18.04-v8",
"tensorflow-1.15.4-gpu-py37": "cu110-ubuntu18.04",
"tensorflow-1.15.5-gpu-py37": "cu110-ubuntu18.04",
"mxnet-1.8-gpu-py37": "cu110-ubuntu16.04-v1",
"mxnet-1.8.0-gpu-py37": "cu110-ubuntu16.04",
"pytorch-1.6-gpu-py36": "cu110-ubuntu18.04-v3",
"pytorch-1.6.0-gpu-py36": "cu110-ubuntu18.04",
"pytorch-1.6-gpu-py3": "cu110-ubuntu18.04-v3",
"pytorch-1.6.0-gpu-py3": "cu110-ubuntu18.04",
}
key = "-".join([framework, tag])
if key in container_versions:
tag = "-".join([tag, container_versions[key]])
# Triton images don't have a trailing -gpu tag. Only -cpu images do.
if framework == SAGEMAKER_TRITONSERVER_FRAMEWORK:
if processor == "gpu":
tag = tag.rstrip("-gpu")
return tag
def _config_for_framework_and_scope(framework, image_scope, accelerator_type=None):
"""Loads the JSON config for the given framework and image scope."""
config = config_for_framework(framework)
if accelerator_type:
_validate_accelerator_type(accelerator_type)
if image_scope not in ("eia", "inference"):
logger.warning(
"Elastic inference is for inference only. Ignoring image scope: %s.",
image_scope,
)
image_scope = "eia"
available_scopes = config.get("scope", list(config.keys()))
if len(available_scopes) == 1:
if image_scope and image_scope != available_scopes[0]:
logger.warning(
"Defaulting to only supported image scope: %s. Ignoring image scope: %s.",
available_scopes[0],
image_scope,
)
image_scope = available_scopes[0]
if not image_scope and "scope" in config and set(available_scopes) == {"training", "inference"}:
logger.info(
"Same images used for training and inference. Defaulting to image scope: %s.",
available_scopes[0],
)
image_scope = available_scopes[0]
_validate_arg(image_scope, available_scopes, "image scope")
return config if "scope" in config else config[image_scope]
def _validate_instance_deprecation(framework, instance_type, version):
"""Check if instance type is deprecated for a certain framework with a certain version"""
if get_instance_type_family(instance_type) == "p2":
if (framework == "pytorch" and Version(version) >= Version("1.13")) or (
framework == "tensorflow" and Version(version) >= Version("2.12")
):
raise ValueError(
"P2 instances have been deprecated for sagemaker jobs starting PyTorch 1.13 and TensorFlow 2.12"
"For information about supported instance types please refer to "
"https://aws.amazon.com/sagemaker/pricing/"
)
def _validate_for_suppported_frameworks_and_instance_type(framework, instance_type):
"""Validate if framework is supported for the instance_type"""
# Validate for Trainium allowed frameworks
if (
instance_type is not None
and "trn" in instance_type
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
):
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework", "Trainium")
# Validate for Graviton allowed frameowrks
if (
instance_type is not None
and get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
):
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")
[docs]
def config_for_framework(framework):
"""Loads the JSON config for the given framework."""
fname = os.path.join(os.path.dirname(__file__), "..", "image_uri_config", "{}.json".format(framework))
with open(fname) as f:
return json.load(f)
def _get_final_image_scope(framework, instance_type, image_scope):
"""Return final image scope based on provided framework and instance type."""
if (
framework in GRAVITON_ALLOWED_FRAMEWORKS
and get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
):
return INFERENCE_GRAVITON
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
# Preserves backwards compatibility with XGB/SKLearn configs which no
# longer define top-level "scope" keys after introducing support for
# Graviton inference. Training and inference configs for XGB/SKLearn are
# identical, so default to training.
return "training"
return image_scope
def _get_inference_tool(inference_tool, instance_type):
"""Extract the inference tool name from instance type."""
if not inference_tool:
instance_type_family = get_instance_type_family(instance_type)
if instance_type_family.startswith("inf") or instance_type_family.startswith("trn"):
return "neuron"
return inference_tool
def _get_latest_versions(list_of_versions):
"""Extract the latest version from the input list of available versions."""
return sorted(list_of_versions, reverse=True)[0]
def _get_latest_version(framework, version, image_scope):
"""Get the latest version from the input framework"""
if version:
return version
try:
framework_config = config_for_framework(framework)
except FileNotFoundError:
raise ValueError("Invalid framework {}".format(framework))
if not framework_config:
raise ValueError("Invalid framework {}".format(framework))
if not version:
version = _fetch_latest_version_from_config(framework_config, image_scope)
return version
def _validate_accelerator_type(accelerator_type):
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook":
raise ValueError(
"Invalid SageMaker Elastic Inference accelerator type: {}. "
"See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html".format(accelerator_type)
)
def _validate_version_and_set_if_needed(version, config, framework, image_scope):
"""Checks if the framework/algorithm version is one of the supported versions."""
if not config:
config = config_for_framework(framework)
available_versions = list(config["versions"].keys())
aliased_versions = list(config.get("version_aliases", {}).keys())
if len(available_versions) == 1 and version not in aliased_versions:
return available_versions[0]
if not version:
version = _get_latest_version(framework, version, image_scope)
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))
return version
def _version_for_config(version, config):
"""Returns the version string for retrieving a framework version's specific config."""
if "version_aliases" in config:
if version in config["version_aliases"].keys():
return config["version_aliases"][version]
return version
def _registry_from_region(region, registry_dict):
"""Returns the ECR registry (AWS account number) for the given region."""
_validate_arg(region, registry_dict.keys(), "region")
return registry_dict[region]
def _processor(instance_type, available_processors, serverless_inference_config=None):
"""Returns the processor type for the given instance type."""
if not available_processors:
logger.info("Ignoring unnecessary instance type: %s.", instance_type)
return None
if len(available_processors) == 1 and not instance_type:
logger.info("Defaulting to only supported image scope: %s.", available_processors[0])
return available_processors[0]
if serverless_inference_config is not None:
logger.info("Defaulting to CPU type when using serverless inference")
return "cpu"
if not instance_type:
raise ValueError(
"Empty SageMaker instance type. For options, see: "
"https://aws.amazon.com/sagemaker/pricing/instance-types"
)
if instance_type.startswith("local"):
processor = "cpu" if instance_type == "local" else "gpu"
elif instance_type.startswith("neuron"):
processor = "neuron"
else:
# looks for either "ml.<family>.<size>" or "ml_<family>"
family = get_instance_type_family(instance_type)
if family:
# For some frameworks, we have optimized images for specific families, e.g c5 or p3.
# In those cases, we use the family name in the image tag. In other cases, we use
# 'cpu' or 'gpu'.
if family in available_processors:
processor = family
elif family.startswith("inf"):
processor = "inf"
elif family.startswith("trn"):
processor = "trn"
elif family[0] in ("g", "p"):
processor = "gpu"
else:
processor = "cpu"
else:
raise ValueError(
"Invalid SageMaker instance type: {}. For options, see: "
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
)
_validate_arg(processor, available_processors, "processor")
return processor
def _should_auto_select_container_version(instance_type, distributed):
"""Returns a boolean that indicates whether to use an auto-selected container version."""
p4d = False
if instance_type:
# looks for either "ml.<family>.<size>" or "ml_<family>"
family = get_instance_type_family(instance_type)
if family:
p4d = family == "p4d"
return p4d or distributed
def _validate_py_version_and_set_if_needed(py_version, version_config, framework):
"""Checks if the Python version is one of the supported versions."""
if "repository" in version_config:
available_versions = version_config.get("py_versions")
else:
available_versions = list(version_config.keys())
if not available_versions:
if py_version:
logger.info("Ignoring unnecessary Python version: %s.", py_version)
return None
if py_version is None and defaults.SPARK_NAME == framework:
return None
if py_version is None and len(available_versions) == 1:
logger.info("Defaulting to only available Python version: %s", available_versions[0])
return available_versions[0]
_validate_arg(py_version, available_versions, "Python version")
return py_version
def _validate_arg(arg, available_options, arg_name):
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
if arg not in available_options:
raise ValueError(
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
"(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
"{options}.".format(arg_name=arg_name, arg=arg, options=", ".join(available_options))
)
def _validate_framework(framework, allowed_frameworks, arg_name, hardware_name):
"""Checks if the framework is in the allowed frameworks, and raises a ``ValueError`` if not."""
if framework not in allowed_frameworks:
raise ValueError(
f"Unsupported {arg_name}: {framework}. "
f"Supported {arg_name}(s) for {hardware_name} instances: {allowed_frameworks}."
)
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
"""Creates a tag for the image URI."""
if inference_tool:
return "-".join(x for x in (tag_prefix, inference_tool, py_version, container_version) if x)
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)
def _fetch_latest_version_from_config( # pylint: disable=R0911
framework_config: dict, image_scope: Optional[str] = None
) -> Optional[str]:
"""Helper function to fetch the latest version as a string from a framework's config
Args:
framework_config (dict): A framework config dict.
image_scope (str): Scope of the image, eg: training, inference
Returns:
Version string if latest version found else None
"""
if image_scope in framework_config:
if image_scope_config := framework_config[image_scope]:
if "version_aliases" in image_scope_config:
if "latest" in image_scope_config["version_aliases"]:
return image_scope_config["version_aliases"]["latest"]
top_version = None
bottom_version = None
if "versions" in framework_config:
versions = list(framework_config["versions"].keys())
if len(versions) == 1:
return versions[0]
top_version = versions[0]
bottom_version = versions[-1]
if top_version == "latest" or bottom_version == "latest":
return None
elif (
image_scope is not None
and image_scope in framework_config
and "versions" in framework_config[image_scope]
):
versions = list(framework_config[image_scope]["versions"].keys())
top_version = versions[0]
bottom_version = versions[-1]
elif "processing" in framework_config and "versions" in framework_config["processing"]:
versions = list(framework_config["processing"]["versions"].keys())
top_version = versions[0]
bottom_version = versions[-1]
if top_version and bottom_version:
if top_version.endswith(".x") or bottom_version.endswith(".x"):
top_number = int(top_version[:-2])
bottom_number = int(bottom_version[:-2])
max_version = max(top_number, bottom_number)
return f"{max_version}.x"
if Version(top_version) >= Version(bottom_version):
return top_version
return bottom_version
return None
def _retrieve_pytorch_uri_inputs_are_all_default(
version: Optional[str] = None,
py_version: Optional[str] = None,
instance_type: Optional[str] = None,
accelerator_type: Optional[str] = None,
image_scope: Optional[str] = None,
container_version: str = None,
distributed: bool = False,
smp: bool = False,
training_compiler_config: TrainingCompilerConfig = None,
sdk_version: Optional[str] = None,
inference_tool: Optional[str] = None,
serverless_inference_config: ServerlessInferenceConfig = None,
) -> bool:
"""
Determine if the inputs for _retrieve_pytorch_uri() are all default values.
"""
return (
not version
and not py_version
and not instance_type
and not accelerator_type
and not image_scope
and not container_version
and not distributed
and not smp
and not training_compiler_config
and not sdk_version
and not inference_tool
and not serverless_inference_config
)
def _retrieve_latest_pytorch_training_uri(region: str):
"""
Retrive the URI for the latest PyTorch training image for CPU
"""
config = config_for_framework("pytorch")
image_scope = "training"
latest_version = _fetch_latest_version_from_config(config, image_scope)
version_config = config[image_scope]["versions"][latest_version]
py_version = _validate_py_version_and_set_if_needed(None, version_config, None)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
registry = _registry_from_region(region, version_config["registries"])
hostname = endpoint_data["hostname"]
repo = version_config["repository"]
tag = _get_image_tag(
container_version="ec2",
distributed=False,
final_image_scope=image_scope,
framework="pytorch",
inference_tool="",
instance_type="",
processor="cpu",
py_version=py_version,
tag_prefix=latest_version,
version=latest_version,
)
if tag:
repo += ":{}".format(tag)
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)