Source code for sagemaker.core.jumpstart.hub.parsers

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
# pylint: skip-file
"""This module stores Hub converter utilities for JumpStart."""
from __future__ import absolute_import

from typing import Any, Dict, List
from sagemaker.core.jumpstart.enums import ModelSpecKwargType, NamingConventionType
from sagemaker.core.s3 import parse_s3_url
from sagemaker.core.jumpstart.types import (
    JumpStartModelSpecs,
    HubContentType,
    JumpStartDataHolderType,
)
from sagemaker.core.jumpstart.hub.interfaces import (
    DescribeHubContentResponse,
    HubModelDocument,
)
from sagemaker.core.jumpstart.hub.parser_utils import (
    camel_to_snake,
    snake_to_upper_camel,
    walk_and_apply_json,
)


def _to_json(dictionary: Dict[Any, Any]) -> Dict[Any, Any]:
    """Convert a nested dictionary of JumpStartDataHolderType into json with UpperCamelCase keys"""
    for key, value in dictionary.items():
        if issubclass(type(value), JumpStartDataHolderType):
            dictionary[key] = walk_and_apply_json(value.to_json(), snake_to_upper_camel)
        elif isinstance(value, list):
            new_value = []
            for value_in_list in value:
                new_value_in_list = value_in_list
                if issubclass(type(value_in_list), JumpStartDataHolderType):
                    new_value_in_list = walk_and_apply_json(
                        value_in_list.to_json(), snake_to_upper_camel
                    )
                new_value.append(new_value_in_list)
            dictionary[key] = new_value
        elif isinstance(value, dict):
            for key_in_dict, value_in_dict in value.items():
                if issubclass(type(value_in_dict), JumpStartDataHolderType):
                    value[key_in_dict] = walk_and_apply_json(
                        value_in_dict.to_json(), snake_to_upper_camel
                    )
    return dictionary


