# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utility functions and mixins for ModelBuilder.
This module provides utility functions for:
- Session management and initialization
- Instance type detection and optimization
- Container image auto-detection
- HuggingFace and JumpStart model handling
- Resource requirement calculation
- Framework serialization support
- MLflow model integration
- General model deployment utilities
Example:
Basic usage as a mixin class::
class MyModelBuilder(ModelBuilderUtils):
def __init__(self):
self.model = "huggingface-model-id"
self.instance_type = "ml.g5.xlarge"
def build(self):
self._auto_detect_image_uri()
return self.image_uri
"""
from __future__ import absolute_import, annotations
# Standard library imports
import importlib.util
import sys
import shutil
import json
import os
import platform
import re
import uuid
from pathlib import Path
import subprocess
from typing import Any, Dict, List, Optional, Tuple, Union
# Third-party imports
from packaging.version import Version
# SageMaker core imports
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.utils.utils import logger
from sagemaker.train import ModelTrainer
# SageMaker serve imports
from sagemaker.serve.compute_resource_requirements import ResourceRequirements
from sagemaker.serve.constants import (
DEFAULT_SERIALIZERS_BY_FRAMEWORK,
Framework,
)
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.builder.serve_settings import _ServeSettings
from sagemaker.serve.detector.image_detector import (
_cast_to_compatible_version,
_detect_framework_and_version,
auto_detect_container,
_get_model_base,
)
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils import task
from sagemaker.serve.utils.exceptions import TaskNotFoundException
from sagemaker.serve.utils.hardware_detector import _total_inference_model_size_mib
from sagemaker.serve.utils.types import ModelServer
from sagemaker.core.resources import Model
# MLflow imports
from sagemaker.serve.model_format.mlflow.constants import (
MLFLOW_METADATA_FILE,
MLFLOW_MODEL_PATH,
MLFLOW_PIP_DEPENDENCY_FILE,
MLFLOW_REGISTRY_PATH_REGEX,
MLFLOW_RUN_ID_REGEX,
MLFLOW_TRACKING_ARN,
MODEL_PACKAGE_ARN_REGEX,
)
from sagemaker.serve.model_format.mlflow.utils import (
_copy_directory_contents,
_download_s3_artifacts,
_generate_mlflow_artifact_path,
_get_all_flavor_metadata,
_get_default_model_server_for_mlflow,
_get_deployment_flavor,
_select_container_for_mlflow_model,
_validate_input_for_mlflow,
)
# SageMaker utils imports
from sagemaker.core.deserializers import JSONDeserializer
from sagemaker.core.jumpstart.accessors import JumpStartS3PayloadAccessor
from sagemaker.core.jumpstart.factory.utils import get_init_kwargs, get_deploy_kwargs
from sagemaker.core.jumpstart.utils import (
get_jumpstart_base_name_if_jumpstart_model,
get_jumpstart_content_bucket,
get_eula_message,
get_metrics_from_deployment_configs,
add_instance_rate_stats_to_benchmark_metrics,
)
from sagemaker.core.jumpstart.types import DeploymentConfigMetadata
from sagemaker.core.jumpstart import accessors
from sagemaker.core.enums import Tag
from sagemaker.core.local.local_session import LocalSession
from sagemaker.core.s3 import S3Downloader
from sagemaker.core.serializers import NumpySerializer
from sagemaker.core.common_utils import (
Tags,
_validate_new_tags,
base_name_from_image,
remove_tag_with_key,
)
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core import model_uris
from sagemaker.serve.utils.local_hardware import _get_available_gpus
from sagemaker.core.base_serializers import JSONSerializer
from sagemaker.core.deserializers import JSONDeserializer
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.builder.requirements_manager import RequirementsManager
from sagemaker.serve.validations.check_integrity import (
compute_hash,
)
from sagemaker.core.remote_function.core.serialization import _MetaData
from sagemaker.serve.model_server.triton.config_template import CONFIG_TEMPLATE
SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources"
_DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py"
_NO_JS_MODEL_EX = "HuggingFace JumpStart Model ID not detected. Building for HuggingFace Model ID."
_JS_SCOPE = "inference"
_CODE_FOLDER = "code"
_JS_MINIMUM_VERSION_IMAGE = "{}:0.31.0-lmi13.0.0-cu124"
_INVALID_DJL_SAMPLE_DATA_EX = (
'For djl-serving, sample input must be of {"inputs": str, "parameters": dict}, '
'sample output must be of [{"generated_text": str,}]'
)
_INVALID_TGI_SAMPLE_DATA_EX = (
'For tgi, sample input must be of {"inputs": str, "parameters": dict}, '
'sample output must be of [{"generated_text": str,}]'
)
SUPPORTED_TRITON_MODE = {Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS}
SUPPORTED_TRITON_FRAMEWORK = {Framework.PYTORCH, Framework.TENSORFLOW}
INPUT_NAME = "input_1"
OUTPUT_NAME = "output_1"
TRITON_IMAGE_ACCOUNT_ID_MAP = {
"us-east-1": "785573368785",
"us-east-2": "007439368137",
"us-west-1": "710691900526",
"us-west-2": "301217895009",
"eu-west-1": "802834080501",
"eu-west-2": "205493899709",
"eu-west-3": "254080097072",
"eu-north-1": "601324751636",
"eu-south-1": "966458181534",
"eu-central-1": "746233611703",
"ap-east-1": "110948597952",
"ap-south-1": "763008648453",
"ap-northeast-1": "941853720454",
"ap-northeast-2": "151534178276",
"ap-southeast-1": "324986816169",
"ap-southeast-2": "355873309152",
"cn-northwest-1": "474822919863",
"cn-north-1": "472730292857",
"sa-east-1": "756306329178",
"ca-central-1": "464438896020",
"me-south-1": "836785723513",
"af-south-1": "774647643957",
}
GPU_INSTANCE_FAMILIES = {
"ml.g4dn",
"ml.g5",
"ml.p3",
"ml.p3dn",
"ml.p4",
"ml.p4d",
"ml.p4de",
"local_gpu",
}
TRITON_IMAGE_BASE = "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:{version}-py3"
LATEST_VERSION = "23.02"
VERSION_FOR_TF1 = "23.02"
[docs]
class TritonSerializer(JSONSerializer):
"""A wrapper of JSONSerializer because Triton expects input to be certain format"""
def __init__(self, input_serializer, dtype: str, content_type="application/json"):
"""Initialize TritonSerializer with input serializer and data type."""
super().__init__(content_type)
self.input_serializer = input_serializer
self.dtype = dtype
[docs]
def serialize(self, data):
"""Serialize data into Triton-compatible format."""
numpy_data = self.input_serializer.serialize(data)
payload = {
"inputs": [
{
"name": INPUT_NAME,
"shape": numpy_data.shape,
"datatype": self.dtype,
"data": numpy_data.tolist(),
}
]
}
return super().serialize(payload)
class _ModelBuilderUtils:
"""Utility mixin class providing common functionality for ModelBuilder.
This class provides utility methods for:
- Session management and initialization
- Instance type detection and optimization
- Container image auto-detection
- HuggingFace and JumpStart model handling
- Resource requirement calculation
- Framework serialization support
- MLflow model integration
- General model deployment utilities
This class is designed to be used as a mixin with ModelBuilder classes.
It expects certain attributes to be available on the instance:
- sagemaker_session: SageMaker session object
- model: Model identifier or object
- instance_type: EC2 instance type
- region: AWS region
- env_vars: Environment variables dict
Example:
class MyModelBuilder(ModelBuilderUtils):
def __init__(self):
self.model = "huggingface-model-id"
self.instance_type = "ml.g5.xlarge"
self.sagemaker_session = None
def build(self):
self._init_sagemaker_session_if_does_not_exist()
self._auto_detect_image_uri()
return self.image_uri
"""
# ========================================
# Session Management
# ========================================
def _init_sagemaker_session_if_does_not_exist(
self, instance_type: Optional[str] = None
) -> None:
"""Initialize SageMaker session if it doesn't exist.
Sets self.sagemaker_session to LocalSession for local instances,
or regular Session for remote instances.
Args:
instance_type: EC2 instance type to determine session type.
If None, uses self.instance_type.
"""
if self.sagemaker_session:
return
effective_instance_type = instance_type or getattr(self, "instance_type", None)
if effective_instance_type in ("local", "local_gpu"):
self.sagemaker_session = LocalSession(
sagemaker_config=getattr(self, "_sagemaker_config", None)
)
else:
# Create session with correct region
if hasattr(self, "region") and self.region:
import boto3
boto_session = boto3.Session(region_name=self.region)
self.sagemaker_session = Session(
boto_session=boto_session,
sagemaker_config=getattr(self, "_sagemaker_config", None),
)
else:
self.sagemaker_session = Session(
sagemaker_config=getattr(self, "_sagemaker_config", None)
)
# ========================================
# Instance Type Detection
# ========================================
def _get_jumpstart_recommended_instance_type(self) -> Optional[str]:
"""Get recommended instance type from JumpStart metadata.
Returns:
Recommended instance type string, or None if not available.
"""
try:
deploy_kwargs = get_deploy_kwargs(
model_id=self.model,
model_version=getattr(self, "model_version", None) or "*",
region=self.region,
tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None),
tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None),
)
# JumpStart provides recommended instance type
if hasattr(deploy_kwargs, "instance_type") and deploy_kwargs.instance_type:
return deploy_kwargs.instance_type
except Exception:
pass
return None
def _get_default_instance_type(self) -> str:
"""Get optimal default instance type based on model characteristics.
Analyzes the model to determine appropriate instance type:
- JumpStart models: Use recommended instance type from metadata
- HuggingFace models: Analyze model size and tags for GPU requirements
- Fallback: ml.m5.large for CPU workloads
Returns:
Instance type string (e.g., 'ml.g5.xlarge', 'ml.m5.large').
"""
logger.debug("Auto-detecting optimal instance type for model...")
if isinstance(self.model, str) and self._is_jumpstart_model_id():
recommended_type = self._get_jumpstart_recommended_instance_type()
if recommended_type:
logger.debug(f"Using JumpStart recommended instance type: {recommended_type}")
return recommended_type
# For HuggingFace models, use metadata to detect requirements
elif isinstance(self.model, str):
try:
env_vars = getattr(self, "env_vars", {}) or {}
hf_model_md = self.get_huggingface_model_metadata(
self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
# Check model size from metadata
model_size = hf_model_md.get("safetensors", {}).get("total", 0)
model_tags = hf_model_md.get("tags", [])
# Large models or specific tags indicate GPU need
if (
model_size > 2_000_000_000 # > 2GB
or any(tag in model_tags for tag in ["7b", "13b", "70b"])
or "7b" in self.model.lower()
or "13b" in self.model.lower()
):
logger.debug("Detected large model, using GPU instance type: ml.g5.xlarge")
return "ml.g5.xlarge"
except Exception as e:
logger.debug(f"Could not get HF metadata for smart detection: {e}")
# Default fallback
logger.debug("Using default CPU instance type: ml.m5.large")
return "ml.m5.large"
# ========================================
# Image Detection and Container Utils
# ========================================
def _auto_detect_container_default(self) -> str:
"""Auto-detect container image for framework-based models.
Detects the appropriate Deep Learning Container (DLC) based on:
- Model framework (PyTorch, TensorFlow)
- Framework version from HuggingFace metadata
- Python version compatibility
- Instance type requirements
Returns:
Container image URI string.
Raises:
ValueError: If instance type not specified or no compatible image found.
"""
from sagemaker.core import image_uris
logger.debug("Auto-detecting image since image_uri was not provided in ModelBuilder()")
if not getattr(self, "instance_type", None):
raise ValueError(
"Instance type is not specified. "
"Unable to detect if the container needs to be GPU or CPU."
)
logger.warning(
"Auto detection is only supported for single models DLCs with a framework backend."
)
py_tuple = platform.python_version_tuple()
env_vars = getattr(self, "env_vars", {}) or {}
torch_v, tf_v, base_hf_v, _ = self._get_hf_framework_versions(
self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
if torch_v:
fw, fw_version = "pytorch", torch_v
elif tf_v:
fw, fw_version = "tensorflow", tf_v
else:
raise ValueError("Could not detect framework from HuggingFace model metadata")
logger.debug("Auto-detected framework: %s", fw)
logger.debug("Auto-detected framework version: %s", fw_version)
casted_versions = _cast_to_compatible_version(fw, fw_version) if fw_version else (None,)
dlc = None
for casted_version in filter(None, casted_versions):
try:
dlc = image_uris.retrieve(
framework=fw,
region=self.region,
version=casted_version,
image_scope="inference",
py_version=f"py{py_tuple[0]}{py_tuple[1]}",
instance_type=self.instance_type,
)
break
except ValueError:
pass
if dlc:
logger.debug("Auto-detected container: %s. Proceeding with deployment.", dlc)
return dlc
raise ValueError(
f"Unable to auto-detect a DLC for framework {fw}, "
f"framework version {fw_version} and python version py{py_tuple[0]}{py_tuple[1]}. "
f"Please manually provide image_uri to ModelBuilder()"
)
def _get_smd_image_uri(self, processing_unit: Optional[str] = None) -> str:
"""Get SageMaker Distribution (SMD) inference image URI.
Retrieves the appropriate SMD container image for custom orchestrator deployment.
Requires Python >= 3.12 for SMD inference.
Args:
processing_unit: Target processing unit ('cpu' or 'gpu').
If None, defaults to 'cpu'.
Returns:
SMD inference image URI string.
Raises:
ValueError: If Python version < 3.12 or invalid processing unit.
"""
import sys
from sagemaker.core import image_uris
if not self.sagemaker_session:
if hasattr(self, "region") and self.region:
import boto3
boto_session = boto3.Session(region_name=self.region)
self.sagemaker_session = Session(boto_session=boto_session)
else:
self.sagemaker_session = Session()
formatted_py_version = f"py{sys.version_info.major}{sys.version_info.minor}"
if Version(f"{sys.version_info.major}{sys.version_info.minor}") < Version("3.12"):
raise ValueError(
f"Found Python version {formatted_py_version} but "
f"custom orchestrator deployment requires Python version >= 3.12."
)
INSTANCE_TYPES = {"cpu": "ml.c5.xlarge", "gpu": "ml.g5.4xlarge"}
effective_processing_unit = processing_unit or "cpu"
if effective_processing_unit not in INSTANCE_TYPES:
raise ValueError(
f"Invalid processing unit '{effective_processing_unit}'. "
f"Must be one of: {list(INSTANCE_TYPES.keys())}"
)
logger.debug(
"Finding SMD inference image URI for a %s instance.", effective_processing_unit
)
smd_uri = image_uris.retrieve(
framework="sagemaker-distribution",
image_scope="inference",
instance_type=INSTANCE_TYPES[effective_processing_unit],
region=self.region,
)
logger.debug("Found compatible image: %s", smd_uri)
return smd_uri
def _is_huggingface_model(self) -> bool:
"""Check if model is a HuggingFace model ID.
Determines if the model string represents a HuggingFace model by:
- Checking for organization/model-name format
- Checking explicit model_type designation
- Fallback: assume HuggingFace if not JumpStart
Returns:
True if model appears to be a HuggingFace model ID.
"""
if not isinstance(self.model, str):
return False
# Simple pattern matching for HuggingFace model IDs
# Format: "organization/model-name" or just "model-name"
model_type = getattr(self, "model_type", None)
if "/" in self.model or model_type == "huggingface":
return True
# Additional check: if it's not a JumpStart model, assume HuggingFace
return not self._is_jumpstart_model_id()
def _get_supported_version(
self, hf_config: Dict[str, Any], hugging_face_version: str, base_fw: str
) -> str:
"""Extract supported framework version from HuggingFace config.
Uses the HuggingFace JSON config to pick the best supported version
for the given framework.
Args:
hf_config: HuggingFace configuration dictionary
hugging_face_version: HuggingFace transformers version
base_fw: Base framework name (e.g., 'pytorch', 'tensorflow')
Returns:
Best supported framework version string.
"""
version_config = hf_config.get("versions", {}).get(hugging_face_version, {})
versions_to_return = []
for key in version_config.keys():
if key.startswith(base_fw):
base_fw_version = key[len(base_fw) :]
if len(hugging_face_version.split(".")) == 2:
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
versions_to_return.append(base_fw_version)
if not versions_to_return:
raise ValueError(f"No supported versions found for framework {base_fw}")
return sorted(versions_to_return, reverse=True)[0]
def _get_hf_framework_versions(
self, model_id: str, hf_token: Optional[str] = None
) -> Tuple[Optional[str], Optional[str], str, str]:
"""Get HuggingFace framework versions for image_uris.retrieve().
Analyzes HuggingFace model metadata to determine the appropriate
framework versions for container image selection.
Args:
model_id: HuggingFace model identifier
hf_token: Optional HuggingFace API token for private models
Returns:
Tuple of (pytorch_version, tensorflow_version, transformers_version, py_version).
One of pytorch_version or tensorflow_version will be None.
Raises:
ValueError: If no supported framework versions found.
"""
from sagemaker.core import image_uris
# Get model metadata for framework detection
hf_model_md = self.get_huggingface_model_metadata(model_id, hf_token)
# Get HuggingFace framework configuration
hf_config = image_uris.config_for_framework("huggingface").get("inference")
config = hf_config["versions"]
base_hf_version = sorted(config.keys(), key=lambda v: Version(v), reverse=True)[0]
model_tags = hf_model_md.get("tags", [])
# Detect framework from model tags
if "pytorch" in model_tags:
pytorch_version = self._get_supported_version(hf_config, base_hf_version, "pytorch")
py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get(
"py_versions", []
)[-1]
return pytorch_version, None, base_hf_version, py_version
elif "keras" in model_tags or "tensorflow" in model_tags:
tensorflow_version = self._get_supported_version(
hf_config, base_hf_version, "tensorflow"
)
py_version = config[base_hf_version][f"tensorflow{tensorflow_version}"].get(
"py_versions", []
)[-1]
return None, tensorflow_version, base_hf_version, py_version
else:
# Default to PyTorch if no framework detected (matches V2 behavior)
pytorch_version = self._get_supported_version(hf_config, base_hf_version, "pytorch")
py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get(
"py_versions", []
)[-1]
return pytorch_version, None, base_hf_version, py_version
def _detect_jumpstart_image(self) -> None:
"""Detect and set image URI for JumpStart models.
Uses JumpStart metadata to determine the appropriate container image
and framework information for the model.
Raises:
ValueError: If image URI cannot be determined or JumpStart lookup fails.
"""
try:
init_kwargs = get_init_kwargs(
model_id=self.model,
model_version=getattr(self, "model_version", None) or "*",
region=self.region,
instance_type=getattr(self, "instance_type", None),
tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None),
tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None),
)
self.image_uri = init_kwargs.get("image_uri")
if not self.image_uri:
raise ValueError(f"Could not determine image URI for JumpStart model: {self.model}")
logger.debug("Auto-detected JumpStart image: %s", self.image_uri)
self.framework, self.framework_version = self._extract_framework_from_image_uri()
except Exception as e:
raise ValueError(
f"Failed to auto-detect image for JumpStart model {self.model}: {e}"
) from e
def _detect_huggingface_image(self) -> None:
"""Detect and set image URI for HuggingFace models based on model server.
Automatically selects the appropriate container image based on:
- Explicit model_server setting
- Model task type from HuggingFace metadata
- Framework requirements and versions
Raises:
ValueError: If image detection fails or unsupported model server.
"""
from sagemaker.core import image_uris
try:
env_vars = getattr(self, "env_vars", {}) or {}
# Determine which model server we're using
model_server = getattr(self, "model_server", None)
if not model_server:
# Auto-select model server based on HF model task
hf_model_md = self.get_huggingface_model_metadata(
self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
model_task = hf_model_md.get("pipeline_tag")
if model_task == "text-generation":
effective_model_server = ModelServer.TGI
elif model_task in ["sentence-similarity", "feature-extraction"]:
effective_model_server = ModelServer.TEI
else:
effective_model_server = ModelServer.MMS # Transformers
else:
effective_model_server = model_server
# Choose image based on effective model server
if effective_model_server == ModelServer.TGI:
# TGI: Use image_uris.retrieve with "huggingface-llm" framework
self.image_uri = image_uris.retrieve(
"huggingface-llm",
region=self.region,
version=None, # Use latest version
image_scope="inference",
)
self.framework = Framework.HUGGINGFACE
elif effective_model_server == ModelServer.TEI:
# TEI: Use image_uris.retrieve with "huggingface-tei" framework
self.image_uri = image_uris.retrieve(
framework="huggingface-tei",
image_scope="inference",
instance_type=getattr(self, "instance_type", None),
region=self.region,
)
self.framework = Framework.HUGGINGFACE
elif effective_model_server == ModelServer.DJL_SERVING:
# DJL: Use image_uris.retrieve with "djl-lmi" framework (matches DJLModel default)
self.image_uri = image_uris.retrieve(
framework="djl-lmi",
region=self.region,
version="latest",
image_scope="inference",
instance_type=getattr(self, "instance_type", None),
)
self.framework = Framework.DJL
elif effective_model_server == ModelServer.MMS: # Transformers
# Transformers: Use HuggingFace framework with detected versions
pytorch_version, tensorflow_version, transformers_version, py_version = (
self._get_hf_framework_versions(
self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
)
base_framework_version = (
f"pytorch{pytorch_version}"
if pytorch_version
else f"tensorflow{tensorflow_version}"
)
self.image_uri = image_uris.retrieve(
framework="huggingface",
region=self.region,
version=transformers_version,
py_version=py_version,
instance_type=getattr(self, "instance_type", None),
image_scope="inference",
base_framework_version=base_framework_version,
)
self.framework = Framework.HUGGINGFACE
elif effective_model_server == ModelServer.TORCHSERVE:
# TorchServe: Use HuggingFace framework with detected versions
pytorch_version, tensorflow_version, transformers_version, py_version = (
self._get_hf_framework_versions(
self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
)
base_framework_version = (
f"pytorch{pytorch_version}"
if pytorch_version
else f"tensorflow{tensorflow_version}"
)
self.image_uri = image_uris.retrieve(
framework="huggingface",
region=self.region,
version=transformers_version,
py_version=py_version,
instance_type=getattr(self, "instance_type", None),
image_scope="inference",
base_framework_version=base_framework_version,
)
self.framework = Framework.HUGGINGFACE
elif effective_model_server == ModelServer.TRITON:
# Triton: Uses custom image construction (not image_uris.retrieve)
raise ValueError(
"Triton image detection for HuggingFace models requires custom implementation"
)
elif effective_model_server == ModelServer.TENSORFLOW_SERVING:
# TensorFlow Serving: V2 required explicit image_uri (no auto-detection)
raise ValueError("TensorFlow Serving requires explicit image_uri specification")
elif effective_model_server == ModelServer.SMD:
# SMD: Uses _get_smd_image_uri helper
cpu_or_gpu = self._get_processing_unit()
self.image_uri = self._get_smd_image_uri(processing_unit=cpu_or_gpu)
self.framework = Framework.SMD
else:
raise ValueError(
f"Unsupported model server for HuggingFace models: {effective_model_server}"
)
logger.debug("Auto-detected HuggingFace image: %s", self.image_uri)
except Exception as e:
raise ValueError(
f"Failed to auto-detect image for HuggingFace model {self.model}: {e}"
) from e
def _detect_model_object_image(self) -> None:
"""Detect image for legacy object-based models.
Handles model objects (not string IDs) by using the auto_detect_container
function to determine appropriate container image.
Raises:
ValueError: If neither model nor inference_spec available for detection.
"""
model = getattr(self, "model", None)
inference_spec = getattr(self, "inference_spec", None)
model_path = getattr(self, "model_path", None)
if model:
logger.debug(
"Auto-detecting container URL for the provided model on instance %s",
getattr(self, "instance_type", None),
)
self.image_uri, fw, self.framework_version = auto_detect_container(
model, self.region, getattr(self, "instance_type", None)
)
self.framework = self._normalize_framework_to_enum(fw)
elif inference_spec:
logger.warning(
"model_path provided with no image_uri. Attempting to auto-detect the image "
"by loading the model using inference_spec.load()..."
)
self.image_uri, fw, self.framework_version = auto_detect_container(
inference_spec.load(model_path),
self.region,
getattr(self, "instance_type", None),
)
self.framework = self._normalize_framework_to_enum(fw)
else:
raise ValueError("Cannot detect required model or inference spec")
def _auto_detect_image_uri(self) -> None:
"""Auto-detect container image URI based on model type.
Determines the appropriate container image by:
1. Using provided image_uri if available
2. For string models: JumpStart vs HuggingFace detection
3. For object models: Legacy auto-detection
Sets self.image_uri, self.framework, and self.framework_version.
Raises:
ValueError: If image cannot be auto-detected for the model type.
"""
image_uri = getattr(self, "image_uri", None)
if image_uri:
self.framework, self.framework_version = self._extract_framework_from_image_uri()
logger.debug("Skipping auto-detection as image_uri is provided: %s", image_uri)
return
if isinstance(self.model, ModelTrainer):
self._detect_inference_image_from_training()
return
model = getattr(self, "model", None)
inference_spec = getattr(self, "inference_spec", None)
if isinstance(model, str):
# V3: String-based model detection
model_type = getattr(self, "model_type", None)
# First priority: Use model_type if it indicates JumpStart
if model_type in ["open_weights", "proprietary"]:
self._detect_jumpstart_image()
else:
# model_type is None - use pattern-based detection
if self._is_jumpstart_model_id():
self._detect_jumpstart_image()
elif self._is_huggingface_model():
self._detect_huggingface_image()
else:
raise ValueError(f"Cannot auto-detect image for model: {model}")
elif inference_spec and hasattr(inference_spec, "get_model"):
try:
spec_model = inference_spec.get_model()
if spec_model is None:
logger.warning(
"InferenceSpec.get_model() returned None. If you are using a JumpStar or HuggingFace model, you may need to implement get_model() in your InferenceSpec class"
)
if isinstance(spec_model, str):
# Temporarily set model for detection, then restore
original_model = self.model
self.model = spec_model
# Use existing detection logic
if self._is_jumpstart_model_id():
self._detect_jumpstart_image()
elif self._is_huggingface_model():
self._detect_huggingface_image()
else:
raise ValueError(
f"Cannot auto-detect image for inference_spec model: {spec_model}"
)
# Restore original model
self.model = original_model
return
except Exception as e:
pass
# Fall back to existing object detection
self._detect_model_object_image()
else:
# V2: Object-based model detection
self._detect_model_object_image()
# ========================================
# HuggingFace Jumpstart Utils
# ========================================
def _use_jumpstart_equivalent(self) -> bool:
"""Check if HuggingFace model has JumpStart equivalent and use it.
Replaces the HuggingFace model with its JumpStart equivalent if available.
Skips replacement if image_uri or env_vars are explicitly provided.
Returns:
True if JumpStart equivalent was found and used, False otherwise.
"""
# Do not use the equivalent JS model if image_uri or env_vars is provided
image_uri = getattr(self, "image_uri", None)
env_vars = getattr(self, "env_vars", None)
if image_uri or env_vars:
return False
if not hasattr(self, "_has_jumpstart_equivalent"):
self._jumpstart_mapping = self._retrieve_hugging_face_model_mapping()
self._has_jumpstart_equivalent = self.model in self._jumpstart_mapping
if self._has_jumpstart_equivalent:
# Use schema builder from HF model metadata
schema_builder = getattr(self, "schema_builder", None)
if not schema_builder:
model_task = None
model_metadata = getattr(self, "model_metadata", None)
if model_metadata:
model_task = model_metadata.get("HF_TASK")
hf_model_md = self.get_huggingface_model_metadata(self.model)
if not model_task:
model_task = hf_model_md.get("pipeline_tag")
if model_task:
self._hf_schema_builder_init(model_task)
huggingface_model_id = self.model
jumpstart_model_id = self._jumpstart_mapping[huggingface_model_id]["jumpstart-model-id"]
self.model = jumpstart_model_id
merged_date = self._jumpstart_mapping[huggingface_model_id].get("merged-at")
# Call _build_for_jumpstart if method exists
if hasattr(self, "_build_for_jumpstart"):
self._build_for_jumpstart()
compare_model_diff_message = (
"If you want to identify the differences between the two, "
"please use model_uris.retrieve() to retrieve the model "
"artifact S3 URI and compare them."
)
is_gated = hasattr(self, "_is_gated_model") and self._is_gated_model()
logger.warning(
"Please note that for this model we are using the JumpStart's "
f'local copy "{jumpstart_model_id}" '
f'of the HuggingFace model "{huggingface_model_id}" you chose. '
"We strive to keep our local copy synced with the HF model hub closely. "
"This model was synced "
f"{f'on {merged_date}' if merged_date else 'before 11/04/2024'}. "
f"{compare_model_diff_message if not is_gated else ''}"
)
return True
return False
def _hf_schema_builder_init(self, model_task: str) -> None:
"""Initialize schema builder for HuggingFace model task.
Attempts to load I/O schemas locally first, then falls back to remote
schema retrieval for the given HuggingFace task.
Args:
model_task: HuggingFace task name (e.g., 'text-generation', 'text-classification')
Raises:
TaskNotFoundException: If I/O schema for the task cannot be found locally or remotely.
"""
try:
try:
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
except ValueError:
# Samples could not be loaded locally, try to fetch remote HF schema
from sagemaker_schema_inference_artifacts.huggingface import remote_schema_retriever
if model_task in ("text-to-image", "automatic-speech-recognition"):
logger.warning(
"HF SchemaBuilder for %s is in beta mode, and is not guaranteed to work "
"with all models at this time.",
model_task,
)
remote_hf_schema_helper = remote_schema_retriever.RemoteSchemaRetriever()
(
sample_inputs,
sample_outputs,
) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task)
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
except ValueError as e:
raise TaskNotFoundException(
f"HuggingFace Schema builder samples for {model_task} could not be found "
f"locally or via remote."
) from e
def _retrieve_hugging_face_model_mapping(self) -> Dict[str, Dict[str, Any]]:
"""Retrieve and preprocess HuggingFace/JumpStart model mapping.
Downloads the mapping file from S3 that contains the correspondence
between HuggingFace model IDs and their JumpStart equivalents.
Returns:
Dictionary mapping HuggingFace model IDs to JumpStart model metadata.
Empty dict if mapping cannot be retrieved.
"""
converted_mapping = {}
session = getattr(self, "sagemaker_session", None)
if not session:
return converted_mapping
region = session.boto_region_name
try:
mapping_json_object = JumpStartS3PayloadAccessor.get_object_cached(
bucket=get_jumpstart_content_bucket(region),
key="hf_model_id_map_cache.json",
region=region,
s3_client=session.s3_client,
)
mapping = json.loads(mapping_json_object)
except Exception:
return converted_mapping
for k, v in mapping.items():
converted_mapping[v["hf-model-id"]] = {
"jumpstart-model-id": k,
"jumpstart-model-version": v["jumpstart-model-version"],
"merged-at": v.get("merged-at"),
"hf-model-repo-sha": v.get("hf-model-repo-sha"),
}
return converted_mapping
def _prepare_hf_model_for_upload(self) -> None:
"""Download HuggingFace model metadata for upload.
Creates a temporary directory and downloads the necessary HuggingFace
model metadata files if model_path is not already set.
"""
model_path = getattr(self, "model_path", None)
if not model_path:
self.model_path = f"/tmp/sagemaker/model-builder/{self.model}"
env_vars = getattr(self, "env_vars", {}) or {}
self.download_huggingface_model_metadata(
self.model,
os.path.join(self.model_path, "code"),
env_vars.get("HUGGING_FACE_HUB_TOKEN"),
)
# ========================================
# Resource and Hardware Utils
# ========================================
def _get_processing_unit(self) -> str:
"""Detect if resource requirements are intended for CPU or GPU instance.
Analyzes resource requirements to determine the target processing unit:
- Checks for accelerator requirements in resource_requirements
- Checks for accelerator requirements in modelbuilder_list items
- Defaults to CPU if no accelerators specified
Returns:
'gpu' if accelerators are required, 'cpu' otherwise.
"""
# Assume custom orchestrator will be deployed as an endpoint to a CPU instance
resource_requirements = getattr(self, "resource_requirements", None)
if not resource_requirements or not getattr(
resource_requirements, "num_accelerators", None
):
modelbuilder_list = getattr(self, "modelbuilder_list", None) or []
for ic in modelbuilder_list:
ic_resource_req = getattr(ic, "resource_requirements", None)
if ic_resource_req and getattr(ic_resource_req, "num_accelerators", 0) > 0:
return "gpu"
return "cpu"
if getattr(resource_requirements, "num_accelerators", 0) > 0:
return "gpu"
return "cpu"
def _get_inference_component_resource_requirements(self, mb) -> None:
"""Fetch pre-benchmarked resource requirements from JumpStart.
Attempts to retrieve and set resource requirements for inference components
using JumpStart deployment configurations when available.
Raises:
ValueError: If no resource requirements provided and no JumpStart configs found.
"""
resource_requirements = getattr(mb, "resource_requirements", None)
if mb._is_jumpstart_model_id() and not resource_requirements:
if not hasattr(mb, "list_deployment_configs"):
return
deployment_configs = mb.list_deployment_configs()
if not deployment_configs:
inference_component_name = getattr(mb, "inference_component_name", "Unknown")
raise ValueError(
f"No resource requirements were provided for Inference Component "
f"{inference_component_name} and no default deployment "
f"configs were found in JumpStart."
)
compute_requirements = (
deployment_configs[0]
.get("DeploymentArgs", {})
.get("ComputeResourceRequirements", {})
)
logger.debug("Retrieved pre-benchmarked deployment configurations from JumpStart.")
mb.resource_requirements = ResourceRequirements(
requests={
"memory": compute_requirements.get("MinMemoryRequiredInMb"),
"num_accelerators": compute_requirements.get(
"NumberOfAcceleratorDevicesRequired", None
),
"copies": 1,
"num_cpus": compute_requirements.get("NumberOfCpuCoresRequired", None),
},
limits={"memory": compute_requirements.get("MaxMemoryRequiredInMb", None)},
)
return mb
def _can_fit_on_single_gpu(self) -> bool:
"""Check if model can fit on a single GPU.
Compares the total inference model size with single GPU memory capacity
to determine if the model can fit on a single GPU device.
Returns:
True if model size <= single GPU memory size, False otherwise.
"""
try:
if not hasattr(self, "_try_fetch_gpu_info"):
return False
single_gpu_size_mib = self._try_fetch_gpu_info()
env_vars = getattr(self, "env_vars", {}) or {}
model_size_mib = _total_inference_model_size_mib(
self.model, env_vars.get("dtypes", "float32")
)
if model_size_mib <= single_gpu_size_mib:
logger.debug(
"Total inference model size: %s MiB, single GPU size: %s MiB",
model_size_mib,
single_gpu_size_mib,
)
return True
return False
except ValueError:
instance_type = getattr(self, "instance_type", "Unknown")
logger.debug("Unable to determine single GPU size for instance %s", instance_type)
return False
# ========================================
# Serialization Utils
# ========================================
def _extract_framework_from_image_uri(self) -> Tuple[Optional[Framework], Optional[str]]:
"""Extract framework and version information from SageMaker image URI.
Analyzes the container image URI to determine the ML framework
and version being used.
Returns:
Tuple of (Framework enum, version string). Both can be None if not detected.
"""
image_uri = getattr(self, "image_uri", None)
if not image_uri:
return None, None
if "pytorch-inference" in image_uri or "pytorch-training" in image_uri:
version_match = re.search(r"pytorch.*:(\d+\.\d+\.\d+)", image_uri)
return Framework.PYTORCH, version_match.group(1) if version_match else None
elif "tensorflow-inference" in image_uri or "tensorflow-training" in image_uri:
version_match = re.search(r"tensorflow.*:(\d+\.\d+\.\d+)", image_uri)
return Framework.TENSORFLOW, version_match.group(1) if version_match else None
elif "sagemaker-xgboost" in image_uri:
version_match = re.search(r"sagemaker-xgboost:(\d+\.\d+)", image_uri)
return Framework.XGBOOST, version_match.group(1) if version_match else None
elif "sagemaker-scikit-learn" in image_uri:
version_match = re.search(r"scikit-learn:(\d+\.\d+)", image_uri)
return Framework.SKLEARN, version_match.group(1) if version_match else None
elif "huggingface" in image_uri:
return Framework.HUGGINGFACE, None
elif "mxnet" in image_uri:
version_match = re.search(r"mxnet.*:(\d+\.\d+\.\d+)", image_uri)
return Framework.MXNET, version_match.group(1) if version_match else None
return None, None
def _fetch_serializer_and_deserializer_for_framework(self, framework: str) -> Tuple[Any, Any]:
"""Fetch default serializer and deserializer for a framework.
Args:
framework: Framework name as string.
Returns:
Tuple containing (serializer, deserializer) instances.
Defaults to (NumpySerializer, JSONDeserializer) if framework not found.
"""
framework_enum = self._normalize_framework_to_enum(framework)
if framework_enum and framework_enum in DEFAULT_SERIALIZERS_BY_FRAMEWORK:
return DEFAULT_SERIALIZERS_BY_FRAMEWORK[framework_enum]
return NumpySerializer(), JSONDeserializer()
def _normalize_framework_to_enum(
self, framework: Union[str, Framework, None]
) -> Optional[Framework]:
"""Convert any framework input to Framework enum.
Args:
framework: Framework as string, enum, or None
Returns:
Framework enum or None if not found/None input
"""
if framework is None:
return None
if isinstance(framework, Framework):
return framework
if not isinstance(framework, str):
return None
framework_mapping = {
"xgboost": Framework.XGBOOST,
"xgb": Framework.XGBOOST,
"pytorch": Framework.PYTORCH,
"torch": Framework.PYTORCH,
"tensorflow": Framework.TENSORFLOW,
"tf": Framework.TENSORFLOW,
"sklearn": Framework.SKLEARN,
"scikit-learn": Framework.SKLEARN,
"scikit_learn": Framework.SKLEARN,
"sk-learn": Framework.SKLEARN,
"huggingface": Framework.HUGGINGFACE,
"hf": Framework.HUGGINGFACE,
"transformers": Framework.HUGGINGFACE,
"mxnet": Framework.MXNET,
"chainer": Framework.CHAINER,
"djl": Framework.DJL,
"sparkml": Framework.SPARKML,
"spark": Framework.SPARKML,
"lda": Framework.LDA,
"ntm": Framework.NTM,
"smd": Framework.SMD,
"sagemaker-distribution": Framework.SMD,
}
return framework_mapping.get(framework.lower())
# ========================================
# MLflow Utils
# ========================================
def _handle_mlflow_input(self) -> None:
"""Check and handle MLflow model input if present.
Detects MLflow model arguments, validates metadata existence,
and initializes MLflow-specific configurations.
"""
self._is_mlflow_model = self._has_mlflow_arguments()
if not self._is_mlflow_model:
return
model_metadata = getattr(self, "model_metadata", {})
mlflow_model_path = model_metadata.get(MLFLOW_MODEL_PATH)
if not mlflow_model_path:
return
artifact_path = self._get_artifact_path(mlflow_model_path)
if not self._mlflow_metadata_exists(artifact_path):
return
self._initialize_for_mlflow(artifact_path)
model_server = getattr(self, "model_server", None)
env_vars = getattr(self, "env_vars", {}) or {}
_validate_input_for_mlflow(model_server, env_vars.get("MLFLOW_MODEL_FLAVOR"))
def _has_mlflow_arguments(self) -> bool:
"""Check whether MLflow model arguments are present.
Returns:
True if MLflow arguments are present and should be handled, False otherwise.
"""
inference_spec = getattr(self, "inference_spec", None)
model = getattr(self, "model", None)
if inference_spec or model:
logger.debug(
"Either inference spec or model is provided. "
"ModelBuilder is not handling MLflow model input"
)
return False
model_metadata = getattr(self, "model_metadata", None)
if not model_metadata:
logger.debug(
"No ModelMetadata provided. ModelBuilder is not handling MLflow model input"
)
return False
mlflow_model_path = model_metadata.get(MLFLOW_MODEL_PATH)
if not mlflow_model_path:
logger.debug(
"%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
"input",
MLFLOW_MODEL_PATH,
)
return False
return True
def _get_artifact_path(self, mlflow_model_path: str) -> str:
"""Retrieve model artifact location from MLflow model path.
Handles different MLflow path formats:
- Run ID paths: runs:/<run_id>/<model_path>
- Registry paths: models:/<model_name>/<version_or_alias>
- Model package ARNs
- Direct file paths
Args:
mlflow_model_path: MLflow model path input.
Returns:
Path to the model artifact.
Raises:
ValueError: If tracking ARN not provided for run/registry paths.
ImportError: If sagemaker_mlflow not installed.
"""
is_run_id_type = re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path)
is_registry_type = re.match(MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path)
if is_run_id_type or is_registry_type:
model_metadata = getattr(self, "model_metadata", {})
mlflow_tracking_arn = model_metadata.get(MLFLOW_TRACKING_ARN)
if not mlflow_tracking_arn:
raise ValueError(
f"{MLFLOW_TRACKING_ARN} is not provided in ModelMetadata or through set_tracking_arn "
f"but MLflow model path was provided."
)
if not importlib.util.find_spec("sagemaker_mlflow"):
raise ImportError(
"Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
)
import mlflow
mlflow.set_tracking_uri(mlflow_tracking_arn)
if is_run_id_type:
_, run_id, model_path = mlflow_model_path.split("/", 2)
artifact_uri = mlflow.get_run(run_id).info.artifact_uri
if not artifact_uri.endswith("/"):
artifact_uri += "/"
return artifact_uri + model_path
# Registry path handling
mlflow_client = mlflow.MlflowClient()
if not mlflow_model_path.endswith("/"):
mlflow_model_path += "/"
if "@" in mlflow_model_path:
_, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2)
model_name, model_alias = model_name_and_alias.split("@")
model_version_info = mlflow_client.get_model_version_by_alias(
model_name, model_alias
)
source = mlflow_client.get_model_version_download_uri(
model_name, model_version_info.version
)
else:
_, model_name, model_version, artifact_uri = mlflow_model_path.split("/", 3)
source = mlflow_client.get_model_version_download_uri(model_name, model_version)
if not source.endswith("/"):
source += "/"
return source + artifact_uri
# Handle model package ARN
if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path):
sagemaker_session = getattr(self, "sagemaker_session", None)
if sagemaker_session:
model_package = sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=mlflow_model_path
)
return model_package["SourceUri"]
# Direct path
return mlflow_model_path
def _mlflow_metadata_exists(self, path: str) -> bool:
"""Check whether MLmodel metadata file exists in the given directory.
Args:
path: Directory path to check (local or S3).
Returns:
True if MLmodel file exists, False otherwise.
"""
if path.startswith("s3://"):
s3_downloader = S3Downloader()
if not path.endswith("/"):
path += "/"
s3_uri_to_mlmodel_file = f"{path}{MLFLOW_METADATA_FILE}"
sagemaker_session = getattr(self, "sagemaker_session", None)
if not sagemaker_session:
return False
response = s3_downloader.list(s3_uri_to_mlmodel_file, sagemaker_session)
return len(response) > 0
file_path = os.path.join(path, MLFLOW_METADATA_FILE)
return os.path.isfile(file_path)
def _initialize_for_mlflow(self, artifact_path: str) -> None:
"""Initialize MLflow model artifacts, image URI and model server.
Downloads artifacts, extracts metadata, and configures model server
and container image for MLflow model deployment.
Args:
artifact_path: Path to the MLflow artifact store.
Raises:
ValueError: If artifact path is invalid.
"""
model_path = getattr(self, "model_path", None)
sagemaker_session = getattr(self, "sagemaker_session", None)
if artifact_path.startswith("s3://"):
_download_s3_artifacts(artifact_path, model_path, sagemaker_session)
elif os.path.exists(artifact_path):
_copy_directory_contents(artifact_path, model_path)
else:
raise ValueError(f"Invalid path: {artifact_path}")
mlflow_model_metadata_path = _generate_mlflow_artifact_path(
model_path, MLFLOW_METADATA_FILE
)
mlflow_model_dependency_path = _generate_mlflow_artifact_path(
model_path, MLFLOW_PIP_DEPENDENCY_FILE
)
flavor_metadata = _get_all_flavor_metadata(mlflow_model_metadata_path)
deployment_flavor = _get_deployment_flavor(flavor_metadata)
current_model_server = getattr(self, "model_server", None)
self.model_server = current_model_server or _get_default_model_server_for_mlflow(
deployment_flavor
)
current_image_uri = getattr(self, "image_uri", None)
if not current_image_uri:
self.image_uri = _select_container_for_mlflow_model(
mlflow_model_src_path=model_path,
deployment_flavor=deployment_flavor,
region=sagemaker_session.boto_region_name if sagemaker_session else None,
instance_type=getattr(self, "instance_type", None),
)
env_vars = getattr(self, "env_vars", {})
env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"})
dependencies = getattr(self, "dependencies", {})
dependencies.update({"requirements": mlflow_model_dependency_path})
# ========================================
# Optimize Utils
# ========================================
def _is_inferentia_or_trainium(self, instance_type: Optional[str]) -> bool:
"""Checks whether an instance is compatible with Inferentia.
Args:
instance_type (str): The instance type used for the compilation job.
Returns:
bool: Whether the given instance type is Inferentia or Trainium.
"""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
if match:
if match[1].startswith("inf") or match[1].startswith("trn"):
return True
return False
def _is_image_compatible_with_optimization_job(self, image_uri: Optional[str]) -> bool:
"""Checks whether an instance is compatible with an optimization job.
Args:
image_uri (str): The image URI of the optimization job.
Returns:
bool: Whether the given instance type is compatible with an optimization job.
"""
if image_uri is None:
return True
return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri)
def _deployment_config_contains_draft_model(self, deployment_config: Optional[Dict]) -> bool:
"""Checks whether a deployment config contains a speculative decoding draft model.
Args:
deployment_config (Dict): The deployment config to check.
Returns:
bool: Whether the deployment config contains a draft model or not.
"""
if deployment_config is None:
return False
deployment_args = deployment_config.get("DeploymentArgs", {})
additional_data_sources = deployment_args.get("AdditionalDataSources")
return (
"speculative_decoding" in additional_data_sources if additional_data_sources else False
)
def _is_draft_model_jumpstart_provided(self, deployment_config: Optional[Dict]) -> bool:
"""Checks whether a deployment config's draft model is provided by JumpStart.
Args:
deployment_config (Dict): The deployment config to check.
Returns:
bool: Whether the draft model is provided by JumpStart or not.
"""
if deployment_config is None:
return False
additional_model_data_sources = deployment_config.get("DeploymentArgs", {}).get(
"AdditionalDataSources"
)
for source in additional_model_data_sources.get("speculative_decoding", []):
if source["channel_name"] == "draft_model":
if source.get("provider", {}).get("name") == "JumpStart":
return True
continue
return False
def _generate_optimized_model(self, optimization_response: dict):
"""Generates a new optimization model.
Args:
pysdk_model (Model): A PySDK model.
optimization_response (dict): The optimization response.
Returns:
Model: A deployable optimized model.
"""
recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get(
"RecommendedInferenceImage"
)
s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation")
deployment_instance_type = optimization_response.get("DeploymentInstanceType")
if recommended_image_uri:
self.image_uri = recommended_image_uri
if s3_uri:
self.s3_upload_path["S3DataSource"]["S3Uri"] = s3_uri
if deployment_instance_type:
self.instance_type = deployment_instance_type
self.add_tags(
{
"Key": Tag.OPTIMIZATION_JOB_NAME,
"Value": optimization_response["OptimizationJobName"],
}
)
def _is_optimized(self) -> bool:
"""Checks whether an optimization model is optimized.
Args:
pysdk_model (Model): A PySDK model.
Return:
bool: Whether the given model type is optimized.
"""
optimized_tags = [Tag.OPTIMIZATION_JOB_NAME, Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER]
if hasattr(self, "_tags") and self._tags:
if isinstance(self._tags, dict):
return self._tags.get("Key") in optimized_tags
for tag in self._tags:
if tag.get("Key") in optimized_tags:
return True
return False
def _generate_model_source(
self, model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool]
) -> Optional[Dict[str, Any]]:
"""Extracts model source from model data.
Args:
model_data (Optional[Union[Dict[str, Any], str]]): A model data.
Returns:
Optional[Dict[str, Any]]: Model source data.
"""
if model_data is None:
raise ValueError("Model Optimization Job only supports model with S3 data source.")
s3_uri = model_data
if isinstance(s3_uri, dict):
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
model_source = {"S3": {"S3Uri": s3_uri}}
if accept_eula:
model_source["S3"]["ModelAccessConfig"] = {"AcceptEula": True}
return model_source
def _update_environment_variables(
self, env: Optional[Dict[str, str]], new_env: Optional[Dict[str, str]]
) -> Optional[Dict[str, str]]:
"""Updates environment variables based on environment variables.
Args:
env (Optional[Dict[str, str]]): The environment variables.
new_env (Optional[Dict[str, str]]): The new environment variables.
Returns:
Optional[Dict[str, str]]: The updated environment variables.
"""
if new_env:
if env:
env.update(new_env)
else:
env = new_env
return env
def _extract_speculative_draft_model_provider(
self,
speculative_decoding_config: Optional[Dict] = None,
) -> Optional[str]:
"""Extracts speculative draft model provider from speculative decoding config.
Args:
speculative_decoding_config (Optional[Dict]): A speculative decoding config.
Returns:
Optional[str]: The speculative draft model provider.
"""
if speculative_decoding_config is None:
return None
model_provider = speculative_decoding_config.get("ModelProvider", "").lower()
if model_provider == "jumpstart":
return "jumpstart"
if model_provider == "custom" or speculative_decoding_config.get("ModelSource"):
return "custom"
if model_provider == "sagemaker":
return "sagemaker"
return "auto"
def _extract_additional_model_data_source_s3_uri(
self,
additional_model_data_source: Optional[Dict] = None,
) -> Optional[str]:
"""Extracts model data source s3 uri from a model data source in Pascal case.
Args:
additional_model_data_source (Optional[Dict]): A model data source.
Returns:
str: S3 uri of the model resources.
"""
if (
additional_model_data_source is None
or additional_model_data_source.get("S3DataSource", None) is None
):
return None
return additional_model_data_source.get("S3DataSource").get("S3Uri")
def _extract_deployment_config_additional_model_data_source_s3_uri(
self,
additional_model_data_source: Optional[Dict] = None,
) -> Optional[str]:
"""Extracts model data source s3 uri from a model data source in snake case.
Args:
additional_model_data_source (Optional[Dict]): A model data source.
Returns:
str: S3 uri of the model resources.
"""
if (
additional_model_data_source is None
or additional_model_data_source.get("s3_data_source", None) is None
):
return None
return additional_model_data_source.get("s3_data_source").get("s3_uri", None)
def _is_draft_model_gated(
self,
draft_model_config: Optional[Dict] = None,
) -> bool:
"""Extracts model gated-ness from draft model data source.
Args:
draft_model_config (Optional[Dict]): A model data source.
Returns:
bool: Whether the draft model is gated or not.
"""
return "hosting_eula_key" in draft_model_config if draft_model_config else False
def _extracts_and_validates_speculative_model_source(
self,
speculative_decoding_config: Dict,
) -> str:
"""Extracts model source from speculative decoding config.
Args:
speculative_decoding_config (Optional[Dict]): A speculative decoding config.
Returns:
str: Model source.
Raises:
ValueError: If model source is none.
"""
model_source: str = speculative_decoding_config.get("ModelSource")
if not model_source:
raise ValueError("ModelSource must be provided in speculative decoding config.")
return model_source
def _generate_channel_name(self, additional_model_data_sources: Optional[List[Dict]]) -> str:
"""Generates a channel name.
Args:
additional_model_data_sources (Optional[List[Dict]]): The additional model data sources.
Returns:
str: The channel name.
"""
channel_name = "draft_model"
if additional_model_data_sources and len(additional_model_data_sources) > 0:
channel_name = additional_model_data_sources[0].get("ChannelName", channel_name)
return channel_name
def _generate_additional_model_data_sources(
self,
model_source: str,
channel_name: str,
accept_eula: bool = False,
s3_data_type: Optional[str] = "S3Prefix",
compression_type: Optional[str] = "None",
) -> List[Dict]:
"""Generates additional model data sources.
Args:
model_source (Optional[str]): The model source.
channel_name (Optional[str]): The channel name.
accept_eula (Optional[bool]): Whether to accept eula or not.
s3_data_type (Optional[str]): The S3 data type, defaults to 'S3Prefix'.
compression_type (Optional[str]): The compression type, defaults to None.
Returns:
List[Dict]: The additional model data sources.
"""
additional_model_data_source = {
"ChannelName": channel_name,
"S3DataSource": {
"S3Uri": model_source,
"S3DataType": s3_data_type,
"CompressionType": compression_type,
},
}
if accept_eula:
additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True}
return [additional_model_data_source]
def _is_s3_uri(self, s3_uri: Optional[str]) -> bool:
"""Checks whether an S3 URI is valid.
Args:
s3_uri (Optional[str]): The S3 URI.
Returns:
bool: Whether the S3 URI is valid.
"""
if s3_uri is None:
return False
return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None
def _extract_optimization_config_and_env(
self,
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
"""Extracts optimization config and environment variables.
Args:
quantization_config (Optional[Dict]): The quantization config.
compilation_config (Optional[Dict]): The compilation config.
sharding_config (Optional[Dict]): The sharding config.
Returns:
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
The optimization config and environment variables.
"""
optimization_config = {}
quantization_override_env = (
quantization_config.get("OverrideEnvironment") if quantization_config else None
)
compilation_override_env = (
compilation_config.get("OverrideEnvironment") if compilation_config else None
)
sharding_override_env = (
sharding_config.get("OverrideEnvironment") if sharding_config else None
)
if quantization_config is not None:
optimization_config["ModelQuantizationConfig"] = quantization_config
if compilation_config is not None:
optimization_config["ModelCompilationConfig"] = compilation_config
if sharding_config is not None:
optimization_config["ModelShardingConfig"] = sharding_config
# Return optimization config dict and environment variables if either is present
if optimization_config:
return (
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
)
return None, None, None, None
def _custom_speculative_decoding(
self,
speculative_decoding_config: Optional[Dict],
accept_eula: Optional[bool] = False,
):
"""Modifies the given model for speculative decoding config with custom provider.
Args:
model (Model): The model.
speculative_decoding_config (Optional[Dict]): The speculative decoding config.
accept_eula (Optional[bool]): Whether to accept eula or not.
"""
if speculative_decoding_config:
additional_model_source = self._extracts_and_validates_speculative_model_source(
speculative_decoding_config
)
accept_eula = speculative_decoding_config.get("AcceptEula", accept_eula)
if self._is_s3_uri(additional_model_source):
channel_name = self._generate_channel_name(self.additional_model_data_sources)
speculative_draft_model = f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"
self.additional_model_data_sources = self._generate_additional_model_data_sources(
additional_model_source, channel_name, accept_eula
)
else:
speculative_draft_model = additional_model_source
self.env_vars.update({"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model})
self.add_tags(
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"},
)
def _get_cached_model_specs(self, model_id, version, region, sagemaker_session):
"""Get cached JumpStart model specs to avoid repeated fetches"""
if not hasattr(self, "_cached_js_model_specs"):
self._cached_js_model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
model_id=model_id,
version=version,
region=region,
sagemaker_session=sagemaker_session,
)
return self._cached_js_model_specs
def _jumpstart_speculative_decoding(
self,
speculative_decoding_config: Optional[Dict[str, Any]] = None,
sagemaker_session: Optional[Session] = None,
):
"""Modifies the given model for speculative decoding config with JumpStart provider.
Args:
model (Model): The model.
speculative_decoding_config (Optional[Dict]): The speculative decoding config.
sagemaker_session (Optional[Session]): Sagemaker session for execution.
"""
if speculative_decoding_config:
js_id = speculative_decoding_config.get("ModelID")
if not js_id:
raise ValueError(
"`ModelID` is a required field in `speculative_decoding_config` when "
"using JumpStart as draft model provider."
)
model_version = speculative_decoding_config.get("ModelVersion", "*")
accept_eula = speculative_decoding_config.get("AcceptEula", False)
channel_name = self._generate_channel_name(self.additional_model_data_sources)
model_specs = self._get_cached_model_specs(
model_id=js_id,
version=model_version,
region=sagemaker_session.boto_region_name,
sagemaker_session=sagemaker_session,
)
model_spec_json = model_specs.to_json()
js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket(self.region)
if model_spec_json.get("gated_bucket", False):
if not accept_eula:
eula_message = get_eula_message(
model_specs=model_specs, region=sagemaker_session.boto_region_name
)
raise ValueError(
f"{eula_message} Set `AcceptEula`=True in "
f"speculative_decoding_config once acknowledged."
)
js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket(
self.region
)
key_prefix = model_spec_json.get("hosting_prepacked_artifact_key")
self.additional_model_data_sources = self._generate_additional_model_data_sources(
f"s3://{js_bucket}/{key_prefix}",
channel_name,
accept_eula,
)
self.env_vars.update(
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"}
)
self.add_tags(
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"},
)
def _optimize_for_hf(
self,
output_path: str,
tags: Optional[Tags] = None,
job_name: Optional[str] = None,
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
max_runtime_in_sec: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""Runs a model optimization job.
Args:
output_path (str): Specifies where to store the compiled/quantized model.
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading
to S3. Defaults to ``None``.
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
``None``.
Returns:
Optional[Dict[str, Any]]: Model optimization job input arguments.
"""
if speculative_decoding_config:
if speculative_decoding_config.get("ModelProvider", "").lower() == "jumpstart":
self._jumpstart_speculative_decoding(
speculative_decoding_config=speculative_decoding_config,
sagemaker_session=self.sagemaker_session,
)
else:
self._custom_speculative_decoding(speculative_decoding_config, False)
if quantization_config or compilation_config or sharding_config:
create_optimization_job_args = {
"OptimizationJobName": job_name,
"DeploymentInstanceType": self.instance_type,
"RoleArn": self.role_arn,
}
if env_vars:
self.env_vars.update(env_vars)
create_optimization_job_args["OptimizationEnvironment"] = env_vars
self._optimize_prepare_for_hf()
model_source = self._generate_model_source(self.s3_upload_path, False)
create_optimization_job_args["ModelSource"] = model_source
(
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
) = self._extract_optimization_config_and_env(
quantization_config, compilation_config, sharding_config
)
create_optimization_job_args["OptimizationConfigs"] = [
{k: v} for k, v in optimization_config.items()
]
self.env_vars.update(
{
**(quantization_override_env or {}),
**(compilation_override_env or {}),
**(sharding_override_env or {}),
}
)
output_config = {"S3OutputLocation": output_path}
if kms_key:
output_config["KmsKeyId"] = kms_key
create_optimization_job_args["OutputConfig"] = output_config
if max_runtime_in_sec:
create_optimization_job_args["StoppingCondition"] = {
"MaxRuntimeInSeconds": max_runtime_in_sec
}
if tags:
create_optimization_job_args["Tags"] = tags
if vpc_config:
create_optimization_job_args["VpcConfig"] = vpc_config
if "HF_MODEL_ID" in self.env_vars:
del self.env_vars["HF_MODEL_ID"]
return create_optimization_job_args
return None
def _optimize_prepare_for_hf(self):
"""Prepare huggingface model data for optimization."""
custom_model_path: str = (
self.model_metadata.get("CUSTOM_MODEL_PATH") if self.model_metadata else None
)
if self._is_s3_uri(custom_model_path):
custom_model_path = (
custom_model_path[:-1] if custom_model_path.endswith("/") else custom_model_path
)
else:
if not custom_model_path:
custom_model_path = f"/tmp/sagemaker/model-builder/{self.model}"
self.download_huggingface_model_metadata(
self.model,
os.path.join(custom_model_path, "code"),
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"),
)
self.s3_upload_path, env = self._prepare_for_mode(
model_path=custom_model_path,
should_upload_artifacts=True,
)
self.env_vars.update(env)
def _is_gated_model(self) -> bool:
"""Determine if ``this`` Model is Gated
Args:
model (Model): Jumpstart Model
Returns:
bool: ``True`` if ``this`` Model is Gated
"""
s3_uri = self.s3_upload_path
if isinstance(s3_uri, dict):
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
if s3_uri is None:
return False
return "private" in s3_uri
def set_js_deployment_config(self, config_name: str, instance_type: str) -> None:
"""Sets the deployment config to apply to the model.
Args:
config_name (str):
The name of the deployment config to apply to the model.
Call list_deployment_configs to see the list of config names.
instance_type (str):
The instance_type that the model will use after setting
the config.
"""
self.set_deployment_config(config_name, instance_type)
self.deployment_config_name = config_name
self.instance_type = instance_type
if self.additional_model_data_sources:
self.speculative_decoding_draft_model_source = "sagemaker"
self.add_tags(
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"},
)
self.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME)
self.remove_tag_with_key(Tag.FINE_TUNING_MODEL_PATH)
self.remove_tag_with_key(Tag.FINE_TUNING_JOB_NAME)
def _set_additional_model_source(
self, speculative_decoding_config: Optional[Dict[str, Any]] = None
) -> None:
"""Set Additional Model Source to ``this`` model.
Args:
speculative_decoding_config (Optional[Dict[str, Any]]): Speculative decoding config.
accept_eula (Optional[bool]): For models that require a Model Access Config.
"""
if speculative_decoding_config:
model_provider = self._extract_speculative_draft_model_provider(
speculative_decoding_config
)
channel_name = self._generate_channel_name(self.additional_model_data_sources)
if model_provider in ["sagemaker", "auto"]:
additional_model_data_sources = (
self._deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources")
if self._deployment_config
else None
)
if additional_model_data_sources is None:
deployment_config = self._find_compatible_deployment_config(
speculative_decoding_config
)
if deployment_config:
if (
model_provider == "sagemaker"
and self._is_draft_model_jumpstart_provided(deployment_config)
):
raise ValueError(
"No `Sagemaker` provided draft model was found for "
f"{self.model}. Try setting `ModelProvider` "
"to `Auto` instead."
)
try:
self.set_js_deployment_config(
config_name=deployment_config.get("DeploymentConfigName"),
instance_type=deployment_config.get("InstanceType"),
)
except ValueError as e:
raise ValueError(
f"{e} If using speculative_decoding_config, "
"accept the EULA by setting `AcceptEula`=True."
)
else:
raise ValueError(
"Cannot find deployment config compatible for optimization job."
)
else:
if model_provider == "sagemaker" and self._is_draft_model_jumpstart_provided(
self._deployment_config
):
raise ValueError(
"No `Sagemaker` provided draft model was found for "
f"{self.model}. Try setting `ModelProvider` "
"to `Auto` instead."
)
self.env_vars.update(
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"}
)
self.add_tags(
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": model_provider},
)
elif model_provider == "jumpstart":
self._jumpstart_speculative_decoding(
speculative_decoding_config=speculative_decoding_config,
sagemaker_session=self.sagemaker_session,
)
else:
self._custom_speculative_decoding(
speculative_decoding_config,
speculative_decoding_config.get("AcceptEula", False),
)
def _find_compatible_deployment_config(
self, speculative_decoding_config: Optional[Dict] = None
) -> Optional[Dict[str, Any]]:
"""Finds compatible model deployment config for optimization job.
Args:
speculative_decoding_config (Optional[Dict]): Speculative decoding config.
Returns:
Optional[Dict[str, Any]]: A compatible model deployment config for optimization job.
"""
self._ensure_metadata_configs()
model_provider = self._extract_speculative_draft_model_provider(speculative_decoding_config)
for deployment_config in self.list_deployment_configs():
image_uri = deployment_config.get("deployment_config", {}).get("ImageUri")
if self._is_image_compatible_with_optimization_job(
image_uri
) and self._deployment_config_contains_draft_model(deployment_config):
if (
model_provider in ["sagemaker", "auto"]
and deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources")
) or model_provider == "custom":
return deployment_config
if model_provider in ["sagemaker", "auto"]:
return None
return self._deployment_config
def _get_neuron_model_env_vars(
self, instance_type: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Gets Neuron model env vars.
Args:
instance_type (Optional[str]): Instance type.
Returns:
Optional[Dict[str, Any]]: Neuron Model environment variables.
"""
metadata_configs = self._metadata_configs
if metadata_configs:
metadata_config = metadata_configs.get(self.config_name)
resolve_config = metadata_config.resolved_config if metadata_config else None
if resolve_config and instance_type not in resolve_config.get(
"supported_inference_instance_types", []
):
neuro_model_id = resolve_config.get("hosting_neuron_model_id")
neuro_model_version = resolve_config.get("hosting_neuron_model_version", "*")
if neuro_model_id:
model_specs = self._get_cached_model_specs(
model_id=neuro_model_id,
version=neuro_model_version,
region=self.region,
sagemaker_session=self.sagemaker_session,
)
model_spec_json = model_specs.to_json()
return model_spec_json.get("hosting_env_vars", {})
return None
def _set_optimization_image_default(
self, create_optimization_job_args: Dict[str, Any]
) -> Dict[str, Any]:
"""Defaults the optimization image to the JumpStart deployment config default
Args:
create_optimization_job_args (Dict[str, Any]): create optimization job request
Returns:
Dict[str, Any]: create optimization job request with image uri default
"""
init_kwargs = get_init_kwargs(
config_name=self.config_name,
model_id=self.model,
instance_type=self.instance_type,
sagemaker_session=self.sagemaker_session,
image_uri=self.image_uri,
region=self.region,
model_version=self.model_version,
hub_arn=self.hub_arn,
tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None),
tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None),
)
default_image = self._get_default_vllm_image(init_kwargs.image_uri)
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
if optimization_config.get("ModelQuantizationConfig"):
model_quantization_config = optimization_config.get("ModelQuantizationConfig")
provided_image = model_quantization_config.get("Image")
if provided_image and self._get_latest_lmi_version_from_list(
default_image, provided_image
):
default_image = provided_image
if optimization_config.get("ModelShardingConfig"):
model_sharding_config = optimization_config.get("ModelShardingConfig")
provided_image = model_sharding_config.get("Image")
if provided_image and self._get_latest_lmi_version_from_list(
default_image, provided_image
):
default_image = provided_image
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
if optimization_config.get("ModelQuantizationConfig") is not None:
optimization_config.get("ModelQuantizationConfig")["Image"] = default_image
if optimization_config.get("ModelShardingConfig") is not None:
optimization_config.get("ModelShardingConfig")["Image"] = default_image
logger.debug(f"Defaulting to {default_image} image for optimization job")
return create_optimization_job_args
def _get_default_vllm_image(self, image: str) -> bool:
"""Ensures the minimum working image version for vLLM enabled optimization techniques
Args:
image (str): JumpStart provided default image
Returns:
str: minimum working image version
"""
dlc_name, _ = image.split(":")
major_version_number, _, _ = self._parse_lmi_version(image)
if major_version_number < self._parse_lmi_version(_JS_MINIMUM_VERSION_IMAGE)[0]:
minimum_version_default = _JS_MINIMUM_VERSION_IMAGE.format(dlc_name)
return minimum_version_default
return image
def _get_latest_lmi_version_from_list(self, version: str, version_to_compare: str) -> bool:
"""LMI version comparator
Args:
version (str): current version
version_to_compare (str): version to compare to
Returns:
bool: if version_to_compare larger or equal to version
"""
parse_lmi_version = self._parse_lmi_version(version)
parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare)
if parse_lmi_version_to_compare[0] > parse_lmi_version[0]:
return True
if parse_lmi_version_to_compare[0] == parse_lmi_version[0]:
if parse_lmi_version_to_compare[1] > parse_lmi_version[1]:
return True
if parse_lmi_version_to_compare[1] == parse_lmi_version[1]:
if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]:
return True
return False
return False
return False
def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]:
"""Parse out LMI version
Args:
image (str): image to parse version out of
Returns:
Tuple[int, int, int]: LMI version split into major, minor, patch
Raises:
ValueError: If the image format cannot be parsed
"""
_, dlc_tag = image.split(":")
parts = dlc_tag.split("-")
lmi_version = None
for part in parts:
if "." in part and part[0].isdigit():
lmi_version = part
break
if not lmi_version:
raise ValueError(f"Could not find version in image: {image}")
version_parts = lmi_version.split(".")
if len(version_parts) < 3:
raise ValueError(f"Invalid version format: {lmi_version} in image: {image}")
major_version = int(version_parts[0])
minor_version = int(version_parts[1])
patch_version = int(version_parts[2])
return (major_version, minor_version, patch_version)
def _optimize_for_jumpstart(
self,
output_path: Optional[str] = None,
instance_type: Optional[str] = None,
tags: Optional[Tags] = None,
job_name: Optional[str] = None,
accept_eula: Optional[bool] = None,
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
max_runtime_in_sec: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""Runs a model optimization job.
Args:
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
instance_type (str): Target deployment instance type that the model is optimized for.
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading
to S3. Defaults to ``None``.
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
``None``.
Returns:
Dict[str, Any]: Model optimization job input arguments.
"""
if self._is_gated_model() and accept_eula is not True:
raise ValueError(
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
)
is_compilation = (compilation_config is not None) or self._is_inferentia_or_trainium(
instance_type
)
env_vars = dict()
if is_compilation:
env_vars = self._get_neuron_model_env_vars(instance_type)
(
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
) = self._extract_optimization_config_and_env(
quantization_config, compilation_config, sharding_config
)
if not optimization_config:
optimization_config = {}
if not optimization_config.get("ModelCompilationConfig") and is_compilation:
if not compilation_override_env:
compilation_override_env = env_vars
override_compilation_config = (
{"OverrideEnvironment": compilation_override_env}
if compilation_override_env
else {}
)
optimization_config["ModelCompilationConfig"] = override_compilation_config
if speculative_decoding_config:
self._set_additional_model_source(speculative_decoding_config)
else:
deployment_config = self._find_compatible_deployment_config(None)
if deployment_config:
self.set_js_deployment_config(
config_name=deployment_config.get("DeploymentConfigName"),
instance_type=deployment_config.get("InstanceType"),
)
env_vars = self.env_vars
model_source = self._generate_model_source(self.s3_upload_path, accept_eula)
optimization_env_vars = self._update_environment_variables(env_vars, env_vars)
output_config = {"S3OutputLocation": output_path}
if kms_key:
output_config["KmsKeyId"] = kms_key
deployment_config_instance_type = (
self._deployment_config.get("DeploymentArgs", {}).get("InstanceType")
if self._deployment_config
else None
)
self.instance_type = (
instance_type or deployment_config_instance_type or self._get_nb_instance()
)
create_optimization_job_args = {
"OptimizationJobName": job_name,
"ModelSource": model_source,
"DeploymentInstanceType": self.instance_type,
"OptimizationConfigs": [{k: v} for k, v in optimization_config.items()],
"OutputConfig": output_config,
"RoleArn": self.role_arn,
}
if optimization_env_vars:
create_optimization_job_args["OptimizationEnvironment"] = optimization_env_vars
if max_runtime_in_sec:
create_optimization_job_args["StoppingCondition"] = {
"MaxRuntimeInSeconds": max_runtime_in_sec
}
if tags:
create_optimization_job_args["Tags"] = tags
if vpc_config:
create_optimization_job_args["VpcConfig"] = vpc_config
if accept_eula:
self.accept_eula = accept_eula
if isinstance(self.s3_upload_path, dict):
self.s3_upload_path["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True}
optimization_env_vars = self._update_environment_variables(
optimization_env_vars,
{
**(quantization_override_env or {}),
**(compilation_override_env or {}),
**(sharding_override_env or {}),
},
)
if optimization_env_vars:
self.env_vars.update(optimization_env_vars)
if sharding_config and self._enable_network_isolation:
logger.warning(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access. Setting it to False."
)
self._enable_network_isolation = False
if quantization_config or sharding_config or is_compilation:
return (
create_optimization_job_args
if is_compilation
else self._set_optimization_image_default(create_optimization_job_args)
)
return None
def _generate_optimized_core_model(self, optimization_response: dict) -> Model:
"""Generate optimized CoreModel from optimization job response."""
recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get(
"RecommendedInferenceImage"
)
s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation")
deployment_instance_type = optimization_response.get("DeploymentInstanceType")
if recommended_image_uri:
self.image_uri = recommended_image_uri
if s3_uri:
if isinstance(self.s3_upload_path, dict):
self.s3_upload_path["S3DataSource"]["S3Uri"] = s3_uri
else:
self.s3_upload_path = s3_uri
if deployment_instance_type:
self.instance_type = deployment_instance_type
self.add_tags(
{"Key": "OptimizationJobName", "Value": optimization_response["OptimizationJobName"]}
)
self._optimizing = False
optimized_core_model = self._create_model()
self.built_model = optimized_core_model
return optimized_core_model
def deployment_config_response_data(
self,
deployment_configs: Optional[List[DeploymentConfigMetadata]],
) -> List[Dict[str, Any]]:
"""Deployment config api response data.
Args:
deployment_configs (Optional[List[DeploymentConfigMetadata]]):
List of deployment configs metadata.
Returns:
List[Dict[str, Any]]: List of deployment config api response data.
"""
configs = []
if not deployment_configs:
return configs
for deployment_config in deployment_configs:
deployment_config_json = deployment_config.to_json()
benchmark_metrics = deployment_config_json.get("BenchmarkMetrics")
if benchmark_metrics and deployment_config.deployment_args:
deployment_config_json["BenchmarkMetrics"] = {
deployment_config.deployment_args.instance_type: benchmark_metrics.get(
deployment_config.deployment_args.instance_type
)
}
configs.append(deployment_config_json)
return configs
# @_deployment_config_lru_cache
def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]:
"""Deployment configs benchmark metrics.
Returns:
Dict[str, List[str]]: Deployment config benchmark data.
"""
return get_metrics_from_deployment_configs(
self._get_deployment_configs(None, None),
)
# @_deployment_config_lru_cache
def _get_deployment_configs(
self, selected_config_name: Optional[str], selected_instance_type: Optional[str]
) -> List[DeploymentConfigMetadata]:
"""Retrieve deployment configs metadata.
Args:
selected_config_name (Optional[str]): The name of the selected deployment config.
selected_instance_type (Optional[str]): The selected instance type.
"""
deployment_configs = []
if not self._metadata_configs:
return deployment_configs
err = None
for config_name, metadata_config in self._metadata_configs.items():
if selected_config_name == config_name:
instance_type_to_use = selected_instance_type
else:
instance_type_to_use = metadata_config.resolved_config.get(
"default_inference_instance_type"
)
if metadata_config.benchmark_metrics:
(
err,
metadata_config.benchmark_metrics,
) = add_instance_rate_stats_to_benchmark_metrics(
self.region, metadata_config.benchmark_metrics
)
config_components = metadata_config.config_components.get(config_name)
image_uri = (
(
config_components.hosting_instance_type_variants.get("regional_aliases", {})
.get(self.region, {})
.get("alias_ecr_uri_1")
)
if config_components
else self.image_uri
)
init_kwargs = get_init_kwargs(
config_name=config_name,
model_id=self.model,
instance_type=instance_type_to_use,
sagemaker_session=self.sagemaker_session,
image_uri=image_uri,
region=self.region,
model_version=getattr(self, "model_version", None) or "*",
hub_arn=self.hub_arn,
tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None),
tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None),
)
deploy_kwargs = get_deploy_kwargs(
model_id=self.model,
instance_type=instance_type_to_use,
sagemaker_session=self.sagemaker_session,
region=self.region,
model_version=getattr(self, "model_version", None) or "*",
hub_arn=self.hub_arn,
tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None),
tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None),
)
deployment_config_metadata = DeploymentConfigMetadata(
config_name,
metadata_config,
init_kwargs,
deploy_kwargs,
)
deployment_configs.append(deployment_config_metadata)
if err and err["Code"] == "AccessDeniedException":
error_message = "Instance rate metrics will be omitted. Reason: %s"
logger.warning(error_message, err["Message"])
return deployment_configs
# ========================================
# General Utils
# ========================================
def add_tags(self, tags: Tags) -> None:
"""Add tags to this model.
Args:
tags: Tags to add to the model.
"""
current_tags = getattr(self, "_tags", None)
self._tags = _validate_new_tags(tags, current_tags)
def remove_tag_with_key(self, key: str) -> None:
"""Remove a tag with the given key from the list of tags.
Args:
key: The key of the tag to remove.
"""
current_tags = getattr(self, "_tags", None)
self._tags = remove_tag_with_key(key, current_tags)
def _get_model_uri(self) -> Optional[str]:
"""Extract model URI from s3_model_data_url.
Returns:
Model URI string, or None if not available.
"""
s3_model_data_url = getattr(self, "s3_model_data_url", None)
if not s3_model_data_url:
return None
if isinstance(s3_model_data_url, (str, PipelineVariable)):
return s3_model_data_url
elif isinstance(s3_model_data_url, dict):
return s3_model_data_url.get("S3DataSource", {}).get("S3Uri", None)
return None
def _ensure_base_name_if_needed(
self, image_uri: str, script_uri: Optional[str], model_uri: Optional[str]
) -> None:
"""Create base name from image URI if no model name provided.
Uses JumpStart base name if available, otherwise derives from image URI.
Args:
image_uri: Container image URI
script_uri: Optional script URI for JumpStart models
model_uri: Optional model URI for JumpStart models
"""
model_name = getattr(self, "model_name", None)
if model_name is None:
base_name = getattr(self, "_base_name", None)
self._base_name = (
base_name
or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri)
or base_name_from_image(image_uri, default_base_name="ModelBuilder")
)
def _ensure_metadata_configs(self) -> None:
"""Lazy load JumpStart metadata configs when needed."""
metadata_configs = getattr(self, "_metadata_configs", None)
model = getattr(self, "model", None)
if metadata_configs is None and isinstance(model, str):
from sagemaker.core.jumpstart.utils import get_jumpstart_configs
self._metadata_configs = get_jumpstart_configs(
region=self.region,
model_id=model,
model_version=getattr(self, "model_version", None) or "*",
sagemaker_session=getattr(self, "sagemaker_session", None),
)
def _user_agent_decorator(self, func):
"""Decorator to add ModelBuilder to user agent string.
Args:
func: Function to decorate
Returns:
Decorated function that appends ModelBuilder to user agent.
"""
def wrapper(*args, **kwargs):
# Call the original function
result = func(*args, **kwargs)
if "ModelBuilder" in result:
return result
return result + " ModelBuilder"
return wrapper
def _get_serve_setting(self) -> _ServeSettings:
"""Get serve settings for model deployment.
Creates or uses existing S3 model data URL and constructs
serve settings with deployment configuration.
Returns:
ServeSettings object with deployment configuration.
"""
s3_model_data_url = getattr(self, "s3_model_data_url", None)
if not s3_model_data_url:
sagemaker_session = getattr(self, "sagemaker_session", None)
if sagemaker_session:
bucket = sagemaker_session.default_bucket()
model_name = getattr(self, "model_name", None)
prefix = f"model-builder/{model_name or 'model'}/{uuid.uuid4().hex}"
self.s3_model_data_url = f"s3://{bucket}/{prefix}/"
return _ServeSettings(
role_arn=getattr(self, "role_arn", None),
s3_model_data_url=getattr(self, "s3_model_data_url", None),
instance_type=getattr(self, "instance_type", None),
env_vars=getattr(self, "env_vars", None),
sagemaker_session=getattr(self, "sagemaker_session", None),
)
def _is_jumpstart_model_id(self) -> bool:
"""Check if model is a JumpStart model ID."""
if not hasattr(self, "_cached_is_jumpstart"):
if self.model is None:
self._cached_is_jumpstart = False
return self._cached_is_jumpstart
try:
model_uris.retrieve(model_id=self.model, model_version="*", model_scope=_JS_SCOPE)
except KeyError:
logger.debug(_NO_JS_MODEL_EX)
self._cached_is_jumpstart = False
return self._cached_is_jumpstart
logger.debug("JumpStart Model ID detected.")
self._cached_is_jumpstart = True
return self._cached_is_jumpstart
return self._cached_is_jumpstart
def _has_nvidia_gpu(self) -> bool:
try:
_get_available_gpus()
return True
except Exception:
# for nvidia-smi to run, a cuda driver must be present
logger.debug("CUDA not found, launching Triton in CPU mode.")
return False
def _is_gpu_instance(self, instance_type: str) -> bool:
instance_family = instance_type.rsplit(".", 1)[0]
return instance_family in GPU_INSTANCE_FAMILIES
def _save_inference_spec(self) -> None:
"""Save inference specification to pickle file."""
if self.inference_spec:
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
save_pkl(pkl_path, (self.inference_spec, self.schema_builder))
def _compute_integrity_hash(self):
"""Compute SHA-256 hash of serve.pkl and store in metadata.json for integrity check."""
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer)
with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())
def _generate_config_pbtxt(self, pkl_path: Path):
"""Generate Triton config.pbtxt file."""
config_path = pkl_path.joinpath("config.pbtxt")
input_shape = list(self.schema_builder._sample_input_ndarray.shape)
output_shape = list(self.schema_builder._sample_output_ndarray.shape)
input_shape[0] = -1
output_shape[0] = -1
config_content = CONFIG_TEMPLATE.format(
input_name=INPUT_NAME,
input_shape=str(input_shape),
input_dtype=self.schema_builder._input_triton_dtype,
output_name=OUTPUT_NAME,
output_dtype=self.schema_builder._output_triton_dtype,
output_shape=str(output_shape),
hardware_type="KIND_CPU" if "-cpu" in self.image_uri else "KIND_GPU",
)
with open(str(config_path), "w") as f:
f.write(config_content)
def _pack_conda_env(self, pkl_path: Path):
"""Pack conda environment for Triton deployment."""
try:
import conda_pack
conda_pack.__version__
except ModuleNotFoundError:
raise ImportError(
"Launching Triton with ModelBuilder requires conda_pack library "
"but it was not found in your environment. "
"Checkout the instructions on the installation page of its repo: "
"https://conda.github.io/conda-pack/ "
"And follow the ones that match your environment. "
"Please note that you may need to restart your runtime after installation."
)
script_path = Path(__file__).parent.joinpath("pack_conda_env.sh")
env_tar_path = pkl_path.joinpath("triton_env.tar.gz")
conda_env_name = os.getenv("CONDA_DEFAULT_ENV")
subprocess.run(["bash", str(script_path), conda_env_name, str(env_tar_path)])
def _export_tf_to_onnx(self, export_path: str, model: object, schema_builder: SchemaBuilder):
try:
import tensorflow as tf
import tf2onnx
tf2onnx.convert.from_keras(
model=model,
input_signature=[
tf.TensorSpec(shape=schema_builder.sample_input.shape, name=INPUT_NAME)
],
output_path=str(export_path.joinpath("model.onnx")),
)
except ModuleNotFoundError:
raise ImportError(
"Launching Triton with ModelBuilder for a Tensorflow model requires tf2onnx module "
"but it was not found in your environment. "
"Checkout the instructions on the installation page of its repo: "
"https://onnxruntime.ai/docs/install/ "
"And follow the ones that match your environment. "
"Please note that you may need to restart your runtime after installation."
)
def _export_pytorch_to_onnx(
self, model: object, export_path: Path, schema_builder: SchemaBuilder
):
"""Export PyTorch model object into ONNX format."""
logger.debug("Converting PyTorch model into ONNX format")
try:
from torch.onnx import export
export(
model=model,
args=schema_builder.sample_input,
f=str(export_path.joinpath("model.onnx")),
input_names=[INPUT_NAME],
output_names=[OUTPUT_NAME],
opset_version=17,
verbose=False,
)
except ModuleNotFoundError:
raise ImportError(
"Launching Triton with ModelBuilder for a PyTorch model requires onnx module "
"but it was not found in your environment. "
"Checkout the instructions on the installation page of its repo: "
"https://onnxruntime.ai/docs/install/ "
"And follow the ones that match your environment. "
"Please note that you may need to restart your runtime after installation."
)
def _validate_for_triton(self):
"""Validation for Triton deployment."""
try:
import tritonclient.http as httpClient
httpClient.__class__
except ModuleNotFoundError:
raise ImportError(
"Launching Triton with ModelBuilder requires tritonClient[http] module "
"but it was not found in your environment. "
"Checkout the instructions on the installation page of its repo: "
"https://github.com/triton-inference-server/client#getting-the-client-libraries-and-examples "
"And follow the ones that match your environment. "
"Please note that you may need to restart your runtime after installation."
)
if (
self.mode == Mode.LOCAL_CONTAINER
and not self._has_nvidia_gpu()
and self.image_uri
and "cpu" not in self.image_uri
):
raise ValueError(
"Your device does not have a Nvidia GPU. "
"Unable to launch Triton container in GPU mode in your local machine. "
"Please provide a CPU version triton image to serve your model in LOCAL_CONTAINER mode."
)
if self.mode not in SUPPORTED_TRITON_MODE:
raise ValueError("%s mode is not supported with Triton model server." % self.mode)
model_path = Path(self.model_path)
if not model_path.exists():
model_path.mkdir(parents=True)
elif not model_path.is_dir():
raise Exception(f"model_path: {self.model_path} is not a valid directory")
self.schema_builder._update_serializer_deserializer_for_triton()
self.schema_builder._detect_dtype_for_triton()
if not platform.python_version().startswith("3.8"):
logger.warning(
f"SageMaker Triton image uses python 3.8, your python version: {platform.python_version()}. "
"It is recommended to use the same python version to avoid incompatibility."
)
if self.model:
fw, self.framework_version = _detect_framework_and_version(
str(_get_model_base(self.model))
)
if fw == "pytorch":
self.framework = Framework.PYTORCH
elif fw == "tensorflow":
self.framework = Framework.TENSORFLOW
if self.framework not in SUPPORTED_TRITON_FRAMEWORK:
raise ValueError("%s is not supported with Triton model server" % self.framework)
if self.inference_spec:
if "conda" not in sys.executable.lower():
raise ValueError(
f"Invalid python environment {sys.executable}, please use anaconda "
"or miniconda to manage your python environment "
"as it is required by Triton to capture "
"and pack your python dependencies."
)
def _prepare_for_triton(self):
"""Prepare model artifacts for Triton deployment."""
model_path = Path(self.model_path)
pkl_path = model_path.joinpath("model_repository").joinpath("model")
if not pkl_path.exists():
pkl_path.mkdir(parents=True)
for root, _, files in os.walk(self.model_path):
for f in files:
path_file = os.path.join(root, f)
if "model_repository" not in path_file:
shutil.copy2(path_file, str(pkl_path.joinpath(f)))
export_path = model_path.joinpath("model_repository").joinpath("model").joinpath("1")
if not export_path.exists():
export_path.mkdir(parents=True)
if self.model:
# ONNX path: no pickle serialization, no serve.pkl, no integrity check needed.
# Do not set secret_key — there is nothing to sign.
if self.framework == Framework.PYTORCH:
self._export_pytorch_to_onnx(
export_path=export_path, model=self.model, schema_builder=self.schema_builder
)
return
if self.framework == Framework.TENSORFLOW:
self._export_tf_to_onnx(
export_path=export_path, model=self.model, schema_builder=self.schema_builder
)
return
raise ValueError("%s is not supported" % self.framework)
if self.inference_spec:
triton_model_path = Path(__file__).parent.joinpath("model.py")
shutil.copy2(str(triton_model_path), str(export_path))
self._generate_config_pbtxt(pkl_path=pkl_path)
self._pack_conda_env(pkl_path=pkl_path)
self._compute_integrity_hash()
return
raise ValueError("Either model or inference_spec should be provided to ModelBuilder.")
def _auto_detect_image_for_triton(self):
"""Detect image of triton given framework, version and region.
If InferenceSpec is provided, then default to latest version.
"""
if self.image_uri:
logger.debug("Skipping auto detection as the image uri is provided %s", self.image_uri)
return
logger.debug(
"Auto detect container url for the provided model and on instance %s",
self.instance_type,
)
region = self.sagemaker_session.boto_region_name
if region not in TRITON_IMAGE_ACCOUNT_ID_MAP.keys():
raise ValueError(
f"{region} is not supported for triton image. "
f"Please switch to the following region: {list(TRITON_IMAGE_ACCOUNT_ID_MAP.keys())}"
)
base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
if (
not self.inference_spec
and self.framework == "tensorflow"
and self.version.startswith("1")
):
self.image_uri = TRITON_IMAGE_BASE.format(
account_id=TRITON_IMAGE_ACCOUNT_ID_MAP.get(region),
region=region,
base=base,
version=VERSION_FOR_TF1,
)
else:
self.image_uri = TRITON_IMAGE_BASE.format(
account_id=TRITON_IMAGE_ACCOUNT_ID_MAP.get(region),
region=region,
base=base,
version=LATEST_VERSION,
)
if not self._is_gpu_instance(self.instance_type):
self.image_uri += "-cpu"
logger.debug(f"Autodetected image: {self.image_uri}. Proceeding with the deployment.")
def _validate_djl_serving_sample_data(self):
"""Validate sample data format for DJL serving."""
sample_input = self.schema_builder.sample_input
sample_output = self.schema_builder.sample_output
if (
not isinstance(sample_input, dict)
or "inputs" not in sample_input
or "parameters" not in sample_input
or not isinstance(sample_output, list)
or not isinstance(sample_output[0], dict)
or "generated_text" not in sample_output[0]
):
raise ValueError(_INVALID_DJL_SAMPLE_DATA_EX)
def _validate_tgi_serving_sample_data(self):
"""Validate sample data format for TGI serving."""
sample_input = self.schema_builder.sample_input
sample_output = self.schema_builder.sample_output
if (
not isinstance(sample_input, dict)
or "inputs" not in sample_input
or "parameters" not in sample_input
or not isinstance(sample_output, list)
or not isinstance(sample_output[0], dict)
or "generated_text" not in sample_output[0]
):
raise ValueError(_INVALID_TGI_SAMPLE_DATA_EX)
def _create_conda_env(self):
"""Create conda environment by running commands."""
try:
RequirementsManager().capture_and_install_dependencies
except subprocess.CalledProcessError:
logger.error("Failed to create and activate conda environment.")
def _extract_framework_from_model_trainer(
self, model_trainer: ModelTrainer
) -> Optional[Framework]:
"""Extract framework from ModelTrainer training image."""
training_image = model_trainer.training_image
if not training_image:
training_image = (
model_trainer._latest_training_job.algorithm_specification.training_image
)
if "pytorch" in training_image.lower():
return Framework.PYTORCH
elif "tensorflow" in training_image.lower():
return Framework.TENSORFLOW
elif "huggingface" in training_image.lower():
return Framework.HUGGINGFACE
elif "xgboost" in training_image.lower():
return Framework.XGBOOST
return None
def _infer_model_server_from_training(
self, model_trainer: ModelTrainer
) -> Optional[ModelServer]:
"""Infer the best model server based on training configuration."""
training_image = model_trainer.training_image
framework = self._extract_framework_from_model_trainer(model_trainer)
if "huggingface" in training_image.lower():
hyperparams = model_trainer.hyperparameters or {}
if any(key in hyperparams for key in ["max_new_tokens", "do_sample", "temperature"]):
logger.info("Auto-detected model server: TGI (HuggingFace text generation)")
return ModelServer.TGI
else:
logger.info("Auto-detected model server: MMS (HuggingFace)")
return ModelServer.MMS
if framework == Framework.PYTORCH:
logger.info("Auto-detected model server: TORCHSERVE (PyTorch framework)")
return ModelServer.TORCHSERVE
if framework == Framework.TENSORFLOW:
logger.info("Auto-detected model server: TENSORFLOW_SERVING (TensorFlow framework)")
return ModelServer.TENSORFLOW_SERVING
logger.warning(
f"Could not auto-detect model server for framework: {framework}. "
"Defaulting to TORCHSERVE. Consider explicitly setting model_server parameter."
)
return ModelServer.TORCHSERVE
def _extract_inference_spec_from_training_code(
self, model_trainer: ModelTrainer
) -> Optional[str]:
"""Check if training source code contains inference.py."""
if not model_trainer.source_code or not model_trainer.source_code.source_dir:
return None
source_dir = model_trainer.source_code.source_dir
# Check for inference.py in source directory
if source_dir.startswith("s3://"):
pass
else:
inference_path = os.path.join(source_dir, "inference.py")
if os.path.exists(inference_path):
return inference_path
return None
def _inherit_training_environment(self, model_trainer: ModelTrainer) -> Dict[str, str]:
"""Inherit relevant environment variables from training."""
from sagemaker.core.utils.utils import Unassigned
training_env = model_trainer.environment or {}
if isinstance(training_env, Unassigned):
training_env = {}
training_job_env = model_trainer._latest_training_job.environment
if isinstance(training_job_env, Unassigned) or training_job_env is None:
training_job_env = {}
inherited_env = {**training_env, **training_job_env}
inference_relevant_keys = [
"HUGGING_FACE_HUB_TOKEN",
"HF_TOKEN",
"MODEL_CLASS_NAME",
"TRANSFORMERS_CACHE",
"PYTORCH_TRANSFORMERS_CACHE",
"HF_HOME",
]
return {
k: v
for k, v in inherited_env.items()
if k in inference_relevant_keys or k.startswith("SAGEMAKER_")
}
def _extract_version_from_training_image(self, training_image: str) -> Optional[str]:
"""Extract framework version from training image URI."""
import re
version_match = re.search(r":(\d+\.\d+(?:\.\d+)?)", training_image)
if version_match:
return version_match.group(1)
return None
def _detect_inference_image_from_training(self) -> None:
"""Detect inference image based on ModelTrainer's training image."""
from sagemaker.core import image_uris
training_image = self.model.training_image
if "pytorch-training" in training_image:
self.image_uri = training_image.replace("pytorch-training", "pytorch-inference")
elif "tensorflow-training" in training_image:
self.image_uri = training_image.replace("tensorflow-training", "tensorflow-inference")
elif "huggingface-pytorch-training" in training_image:
self.image_uri = training_image.replace(
"huggingface-pytorch-training", "huggingface-pytorch-inference"
)
else:
framework = self._extract_framework_from_model_trainer(self.model)
fw = framework.value.lower() if framework else "pytorch"
fw_version = self._extract_version_from_training_image(training_image)
py_tuple = platform.python_version_tuple()
casted_versions = _cast_to_compatible_version(fw, fw_version) if fw_version else (None,)
dlc = None
for casted_version in filter(None, casted_versions):
try:
dlc = image_uris.retrieve(
framework=fw,
region=self.region,
version=casted_version,
image_scope="inference",
py_version=f"py{py_tuple[0]}{py_tuple[1]}",
instance_type=self.instance_type,
)
break
except ValueError:
pass
if dlc:
self.image_uri = dlc
else:
raise ValueError(
f"Could not detect inference image for training image: {training_image}"
)
def _extract_speculative_draft_model_provider(
self,
speculative_decoding_config: Optional[Dict] = None,
) -> Optional[str]:
"""Extracts speculative draft model provider from speculative decoding config.
Args:
speculative_decoding_config (Optional[Dict]): A speculative decoding config.
Returns:
Optional[str]: The speculative draft model provider.
"""
if speculative_decoding_config is None:
return None
model_provider = speculative_decoding_config.get("ModelProvider", "").lower()
if model_provider == "jumpstart":
return "jumpstart"
if model_provider == "custom" or speculative_decoding_config.get("ModelSource"):
return "custom"
if model_provider == "sagemaker":
return "sagemaker"
return "auto"
def get_huggingface_model_metadata(
self, model_id: str, hf_hub_token: Optional[str] = None
) -> dict:
"""Retrieves the json metadata of the HuggingFace Model via HuggingFace API.
Args:
model_id (str): The HuggingFace Model ID
hf_hub_token (str): The HuggingFace Hub Token needed for Private/Gated HuggingFace Models
Returns:
dict: The model metadata retrieved with the HuggingFace API
"""
import urllib.request
from urllib.error import HTTPError, URLError
import json
from json import JSONDecodeError
if not model_id:
raise ValueError("Model ID is empty. Please provide a valid Model ID.")
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
hf_model_metadata_json = None
try:
if hf_hub_token:
hf_model_metadata_url = urllib.request.Request(
hf_model_metadata_url, None, {"Authorization": "Bearer " + hf_hub_token}
)
with urllib.request.urlopen(hf_model_metadata_url) as response:
hf_model_metadata_json = json.load(response)
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
if "HTTP Error 401: Unauthorized" in str(e):
raise ValueError(
"Trying to access a gated/private HuggingFace model without valid credentials. "
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
)
logger.warning(
"Exception encountered while trying to retrieve HuggingFace model metadata %s. "
"Details: %s",
hf_model_metadata_url,
e,
)
if not hf_model_metadata_json:
raise ValueError(
"Did not find model metadata for the following HuggingFace Model ID %s" % model_id
)
return hf_model_metadata_json
def download_huggingface_model_metadata(
self, model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None
) -> None:
"""Downloads the HuggingFace Model snapshot via HuggingFace API.
Args:
model_id (str): The HuggingFace Model ID
model_local_path (str): The local path to save the HuggingFace Model snapshot.
hf_hub_token (str): The HuggingFace Hub Token
Raises:
ImportError: If huggingface_hub is not installed.
"""
if not importlib.util.find_spec("huggingface_hub"):
raise ImportError(
"Unable to import huggingface_hub, check if huggingface_hub is installed"
)
from huggingface_hub import snapshot_download
os.makedirs(model_local_path, exist_ok=True)
logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path)
snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token)