Source code for sagemaker.core.jumpstart.factory.utils

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores JumpStart factory utilities."""

from __future__ import absolute_import
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from sagemaker.core.shapes import ModelAccessConfig
from sagemaker.core import (
    environment_variables,
    image_uris,
    instance_types,
    model_uris,
    script_uris,
)
from sagemaker.serve.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.core.deserializers.base import BaseDeserializer
from sagemaker.core.serializers.base import BaseSerializer
from sagemaker.core.explainer.explainer_config import ExplainerConfig
from sagemaker.core.jumpstart.artifacts import (
    _model_supports_inference_script_uri,
    _retrieve_model_init_kwargs,
    _retrieve_model_deploy_kwargs,
    _retrieve_model_package_arn,
)
from sagemaker.core.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
from sagemaker.core.jumpstart.constants import (
    INFERENCE_ENTRY_POINT_SCRIPT_NAME,
    JUMPSTART_DEFAULT_REGION_NAME,
    JUMPSTART_LOGGER,
)

from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.core.jumpstart.hub.utils import (
    construct_hub_model_arn_from_inputs,
    construct_hub_model_reference_arn_from_inputs,
)

from sagemaker.core.jumpstart.enums import (
    JumpStartScriptScope,
    JumpStartModelType,
    HubContentCapability,
)
from sagemaker.core.jumpstart.types import (
    HubContentType,
    JumpStartEstimatorDeployKwargs,
    JumpStartEstimatorFitKwargs,
    JumpStartEstimatorInitKwargs,
    JumpStartModelDeployKwargs,
    JumpStartModelInitKwargs,
    JumpStartModelSpecs,
)
from sagemaker.core.jumpstart.utils import (
    add_hub_content_arn_tags,
    add_jumpstart_model_info_tags,
    add_bedrock_store_tags,
    get_default_jumpstart_session_with_user_agent_suffix,
    get_top_ranked_config_name,
    update_dict_if_key_not_present,
    resolve_model_sagemaker_config_field,
    verify_model_region_and_return_specs,
    get_draft_model_content_bucket,
)

from sagemaker.core.model_monitor.data_capture_config import DataCaptureConfig

from sagemaker.serve.serverless.serverless_inference_config import ServerlessInferenceConfig
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.common_utils import (
    camel_case_to_pascal_case,
    name_from_base,
    format_tags,
    Tags,
)
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.serve.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker.core import resource_requirements
from sagemaker.core.enums import EndpointType


KwargsType = Union[
    JumpStartModelDeployKwargs,
    JumpStartModelInitKwargs,
    JumpStartEstimatorFitKwargs,
    JumpStartEstimatorInitKwargs,
    JumpStartEstimatorDeployKwargs,
]


[docs] def get_model_info_default_kwargs( kwargs: KwargsType, include_config_name: bool = True, include_model_version: bool = True, include_tolerate_flags: bool = True, ) -> dict: """Returns a dictionary of model info kwargs to use with JumpStart APIs.""" kwargs_dict = { "model_id": kwargs.model_id, "hub_arn": kwargs.hub_arn, "region": kwargs.region, "sagemaker_session": kwargs.sagemaker_session, "model_type": kwargs.model_type, } if include_config_name: kwargs_dict.update({"config_name": kwargs.config_name}) if include_model_version: kwargs_dict.update({"model_version": kwargs.model_version}) if include_tolerate_flags: kwargs_dict.update( { "tolerate_deprecated_model": kwargs.tolerate_deprecated_model, "tolerate_vulnerable_model": kwargs.tolerate_vulnerable_model, } ) return kwargs_dict
def _set_temp_sagemaker_session_if_not_set(kwargs: KwargsType) -> Tuple[KwargsType, Session]: """Sets a temporary sagemaker session if one is not set, and returns original session. We need to create a default JS session (without custom user agent) in order to retrieve config name info. """ orig_session = kwargs.sagemaker_session if kwargs.sagemaker_session is None: kwargs.sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION return kwargs, orig_session def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets region kwargs based on default or override, returns full kwargs.""" kwargs.region = ( kwargs.region or kwargs.sagemaker_session.boto_region_name or JUMPSTART_DEFAULT_REGION_NAME ) return kwargs def _add_sagemaker_session_with_custom_user_agent_to_kwargs( kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs], orig_session: Optional[Session], ) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" kwargs.sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix( model_id=kwargs.model_id, model_version=kwargs.model_version, config_name=kwargs.config_name, is_hub_content=kwargs.hub_arn is not None, ) return kwargs def _add_role_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets role based on default or override, returns full kwargs.""" kwargs.role = resolve_model_sagemaker_config_field( field_name="role", field_val=kwargs.role, sagemaker_session=kwargs.sagemaker_session, default_value=kwargs.role, ) return kwargs def _add_model_version_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: """Sets model version based on default or override, returns full kwargs.""" kwargs.model_version = kwargs.model_version or "*" if kwargs.hub_arn: hub_content_version = kwargs.specs.version kwargs.model_version = hub_content_version return kwargs def _add_vulnerable_and_deprecated_status_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: """Sets deprecated and vulnerability check status, returns full kwargs.""" kwargs.tolerate_deprecated_model = kwargs.tolerate_deprecated_model or False kwargs.tolerate_vulnerable_model = kwargs.tolerate_vulnerable_model or False return kwargs def _add_instance_type_to_kwargs( kwargs: JumpStartModelInitKwargs, disable_instance_type_logging: bool = False ) -> JumpStartModelInitKwargs: """Sets instance type based on default or override, returns full kwargs.""" orig_instance_type = kwargs.instance_type kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default( **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.INFERENCE, training_instance_type=kwargs.training_instance_type, ) if not disable_instance_type_logging and orig_instance_type is None: JUMPSTART_LOGGER.info( "No instance type selected for inference hosting endpoint. Defaulting to %s.", kwargs.instance_type, ) specs = kwargs.specs if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: return kwargs resolved_config = ( specs.inference_configs.configs[kwargs.config_name].resolved_config if specs.inference_configs else None ) if resolved_config is None: return kwargs supported_instance_types = resolved_config.get("supported_inference_instance_types", []) if kwargs.instance_type not in supported_instance_types: JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) return kwargs def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets image uri based on default or override, returns full kwargs. Uses placeholder image uri for JumpStart proprietary models that uses ModelPackages """ if kwargs.model_type == JumpStartModelType.PROPRIETARY: kwargs.image_uri = None return kwargs kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( **get_model_info_default_kwargs(kwargs), framework=None, image_scope=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, ) return kwargs def _add_model_reference_arn_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" hub_content_type = kwargs.specs.hub_content_type kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None if hub_content_type == HubContentType.MODEL_REFERENCE: kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs( hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version ) else: kwargs.model_reference_arn = None return kwargs def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" if kwargs.model_type == JumpStartModelType.PROPRIETARY: kwargs.model_data = None return kwargs model_info_kwargs = get_model_info_default_kwargs(kwargs) model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve( **model_info_kwargs, model_scope=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, ) if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"): old_model_data_str = model_data model_data = { "S3DataSource": { "S3Uri": model_data, "S3DataType": "S3Prefix", "CompressionType": "None", } } if kwargs.model_data: JUMPSTART_LOGGER.info( "S3 prefix model_data detected for JumpStartModel: '%s'. " "Converting to S3DataSource dictionary: '%s'.", old_model_data_str, json.dumps(model_data), ) kwargs.model_data = model_data return kwargs def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets source dir based on default or override, returns full kwargs.""" if kwargs.model_type == JumpStartModelType.PROPRIETARY: kwargs.source_dir = None return kwargs source_dir = kwargs.source_dir if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)): source_dir = source_dir or script_uris.retrieve( **get_model_info_default_kwargs(kwargs), script_scope=JumpStartScriptScope.INFERENCE ) kwargs.source_dir = source_dir return kwargs def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets entry point based on default or override, returns full kwargs.""" if kwargs.model_type == JumpStartModelType.PROPRIETARY: kwargs.entry_point = None return kwargs entry_point = kwargs.entry_point if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)): entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME kwargs.entry_point = entry_point return kwargs def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets env based on default or override, returns full kwargs.""" if kwargs.model_type == JumpStartModelType.PROPRIETARY: kwargs.env = None return kwargs env = kwargs.env if env is None: env = {} extra_env_vars = environment_variables.retrieve_default( **get_model_info_default_kwargs(kwargs), include_aws_sdk_env_vars=False, script=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, ) for key, value in extra_env_vars.items(): update_dict_if_key_not_present( env, key, value, ) if env == {}: env = None kwargs.env = env return kwargs def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model package arn based on default or override, returns full kwargs.""" model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn( **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type, scope=JumpStartScriptScope.INFERENCE, ) kwargs.model_package_arn = model_package_arn return kwargs def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets extra kwargs based on default or override, returns full kwargs.""" model_kwargs_to_add = _retrieve_model_init_kwargs(**get_model_info_default_kwargs(kwargs)) for key, value in model_kwargs_to_add.items(): if getattr(kwargs, key) is None: resolved_value = resolve_model_sagemaker_config_field( field_name=key, field_val=value, sagemaker_session=kwargs.sagemaker_session, ) setattr(kwargs, key, resolved_value) return kwargs def _add_endpoint_name_to_kwargs( kwargs: Optional[JumpStartModelDeployKwargs], ) -> JumpStartModelDeployKwargs: """Sets resource name based on default or override, returns full kwargs.""" default_endpoint_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs)) kwargs.endpoint_name = kwargs.endpoint_name or ( name_from_base(default_endpoint_name) if default_endpoint_name is not None else None ) return kwargs def _add_model_name_to_kwargs( kwargs: Optional[JumpStartModelInitKwargs], ) -> JumpStartModelInitKwargs: """Sets resource name based on default or override, returns full kwargs.""" default_model_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs)) kwargs.name = kwargs.name or ( name_from_base(default_model_name) if default_model_name is not None else None ) return kwargs def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: """Sets tags based on default or override, returns full kwargs.""" full_model_version = kwargs.specs.version if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_info_tags( kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, config_name=kwargs.config_name, scope=JumpStartScriptScope.INFERENCE, ) if kwargs.hub_arn: if kwargs.model_reference_arn: hub_content_arn = construct_hub_model_reference_arn_from_inputs( kwargs.hub_arn, kwargs.model_id, kwargs.model_version ) else: hub_content_arn = construct_hub_model_arn_from_inputs( kwargs.hub_arn, kwargs.model_id, kwargs.model_version ) kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None: if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities: kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible") return kwargs def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any]: """Sets extra kwargs based on default or override, returns full kwargs.""" deploy_kwargs_to_add = _retrieve_model_deploy_kwargs( **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type ) for key, value in deploy_kwargs_to_add.items(): if getattr(kwargs, key) is None: setattr(kwargs, key, value) return kwargs def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets the resource requirements based on the default or an override. Returns full kwargs.""" kwargs.resources = kwargs.resources or resource_requirements.retrieve_default( **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, ) return kwargs def _select_inference_config_from_training_config( specs: JumpStartModelSpecs, training_config_name: str ) -> Optional[str]: """Selects the inference config from the training config. Args: specs (JumpStartModelSpecs): The specs for the model. training_config_name (str): The name of the training config. Returns: str: The name of the inference config. """ if specs.training_configs: resolved_training_config = specs.training_configs.configs.get(training_config_name) if resolved_training_config: return resolved_training_config.default_inference_config return None def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets default config name to the kwargs. Returns full kwargs. Raises: ValueError: If the instance_type is not supported with the current config. """ kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( **get_model_info_default_kwargs(kwargs, include_config_name=False), scope=JumpStartScriptScope.INFERENCE, ) if kwargs.config_name is None: return kwargs return kwargs def _add_additional_model_data_sources_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: """Sets default additional model data sources to init kwargs""" specs = kwargs.specs # Append speculative decoding data source from metadata speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources() for data_source in speculative_decoding_data_sources: data_source.s3_data_source.set_bucket( get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region) ) api_shape_additional_model_data_sources = ( [ camel_case_to_pascal_case(data_source.to_json()) for data_source in speculative_decoding_data_sources ] if specs.get_speculative_decoding_s3_data_sources() else None ) kwargs.additional_model_data_sources = ( kwargs.additional_model_data_sources or api_shape_additional_model_data_sources ) return kwargs def _add_config_name_to_deploy_kwargs( kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None ) -> JumpStartModelInitKwargs: """Sets default config name to the kwargs. Returns full kwargs. If a training_config_name is passed, then choose the inference config based on the supported inference configs in that training config. Raises: ValueError: If the instance_type is not supported with the current config. """ if training_config_name: specs = kwargs.specs default_config_name = _select_inference_config_from_training_config( specs=specs, training_config_name=training_config_name ) else: default_config_name = kwargs.config_name or get_top_ranked_config_name( **get_model_info_default_kwargs(kwargs, include_config_name=False), scope=JumpStartScriptScope.INFERENCE, ) kwargs.config_name = kwargs.config_name or default_config_name return kwargs
[docs] def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, serializer: Optional[BaseSerializer] = None, deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, data_capture_config: Optional[DataCaptureConfig] = None, async_inference_config: Optional[AsyncInferenceConfig] = None, serverless_inference_config: Optional[ServerlessInferenceConfig] = None, volume_size: Optional[int] = None, model_data_download_timeout: Optional[int] = None, container_startup_health_check_timeout: Optional[int] = None, inference_recommendation_id: Optional[str] = None, explainer_config: Optional[ExplainerConfig] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, accept_eula: Optional[bool] = None, model_reference_arn: Optional[str] = None, endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, training_config_name: Optional[str] = None, config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, inference_ami_version: Optional[str] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, hub_arn=hub_arn, model_type=model_type, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, serializer=serializer, deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, data_capture_config=data_capture_config, async_inference_config=async_inference_config, serverless_inference_config=serverless_inference_config, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, accept_eula=accept_eula, model_reference_arn=model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, config_name=config_name, routing_config=routing_config, model_access_configs=model_access_configs, inference_ami_version=inference_ami_version, ) deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs) deploy_kwargs.specs = verify_model_region_and_return_specs( **get_model_info_default_kwargs( deploy_kwargs, include_model_version=False, include_tolerate_flags=False ), version=deploy_kwargs.model_version or "*", scope=JumpStartScriptScope.INFERENCE, # We set these flags to True to retrieve the json specs. # Exceptions will be thrown later if these are not tolerated. tolerate_deprecated_model=True, tolerate_vulnerable_model=True, ) deploy_kwargs = _add_config_name_to_deploy_kwargs( kwargs=deploy_kwargs, training_config_name=training_config_name ) deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( kwargs=deploy_kwargs, orig_session=orig_session ) deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs.initial_instance_count = initial_instance_count or 1 deploy_kwargs = _add_deploy_extra_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_tags_to_kwargs(kwargs=deploy_kwargs) if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: deploy_kwargs = _add_resources_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs.endpoint_type = endpoint_type deploy_kwargs.managed_instance_scaling = managed_instance_scaling return deploy_kwargs
[docs] def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, region: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, enable_network_isolation: Union[bool, PipelineVariable] = None, model_kms_key: Optional[str] = None, image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, source_dir: Optional[str] = None, code_location: Optional[str] = None, entry_point: Optional[str] = None, container_log_level: Optional[Union[int, PipelineVariable]] = None, dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, training_instance_type: Optional[str] = None, disable_instance_type_logging: bool = False, resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, additional_model_data_sources: Optional[Dict[str, Any]] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs( model_id=model_id, model_version=model_version, hub_arn=hub_arn, model_type=model_type, instance_type=instance_type, region=region, image_uri=image_uri, model_data=model_data, source_dir=source_dir, entry_point=entry_point, env=env, role=role, name=name, vpc_config=vpc_config, sagemaker_session=sagemaker_session, enable_network_isolation=enable_network_isolation, model_kms_key=model_kms_key, image_config=image_config, code_location=code_location, container_log_level=container_log_level, dependencies=dependencies, git_config=git_config, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, model_package_arn=model_package_arn, training_instance_type=training_instance_type, resources=resources, config_name=config_name, additional_model_data_sources=additional_model_data_sources, ) model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( kwargs=model_init_kwargs ) model_init_kwargs.specs = verify_model_region_and_return_specs( **get_model_info_default_kwargs( model_init_kwargs, include_model_version=False, include_tolerate_flags=False ), version=model_init_kwargs.model_version or "*", scope=JumpStartScriptScope.INFERENCE, # We set these flags to True to retrieve the json specs. # Exceptions will be thrown later if these are not tolerated. tolerate_deprecated_model=True, tolerate_vulnerable_model=True, ) model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( kwargs=model_init_kwargs, orig_session=orig_session ) model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_instance_type_to_kwargs( kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging ) model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs) if hub_arn: model_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=model_init_kwargs) else: model_init_kwargs.model_reference_arn = None model_init_kwargs.hub_content_type = None # we use the model artifact from the training job output if not model_from_estimator: model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_source_dir_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_entry_point_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_env_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs) return model_init_kwargs