[docs] def get_model_spec_arg_keys( arg_type: ModelSpecKwargType, naming_convention: NamingConventionType = NamingConventionType.DEFAULT, ) -> List[str]: """Returns a list of arg keys for a specific model spec arg type. Args: arg_type (ModelSpecKwargType): Type of the model spec's kwarg. naming_convention (NamingConventionType): Type of naming convention to return. Raises: ValueError: If the naming convention is not valid. """ arg_keys: List[str] = [] if arg_type == ModelSpecKwargType.DEPLOY: arg_keys = [ "ModelDataDownloadTimeout", "ContainerStartupHealthCheckTimeout", "InferenceAmiVersion", ] elif arg_type == ModelSpecKwargType.ESTIMATOR: arg_keys = [ "EncryptInterContainerTraffic", "MaxRuntimeInSeconds", "DisableOutputCompression", "ModelDir", ] elif arg_type == ModelSpecKwargType.MODEL: arg_keys = [] elif arg_type == ModelSpecKwargType.FIT: arg_keys = [] if naming_convention == NamingConventionType.SNAKE_CASE: arg_keys = [camel_to_snake(key) for key in arg_keys] elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE: return arg_keys else: raise ValueError("Please provide a valid naming convention.") return arg_keys
[docs] def get_model_spec_kwargs_from_hub_model_document( arg_type: ModelSpecKwargType, hub_content_document: Dict[str, Any], naming_convention: NamingConventionType = NamingConventionType.UPPER_CAMEL_CASE, ) -> Dict[str, Any]: """Returns a map of arg type to arg keys for a given hub content document. Args: arg_type (ModelSpecKwargType): Type of the model spec's kwarg. hub_content_document: A dictionary representation of hub content document. naming_convention (NamingConventionType): Type of naming convention to return. """ kwargs = dict() keys = get_model_spec_arg_keys(arg_type, naming_convention=naming_convention) for k in keys: kwarg_value = hub_content_document.get(k) if kwarg_value is not None: kwargs[k] = kwarg_value return kwargs
[docs] def make_model_specs_from_describe_hub_content_response( response: DescribeHubContentResponse, ) -> JumpStartModelSpecs: """Sets fields in JumpStartModelSpecs based on values in DescribeHubContentResponse Args: response (Dict[str, any]): parsed DescribeHubContentResponse returned from SageMaker:DescribeHubContent """ if response.hub_content_type not in {HubContentType.MODEL, HubContentType.MODEL_REFERENCE}: raise AttributeError( "Invalid content type, use either HubContentType.MODEL or HubContentType.MODEL_REFERENCE." ) region = response.get_hub_region() specs = {} model_id = response.hub_content_name specs["model_id"] = model_id specs["version"] = response.hub_content_version hub_model_document: HubModelDocument = response.hub_content_document specs["url"] = hub_model_document.url specs["min_sdk_version"] = hub_model_document.min_sdk_version specs["model_types"] = hub_model_document.model_types specs["capabilities"] = hub_model_document.capabilities specs["training_supported"] = bool(hub_model_document.training_supported) specs["incremental_training_supported"] = bool( hub_model_document.incremental_training_supported ) specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri specs["inference_configs"] = hub_model_document.inference_configs specs["inference_config_components"] = hub_model_document.inference_config_components specs["inference_config_rankings"] = hub_model_document.inference_config_rankings if hub_model_document.hosting_artifact_uri: _, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable hub_model_document.hosting_artifact_uri ) specs["hosting_artifact_key"] = hosting_artifact_key specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri if hub_model_document.hosting_script_uri: _, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable hub_model_document.hosting_script_uri ) specs["hosting_script_key"] = hosting_script_key specs["inference_environment_variables"] = hub_model_document.inference_environment_variables specs["inference_vulnerable"] = False specs["inference_dependencies"] = hub_model_document.inference_dependencies specs["inference_vulnerabilities"] = [] specs["training_vulnerable"] = False specs["training_vulnerabilities"] = [] specs["deprecated"] = False specs["deprecated_message"] = None specs["deprecate_warn_message"] = None specs["usage_info_message"] = None specs["default_inference_instance_type"] = hub_model_document.default_inference_instance_type specs["supported_inference_instance_types"] = ( hub_model_document.supported_inference_instance_types ) specs["dynamic_container_deployment_supported"] = ( hub_model_document.dynamic_container_deployment_supported ) specs["hosting_resource_requirements"] = hub_model_document.hosting_resource_requirements specs["hosting_prepacked_artifact_key"] = None if hub_model_document.hosting_prepacked_artifact_uri is not None: ( hosting_prepacked_artifact_bucket, # pylint: disable=unused-variable hosting_prepacked_artifact_key, ) = parse_s3_url(hub_model_document.hosting_prepacked_artifact_uri) specs["hosting_prepacked_artifact_key"] = hosting_prepacked_artifact_key hub_content_document_dict: Dict[str, Any] = hub_model_document.to_json() specs["fit_kwargs"] = get_model_spec_kwargs_from_hub_model_document( ModelSpecKwargType.FIT, hub_content_document_dict ) specs["model_kwargs"] = get_model_spec_kwargs_from_hub_model_document( ModelSpecKwargType.MODEL, hub_content_document_dict ) specs["deploy_kwargs"] = get_model_spec_kwargs_from_hub_model_document( ModelSpecKwargType.DEPLOY, hub_content_document_dict ) specs["estimator_kwargs"] = get_model_spec_kwargs_from_hub_model_document( ModelSpecKwargType.ESTIMATOR, hub_content_document_dict ) specs["predictor_specs"] = hub_model_document.sage_maker_sdk_predictor_specifications default_payloads: Dict[str, Any] = {} if hub_model_document.default_payloads is not None: for alias, payload in hub_model_document.default_payloads.items(): default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake) specs["default_payloads"] = default_payloads specs["gated_bucket"] = hub_model_document.gated_bucket specs["inference_volume_size"] = hub_model_document.inference_volume_size specs["inference_enable_network_isolation"] = ( hub_model_document.inference_enable_network_isolation ) specs["resource_name_base"] = hub_model_document.resource_name_base specs["hosting_eula_key"] = None if hub_model_document.hosting_eula_uri is not None: hosting_eula_bucket, hosting_eula_key = parse_s3_url( # pylint: disable=unused-variable hub_model_document.hosting_eula_uri ) specs["hosting_eula_key"] = hosting_eula_key if hub_model_document.hosting_model_package_arn: specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn} specs["model_subscription_link"] = hub_model_document.model_subscription_link specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants if specs["training_supported"]: specs["training_ecr_uri"] = hub_model_document.training_ecr_uri ( training_artifact_bucket, # pylint: disable=unused-variable training_artifact_key, ) = parse_s3_url(hub_model_document.training_artifact_uri) specs["training_artifact_key"] = training_artifact_key ( training_script_bucket, # pylint: disable=unused-variable training_script_key, ) = parse_s3_url(hub_model_document.training_script_uri) specs["training_script_key"] = training_script_key specs["training_configs"] = hub_model_document.training_configs specs["training_config_components"] = hub_model_document.training_config_components specs["training_config_rankings"] = hub_model_document.training_config_rankings specs["training_dependencies"] = hub_model_document.training_dependencies specs["default_training_instance_type"] = hub_model_document.default_training_instance_type specs["supported_training_instance_types"] = ( hub_model_document.supported_training_instance_types ) specs["metrics"] = hub_model_document.training_metrics specs["training_prepacked_script_key"] = None if hub_model_document.training_prepacked_script_uri is not None: ( training_prepacked_script_bucket, # pylint: disable=unused-variable training_prepacked_script_key, ) = parse_s3_url(hub_model_document.training_prepacked_script_uri) specs["training_prepacked_script_key"] = training_prepacked_script_key specs["hyperparameters"] = hub_model_document.hyperparameters specs["training_volume_size"] = hub_model_document.training_volume_size specs["training_enable_network_isolation"] = ( hub_model_document.training_enable_network_isolation ) if hub_model_document.training_model_package_artifact_uri: specs["training_model_package_artifact_uris"] = { region: hub_model_document.training_model_package_artifact_uri } specs["training_instance_type_variants"] = ( hub_model_document.training_instance_type_variants ) if hub_model_document.default_training_dataset_uri: _, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable hub_model_document.default_training_dataset_uri ) specs["default_training_dataset_key"] = default_training_dataset_key specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri return JumpStartModelSpecs(_to_json(specs), is_hub_content=True)