# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
# pylint: skip-file
"""This module contains accessors related to SageMaker JumpStart."""
from __future__ import absolute_import
import functools
import logging
from typing import Any, Dict, List, Optional
import boto3
from sagemaker.core.deprecations import deprecated
from sagemaker.core.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs, HubContentType
from sagemaker.core.jumpstart.enums import JumpStartModelType
from sagemaker.core.jumpstart import cache
from sagemaker.core.jumpstart.hub.utils import (
construct_hub_model_arn_from_inputs,
construct_hub_model_reference_arn_from_inputs,
)
from sagemaker.core.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.jumpstart import constants
[docs]
class SageMakerSettings(object):
"""Static class for storing the SageMaker settings."""
_parsed_sagemaker_version = ""
[docs]
@staticmethod
def set_sagemaker_version(version: str) -> None:
"""Set SageMaker version."""
SageMakerSettings._parsed_sagemaker_version = version
[docs]
@staticmethod
def get_sagemaker_version() -> str:
"""Return SageMaker version."""
return SageMakerSettings._parsed_sagemaker_version
[docs]
class JumpStartS3PayloadAccessor(object):
"""Static class for storing and retrieving S3 payload artifacts."""
MAX_CACHE_SIZE_BYTES = int(100 * 1e6)
MAX_PAYLOAD_SIZE_BYTES = int(6 * 1e6)
CACHE_SIZE = MAX_CACHE_SIZE_BYTES // MAX_PAYLOAD_SIZE_BYTES
[docs]
@staticmethod
def clear_cache() -> None:
"""Clears LRU caches associated with S3 client and retrieved objects."""
JumpStartS3PayloadAccessor._get_default_s3_client.cache_clear()
JumpStartS3PayloadAccessor.get_object_cached.cache_clear()
@staticmethod
@functools.lru_cache()
def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client:
"""Returns default S3 client associated with the region.
Result is cached so multiple clients in memory are not created.
"""
return boto3.client("s3", region_name=region)
[docs]
@staticmethod
@functools.lru_cache(maxsize=CACHE_SIZE)
def get_object_cached(
bucket: str,
key: str,
region: str = JUMPSTART_DEFAULT_REGION_NAME,
s3_client: Optional[boto3.client] = None,
) -> bytes:
"""Returns S3 object located at the bucket and key.
Requests are cached so that the same S3 request is never made more
than once, unless a different region or client is used.
"""
return JumpStartS3PayloadAccessor.get_object(
bucket=bucket, key=key, region=region, s3_client=s3_client
)
@staticmethod
def _get_object_size_bytes(
bucket: str,
key: str,
region: str = JUMPSTART_DEFAULT_REGION_NAME,
s3_client: Optional[boto3.client] = None,
) -> bytes:
"""Returns size in bytes of S3 object using S3.HeadObject operation."""
if s3_client is None:
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region)
return s3_client.head_object(Bucket=bucket, Key=key)["ContentLength"]
[docs]
@staticmethod
def get_object(
bucket: str,
key: str,
region: str = JUMPSTART_DEFAULT_REGION_NAME,
s3_client: Optional[boto3.client] = None,
) -> bytes:
"""Returns S3 object located at the bucket and key.
Raises:
ValueError: The object size is too large.
"""
if s3_client is None:
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region)
object_size_bytes = JumpStartS3PayloadAccessor._get_object_size_bytes(
bucket=bucket, key=key, region=region, s3_client=s3_client
)
if object_size_bytes > JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES:
raise ValueError(
f"s3://{bucket}/{key} has size of {object_size_bytes} bytes, "
"which exceeds maximum allowed size of "
f"{JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES} bytes."
)
return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
[docs]
class JumpStartModelsAccessor(object):
"""Static class for storing the JumpStart models cache."""
_cache: Optional[cache.JumpStartModelsCache] = None
_curr_region = JUMPSTART_DEFAULT_REGION_NAME
_content_bucket: Optional[str] = None
_gated_content_bucket: Optional[str] = None
_cache_kwargs: Dict[str, Any] = {}
[docs]
@staticmethod
def set_jumpstart_content_bucket(content_bucket: str) -> None:
"""Sets JumpStart content bucket."""
JumpStartModelsAccessor._content_bucket = content_bucket
[docs]
@staticmethod
def get_jumpstart_content_bucket() -> Optional[str]:
"""Returns JumpStart content bucket."""
return JumpStartModelsAccessor._content_bucket
[docs]
@staticmethod
def set_jumpstart_gated_content_bucket(gated_content_bucket: str) -> None:
"""Sets JumpStart gated content bucket."""
JumpStartModelsAccessor._gated_content_bucket = gated_content_bucket
[docs]
@staticmethod
def get_jumpstart_gated_content_bucket() -> Optional[str]:
"""Returns JumpStart gated content bucket."""
return JumpStartModelsAccessor._gated_content_bucket
@staticmethod
def _validate_and_mutate_region_cache_kwargs(
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
) -> Dict[str, Any]:
"""Returns cache_kwargs with region argument removed if present.
Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
Args:
cache_kwargs (Optional[Dict[str, Any]]): cache kwargs to validate.
region (str): The region to validate along with the kwargs.
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
if region is not None and "region" in cache_kwargs_dict:
if region != cache_kwargs_dict["region"]:
raise ValueError(
f"Inconsistent region definitions: {region}, {cache_kwargs_dict['region']}"
)
del cache_kwargs_dict["region"]
return cache_kwargs_dict
@staticmethod
def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
"""Sets ``JumpStartModelsAccessor._cache`` and ``JumpStartModelsAccessor._curr_region``.
Args:
region (str): region for which to retrieve header/spec.
cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``.
"""
new_cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
cache_kwargs, region
)
if (
JumpStartModelsAccessor._cache is None
or region != JumpStartModelsAccessor._curr_region
or new_cache_kwargs != JumpStartModelsAccessor._cache_kwargs
):
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
region=region, **cache_kwargs
)
JumpStartModelsAccessor._curr_region = region
JumpStartModelsAccessor._cache_kwargs = new_cache_kwargs
@staticmethod
def _get_manifest(
region: str = JUMPSTART_DEFAULT_REGION_NAME,
s3_client: Optional[boto3.client] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest.
Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
Args:
region (str): Optional. The region to use for the cache.
s3_client (boto3.client): Optional. Boto3 client to use for accessing JumpStart models
s3 cache. If not set, a default client will be made.
"""
additional_kwargs = {}
if s3_client is not None:
additional_kwargs.update({"s3_client": s3_client})
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs},
region,
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
return JumpStartModelsAccessor._cache.get_manifest(model_type) # type: ignore
[docs]
@staticmethod
def get_model_specs(
region: str,
model_id: str,
version: str,
hub_arn: Optional[str] = None,
s3_client: Optional[boto3.client] = None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> JumpStartModelSpecs:
"""Returns model specs from JumpStart models cache.
Args:
region (str): region for which to retrieve header.
model_id (str): model ID to retrieve.
version (str): semantic version to retrieve for the model ID.
s3_client (boto3.client): boto3 client to use for accessing JumpStart models s3 cache.
If not set, a default client will be made.
"""
additional_kwargs = {}
if s3_client is not None:
additional_kwargs.update({"s3_client": s3_client})
if hub_arn:
additional_kwargs.update({"sagemaker_session": sagemaker_session})
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
# Users only input model id, not contentType, so first try to describe with ModelReference, then with Model
if hub_arn:
try:
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
model_specs = JumpStartModelsAccessor._cache.get_hub_model_reference(
hub_model_reference_arn=hub_model_arn
)
model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE)
return model_specs
except Exception as ex:
logging.info(
"Received exeption while calling APIs for ContentType ModelReference, \
retrying with ContentType Model: "
+ str(ex)
)
hub_model_arn = construct_hub_model_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
# Failed to describe ModelReference, try with Model
try:
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
hub_model_arn=hub_model_arn
)
model_specs.set_hub_content_type(HubContentType.MODEL)
return model_specs
except Exception as ex:
# Failed with both, throw a custom error message
raise RuntimeError(
f"Cannot get details for {model_id} in Hub {hub_arn}. \
{model_id} does not exist as a Model or ModelReference: \n"
+ str(ex)
)
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, version_str=version, model_type=model_type
)
[docs]
@staticmethod
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
"""Sets cache kwargs, clears the cache.
Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
Args:
cache_kwargs (str): cache kwargs to validate.
region (str): Optional. The region to validate along with the kwargs.
"""
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
cache_kwargs, region
)
JumpStartModelsAccessor._cache_kwargs = cache_kwargs
if region is None:
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
**JumpStartModelsAccessor._cache_kwargs
)
else:
JumpStartModelsAccessor._curr_region = region
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
region=region, **JumpStartModelsAccessor._cache_kwargs
)
[docs]
@staticmethod
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None) -> None:
"""Resets cache, optionally allowing cache kwargs to be passed to the new cache.
Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
Args:
cache_kwargs (str): cache kwargs to validate.
region (str): The region to validate along with the kwargs.
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
@staticmethod
@deprecated()
def get_manifest(
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest.
Raises:
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
Args:
cache_kwargs (Dict[str, Any]): Optional. Cache kwargs to use.
(Default: None).
region (str): Optional. The region to use for the cache.
(Default: None).
"""
cache_kwargs_dict: Dict[str, Any] = {} if cache_kwargs is None else cache_kwargs
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore