Source code for sagemaker.core.jumpstart.document

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains utilites for JumpStart model metadata."""
from __future__ import absolute_import

import json
from typing import Optional, Tuple
from functools import lru_cache
from botocore.exceptions import ClientError

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.utils.utils import logger
from sagemaker.core.resources import HubContent
from sagemaker.core.jumpstart.configs import JumpStartConfig
from sagemaker.core.jumpstart.models import HubContentDocument
from sagemaker.core.jumpstart.constants import SAGEMAKER_PUBLIC_HUB


[docs] @lru_cache(maxsize=128) def get_hub_content_and_document( jumpstart_config: JumpStartConfig, sagemaker_session: Optional[Session] = None, ) -> Tuple[HubContent, HubContentDocument]: """Get model metadata for JumpStart. Args: jumpstart_config (JumpStartConfig): JumpStart configuration. sagemaker_session (Session, optional): SageMaker session. Defaults to None. Returns: HubContentDocument: Model metadata. """ if sagemaker_session is None: sagemaker_session = Session() logger.debug("No sagemaker session provided. Using default session.") hub_name = jumpstart_config.hub_name if jumpstart_config.hub_name else SAGEMAKER_PUBLIC_HUB hub_content_type = "Model" if hub_name == SAGEMAKER_PUBLIC_HUB else "ModelReference" region = sagemaker_session.boto_region_name try: hub_content = HubContent.get( hub_name=hub_name, hub_content_name=jumpstart_config.model_id, hub_content_version=jumpstart_config.model_version, hub_content_type=hub_content_type, session=sagemaker_session.boto_session, region=region, ) except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFound": logger.error( f"Hub content {jumpstart_config.model_id} not found in {hub_name}.\n" "Please check that the Model ID is availble in the specified hub." ) raise e logger.info( f"hub_content_name: {hub_content.hub_content_name}, " f"hub_content_version: {hub_content.hub_content_version}" ) document_json = json.loads(hub_content.hub_content_document) return (hub_content, HubContentDocument(**document_json))