Source code for sagemaker.core.image_retriever.image_retriever

import re
from typing import Optional
from graphene.utils.str_converters import to_camel_case

from sagemaker.core.inference_config import ServerlessInferenceConfig
from sagemaker.core.training_compiler.config import TrainingCompilerConfig
from sagemaker.core.common_utils import _botocore_resolver
from sagemaker.core.workflow import is_pipeline_variable
from sagemaker.core.image_retriever.image_retriever_utils import (
    _config_for_framework_and_scope,
    _get_final_image_scope,
    _get_image_tag,
    _get_inference_tool,
    _get_latest_versions,
    _processor,
    _registry_from_region,
    _retrieve_latest_pytorch_training_uri,
    _retrieve_pytorch_uri_inputs_are_all_default,
    _validate_arg,
    _validate_for_suppported_frameworks_and_instance_type,
    _validate_instance_deprecation,
    _validate_py_version_and_set_if_needed,
    _validate_version_and_set_if_needed,
    _version_for_config,
    config_for_framework,
)
from sagemaker.core.workflow.utilities import override_pipeline_parameter_var
from sagemaker.core.config.config_schema import IMAGE_RETRIEVER, MODULES, SAGEMAKER, _simple_path
from sagemaker.core.config.config_manager import SageMakerConfig

ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
HUGGING_FACE_FRAMEWORK = "huggingface"
PYTORCH_FRAMEWORK = "pytorch"

CONFIGURABLE_ATTRIBUTES = [
    "version",
    "py_version",
    "instance_type",
    "accelerator_type",
    "image_scope",
    "container_version",
    "distributed",
    "smp",
    "base_framework_version",
    "training_compiler_config",
    "model_id",
    "model_version",
    "sdk_version",
    "inference_tool",
    "serverless_inference_config",
]


[docs] class ImageRetriever:
[docs] @staticmethod def retrieve_hugging_face_uri( region: str, version: Optional[str] = None, py_version: Optional[str] = None, instance_type: Optional[str] = None, accelerator_type: Optional[str] = None, image_scope: Optional[str] = None, container_version: str = None, distributed: bool = False, base_framework_version: Optional[str] = None, training_compiler_config: TrainingCompilerConfig = None, sdk_version: Optional[str] = None, inference_tool: Optional[str] = None, serverless_inference_config: ServerlessInferenceConfig = None, ): """Retrieves the ECR URI for the Docker image of HuggingFace models. Args: region (str): The AWS region. version (Optional[str]): The framework or algorithm version. This is required if there is more than one supported version for the given framework or algorithm. py_version (Optional[str]): The Python version. This is required if there is more than one supported Python version for the given framework version. instance_type (Optional[str]): The SageMaker instance type. For supported types, see https://aws.amazon.com/sagemaker/pricing. This is required if there are different images for different processor types. accelerator_type (Optional[str]): Elastic Inference accelerator type. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html. image_scope (Optional[str]): The image type, i.e. what it is used for. Valid values: "training", "inference", "inference_graviton", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. container_version (Optional[str]): the version of docker image. Ideally the value of parameter should be created inside the framework. For custom use, see the list of supported container versions: https://github.com/aws/deep-learning-containers/blob/master/available_images.md (default: None). distributed (bool): If the image is for running distributed training. base_framework_version (Optional[str]): The base version number of PyTorch or Tensorflow. (default: None). training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler (default: None). sdk_version (Optional[str]): the version of python-sdk that will be used in the image retrieval. (default: None). inference_tool (Optional[str]): the tool that will be used to aid in the inference. Valid values: "neuron, neuronx, None" (default: None). serverless_inference_config (sagemaker.serve.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to determine processor type.. Returns: str: The ECR URI for the corresponding SageMaker Docker image. """ args = dict(locals()) for name, val in args.items(): if name in CONFIGURABLE_ATTRIBUTES and not val: default_value = SageMakerConfig.resolve_value_from_config( config_path=_simple_path( SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name) ) ) if default_value is not None: locals()[name] = default_value if training_compiler_config: final_image_scope = image_scope config = _config_for_framework_and_scope( HUGGING_FACE_FRAMEWORK + "-training-compiler", final_image_scope, accelerator_type ) else: _framework = HUGGING_FACE_FRAMEWORK inference_tool = _get_inference_tool(inference_tool, instance_type) if inference_tool in ["neuron", "neuronx"]: _framework = f"{HUGGING_FACE_FRAMEWORK}-{inference_tool}" final_image_scope = _get_final_image_scope( HUGGING_FACE_FRAMEWORK, instance_type, image_scope ) _validate_for_suppported_frameworks_and_instance_type( HUGGING_FACE_FRAMEWORK, instance_type ) config = _config_for_framework_and_scope( _framework, final_image_scope, accelerator_type ) if base_framework_version is not None: processor = _processor(instance_type, ["cpu", "gpu"]) is_native_huggingface_gpu = processor == "gpu" and not training_compiler_config container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None original_version = version version = _validate_version_and_set_if_needed( version, config, HUGGING_FACE_FRAMEWORK, image_scope ) version_config = config["versions"][_version_for_config(version, config)] if version_config.get("version_aliases"): full_base_framework_version = version_config["version_aliases"].get( base_framework_version, base_framework_version ) _validate_arg(full_base_framework_version, list(version_config.keys()), "base framework") version_config = version_config.get(full_base_framework_version) py_version = _validate_py_version_and_set_if_needed( py_version, version_config, HUGGING_FACE_FRAMEWORK ) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} hostname = endpoint_data["hostname"] repo = version_config["repository"] processor = _processor( instance_type, config.get("processors") or version_config.get("processors"), serverless_inference_config, ) # if container version is available in .json file, utilize that if version_config.get("container_version"): container_version = version_config["container_version"][processor] # Append sdk version in case of trainium instances if repo in ["pytorch-training-neuron", "pytorch-training-neuronx"]: if not sdk_version: sdk_version = _get_latest_versions(version_config["sdk_versions"]) container_version = sdk_version + "-" + container_version pt_or_tf_version = ( re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2) ) _version = original_version if repo in [ "huggingface-pytorch-trcomp-training", "huggingface-tensorflow-trcomp-training", ]: _version = version if repo in [ "huggingface-pytorch-inference-neuron", "huggingface-pytorch-inference-neuronx", ]: if not sdk_version: sdk_version = _get_latest_versions(version_config["sdk_versions"]) container_version = sdk_version + "-" + container_version if config.get("version_aliases").get(original_version): _version = config.get("version_aliases")[original_version] if ( config.get("versions", {}) .get(_version, {}) .get("version_aliases", {}) .get(base_framework_version, {}) ): _base_framework_version = config.get("versions")[_version]["version_aliases"][ base_framework_version ] pt_or_tf_version = ( re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2) ) tag_prefix = f"{pt_or_tf_version}-transformers{_version}" if repo == f"{HUGGING_FACE_FRAMEWORK}-inference-graviton": container_version = f"{container_version}-sagemaker" _validate_instance_deprecation(HUGGING_FACE_FRAMEWORK, instance_type, version) tag = _get_image_tag( container_version, distributed, final_image_scope, HUGGING_FACE_FRAMEWORK, inference_tool, instance_type, processor, py_version, tag_prefix, version, ) if tag: repo += ":{}".format(tag) return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
[docs] @staticmethod def retrieve_jumpstart_uri(): """Retrieves the container image URI for JumpStart models. TODO: We are changing the Jumpstart utilities and moving from S3 based metadata to API based metadata. Implement this function after utility changes are finished. """
[docs] @staticmethod def retrieve_pytorch_uri( region: str, version: Optional[str] = None, py_version: Optional[str] = None, instance_type: Optional[str] = None, accelerator_type: Optional[str] = None, image_scope: Optional[str] = None, container_version: str = None, distributed: bool = False, smp: bool = False, training_compiler_config: TrainingCompilerConfig = None, sdk_version: Optional[str] = None, inference_tool: Optional[str] = None, serverless_inference_config: ServerlessInferenceConfig = None, ): """Retrieves the ECR URI for the Docker image of PyTorch models. Args: region (str): The AWS region. version (Optional[str]): The framework or algorithm version. This is required if there is more than one supported version for the given framework or algorithm. py_version (Optional[str]): The Python version. This is required if there is more than one supported Python version for the given framework version. instance_type (Optional[str]): The SageMaker instance type. For supported types, see https://aws.amazon.com/sagemaker/pricing. This is required if there are different images for different processor types. accelerator_type (Optional[str]): Elastic Inference accelerator type. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html. image_scope (Optional[str]): The image type, i.e. what it is used for. Valid values: "training", "inference", "inference_graviton", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. container_version (Optional[str]): the version of docker image. Ideally the value of parameter should be created inside the framework. For custom use, see the list of supported container versions: https://github.com/aws/deep-learning-containers/blob/master/available_images.md (default: None). distributed (bool): If the image is for running distributed training. smp (bool): If using the SMP library for distributed training. training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler (default: None). sdk_version (Optional[str]): the version of python-sdk that will be used in the image retrieval. (default: None). inference_tool (Optional[str]): the tool that will be used to aid in the inference. Valid values: "neuron, neuronx, None" (default: None). serverless_inference_config (sagemaker.serve.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to determine processor type. Returns: str: The ECR URI for the corresponding SageMaker Docker image. """ if _retrieve_pytorch_uri_inputs_are_all_default( version, py_version, instance_type, accelerator_type, image_scope, container_version, distributed, smp, training_compiler_config, sdk_version, inference_tool, serverless_inference_config, ): # if all inputs are not set, return the latest pytorch training uri return _retrieve_latest_pytorch_training_uri(region) framework = PYTORCH_FRAMEWORK if training_compiler_config: final_image_scope = image_scope config = _config_for_framework_and_scope( PYTORCH_FRAMEWORK + "-training-compiler", final_image_scope, accelerator_type ) else: _framework = PYTORCH_FRAMEWORK inference_tool = _get_inference_tool(inference_tool, instance_type) if inference_tool in ["neuron", "neuronx"]: _framework = f"{PYTORCH_FRAMEWORK}-{inference_tool}" final_image_scope = _get_final_image_scope( PYTORCH_FRAMEWORK, instance_type, image_scope ) _validate_for_suppported_frameworks_and_instance_type(PYTORCH_FRAMEWORK, instance_type) config = _config_for_framework_and_scope( _framework, final_image_scope, accelerator_type ) version = _validate_version_and_set_if_needed( version, config, PYTORCH_FRAMEWORK, image_scope ) version_config = config["versions"][_version_for_config(version, config)] if distributed and smp: framework = "pytorch-smp" supported_smp_pt_versions_cu124 = ("2.5",) supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4") if any(pt_version in version for pt_version in supported_smp_pt_versions_cu124): container_version = "cu124" elif "p5" in instance_type or any( pt_version in version for pt_version in supported_smp_pt_versions_cu121 ): container_version = "cu121" else: container_version = "cu118" py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} hostname = endpoint_data["hostname"] repo = version_config["repository"] processor = _processor( instance_type, config.get("processors") or version_config.get("processors"), serverless_inference_config, ) # if container version is available in .json file, utilize that if version_config.get("container_version"): container_version = version_config["container_version"][processor] # Append sdk version in case of trainium instances if repo in ["pytorch-training-neuron", "pytorch-training-neuronx"]: if not sdk_version: sdk_version = _get_latest_versions(version_config["sdk_versions"]) container_version = sdk_version + "-" + container_version tag_prefix = version_config.get("tag_prefix", version) if repo == f"{framework}-inference-graviton": container_version = f"{container_version}-sagemaker" _validate_instance_deprecation(framework, instance_type, version) tag = _get_image_tag( container_version, distributed, final_image_scope, framework, inference_tool, instance_type, processor, py_version, tag_prefix, version, ) if tag: repo += ":{}".format(tag) return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
[docs] @staticmethod @override_pipeline_parameter_var def retrieve( framework: str, region: str, version: Optional[str] = None, py_version: Optional[str] = None, instance_type: Optional[str] = None, accelerator_type: Optional[str] = None, image_scope: Optional[str] = None, container_version: str = None, distributed: bool = False, smp: bool = False, base_framework_version: Optional[str] = None, training_compiler_config: TrainingCompilerConfig = None, model_id: Optional[str] = None, model_version: Optional[str] = None, sdk_version: Optional[str] = None, inference_tool: Optional[str] = None, serverless_inference_config: ServerlessInferenceConfig = None, ): """Retrieves the ECR URI for the Docker image matching the given arguments. Args: framework (str): The name of the framework or algorithm. region (str): The AWS region. version (Optional[str]): The framework or algorithm version. This is required if there is more than one supported version for the given framework or algorithm. py_version (Optional[str]): The Python version. This is required if there is more than one supported Python version for the given framework version. instance_type (Optional[str]): The SageMaker instance type. For supported types, see https://aws.amazon.com/sagemaker/pricing. This is required if there are different images for different processor types. accelerator_type (Optional[str]): Elastic Inference accelerator type. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html. image_scope (Optional[str]): The image type, i.e. what it is used for. Valid values: "training", "inference", "inference_graviton", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. container_version (Optional[str]): the version of docker image. Ideally the value of parameter should be created inside the framework. For custom use, see the list of supported container versions: https://github.com/aws/deep-learning-containers/blob/master/available_images.md (default: None). distributed (bool): If the image is for running distributed training. smp (bool): If using the SMP library for distributed training. base_framework_version (Optional[str]): The base version number of PyTorch or Tensorflow. (default: None). training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler (default: None). model_id (Optional[str]): The JumpStart model ID for which to retrieve the image URI (default: None). model_version (Optional[str]): The version of the JumpStart model for which to retrieve the image URI (default: None). hub_arn (Optional[str]): The arn of the SageMaker Hub for which to retrieve model details from. (Default: None). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without an exception raised. If ``False``, raises an exception if the script used by this version of the model has dependencies with known security vulnerabilities. (Default: False). tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated without an exception raised. If False, raises an exception if the version of the model is deprecated. (Default: False). sdk_version (Optional[str]): the version of python-sdk that will be used in the image retrieval. (default: None). inference_tool (Optional[str]): the tool that will be used to aid in the inference. Valid values: "neuron, neuronx, None" (default: None). serverless_inference_config (sagemaker.serve.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to determine processor type. sagemaker_session (sagemaker.core.helper.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). model_type (JumpStartModelType): The type of the model, can be open weights model or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: The ECR URI for the corresponding SageMaker Docker image. Raises: NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported or any PipelineVariable object is passed in. VulnerableJumpStartModelError: If any of the dependencies required by the script have known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ args = dict(locals()) for name, val in args.items(): if name in CONFIGURABLE_ATTRIBUTES and not val: default_value = SageMakerConfig.resolve_value_from_config( config_path=_simple_path( SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name) ) ) if default_value is not None: locals()[name] = default_value for name, val in args.items(): if is_pipeline_variable(val): raise ValueError( "When retrieving the image_uri, the argument %s should not be a pipeline variable " "(%s) since pipeline variables are only interpreted in the pipeline execution time." % (name, type(val)) ) if model_id: return ImageRetriever.retrieve_jumpstart_uri() if framework == HUGGING_FACE_FRAMEWORK: return ImageRetriever.retrieve_hugging_face_uri( region=region, version=version, py_version=py_version, instance_type=instance_type, accelerator_type=accelerator_type, image_scope=image_scope, container_version=container_version, distributed=distributed, base_framework_version=base_framework_version, training_compiler_config=training_compiler_config, sdk_version=sdk_version, inference_tool=inference_tool, serverless_inference_config=serverless_inference_config, ) if framework == PYTORCH_FRAMEWORK: return ImageRetriever.retrieve_pytorch_uri( region=region, version=version, py_version=py_version, instance_type=instance_type, accelerator_type=accelerator_type, image_scope=image_scope, container_version=container_version, distributed=distributed, smp=smp, training_compiler_config=training_compiler_config, sdk_version=sdk_version, inference_tool=inference_tool, serverless_inference_config=serverless_inference_config, ) final_image_scope = _get_final_image_scope(framework, instance_type, image_scope) _validate_for_suppported_frameworks_and_instance_type(framework, instance_type) config = _config_for_framework_and_scope(framework, final_image_scope, accelerator_type) version = _validate_version_and_set_if_needed(version, config, framework, image_scope) version_config = config["versions"][_version_for_config(version, config)] py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} hostname = endpoint_data["hostname"] repo = version_config["repository"] processor = _processor( instance_type, config.get("processors") or version_config.get("processors"), serverless_inference_config, ) # if container version is available in .json file, utilize that if version_config.get("container_version"): container_version = version_config["container_version"][processor] # Append sdk version in case of trainium instances if repo in ["pytorch-training-neuron", "pytorch-training-neuronx"]: if not sdk_version: sdk_version = _get_latest_versions(version_config["sdk_versions"]) container_version = sdk_version + "-" + container_version tag_prefix = version_config.get("tag_prefix", version) if repo == f"{framework}-inference-graviton": container_version = f"{container_version}-sagemaker" _validate_instance_deprecation(framework, instance_type, version) tag = _get_image_tag( container_version, distributed, final_image_scope, framework, inference_tool, instance_type, processor, py_version, tag_prefix, version, ) if tag: repo += ":{}".format(tag) return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
[docs] @staticmethod def retrieve_base_python_image_uri(region: str, py_version: str = "310") -> str: """Retrieves the image URI for base python image. Args: region (str): The AWS region to use for image URI. py_version (str): The python version to use for the image. Default to 310 Returns: str: The image URI string. """ framework = "sagemaker-base-python" version = "1.0" endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} hostname = endpoint_data["hostname"] config = config_for_framework(framework) version_config = config["versions"][_version_for_config(version, config)] registry = _registry_from_region(region, version_config["registries"]) repo = version_config["repository"] + "-" + py_version repo_and_tag = repo + ":" + version return ECR_URI_TEMPLATE.format( registry=registry, hostname=hostname, repository=repo_and_tag )