Source code for sagemaker.core.jumpstart.hub.interfaces

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores types related to SageMaker JumpStart HubAPI requests and responses."""
from __future__ import absolute_import

from enum import Enum
import re
import json
import datetime

from typing import Any, Dict, List, Union, Optional
from sagemaker.core.jumpstart.enums import JumpStartScriptScope
from sagemaker.core.jumpstart.types import (
    HubContentType,
    HubArnExtractedInfo,
    JumpStartConfigComponent,
    JumpStartConfigRanking,
    JumpStartMetadataConfig,
    JumpStartMetadataConfigs,
    JumpStartPredictorSpecs,
    JumpStartHyperparameter,
    JumpStartDataHolderType,
    JumpStartEnvironmentVariable,
    JumpStartSerializablePayload,
    JumpStartInstanceTypeVariants,
)
from sagemaker.core.jumpstart.hub.parser_utils import (
    snake_to_upper_camel,
    walk_and_apply_json,
)


class _ComponentType(str, Enum):
    """Enum for different component types."""

    INFERENCE = "Inference"
    TRAINING = "Training"


[docs] class HubDataHolderType(JumpStartDataHolderType): """Base class for many Hub API interfaces."""
[docs] def to_json(self) -> Dict[str, Any]: """Returns json representation of object.""" json_obj = {} for att in self.__slots__: if att in self._non_serializable_slots: continue if hasattr(self, att): cur_val = getattr(self, att) # Do not serialize null values. if cur_val is None: continue if issubclass(type(cur_val), JumpStartDataHolderType): json_obj[att] = cur_val.to_json() elif isinstance(cur_val, list): json_obj[att] = [] for obj in cur_val: if issubclass(type(obj), JumpStartDataHolderType): json_obj[att].append(obj.to_json()) else: json_obj[att].append(obj) elif isinstance(cur_val, datetime.datetime): json_obj[att] = str(cur_val) else: json_obj[att] = cur_val return json_obj
def __str__(self) -> str: """Returns string representation of object. Example: "{'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ att_dict = walk_and_apply_json(self.to_json(), snake_to_upper_camel) return f"{json.dumps(att_dict, default=lambda o: o.to_json())}"
[docs] class CreateHubResponse(HubDataHolderType): """Data class for the Hub from session.create_hub()""" __slots__ = [ "hub_arn", ] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates CreateHubResponse object. Args: json_obj (Dict[str, Any]): Dictionary representation of session.create_hub() response. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.hub_arn: str = json_obj["HubArn"]
[docs] class HubContentDependency(HubDataHolderType): """Data class for any dependencies related to hub content. Content can be scripts, model artifacts, datasets, or notebooks. """ __slots__ = ["dependency_copy_path", "dependency_origin_path", "dependency_type"] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates HubContentDependency object Args: json_obj (Dict[str, Any]): Dictionary representation of hub content description. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub content description. """ self.dependency_copy_path: Optional[str] = json_obj.get("DependencyCopyPath", "") self.dependency_origin_path: Optional[str] = json_obj.get("DependencyOriginPath", "") self.dependency_type: Optional[str] = json_obj.get("DependencyType", "")
[docs] class DescribeHubContentResponse(HubDataHolderType): """Data class for the Hub Content from session.describe_hub_contents()""" __slots__ = [ "creation_time", "document_schema_version", "failure_reason", "hub_arn", "hub_content_arn", "hub_content_dependencies", "hub_content_description", "hub_content_display_name", "hub_content_document", "hub_content_markdown", "hub_content_name", "hub_content_search_keywords", "hub_content_status", "hub_content_type", "hub_content_version", "reference_min_version", "hub_name", "_region", ] _non_serializable_slots = ["_region"] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates DescribeHubContentResponse object. Args: json_obj (Dict[str, Any]): Dictionary representation of hub content description. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub content description. """ self.creation_time: datetime.datetime = json_obj["CreationTime"] self.document_schema_version: str = json_obj["DocumentSchemaVersion"] self.failure_reason: Optional[str] = json_obj.get("FailureReason") self.hub_arn: str = json_obj["HubArn"] self.hub_content_arn: str = json_obj["HubContentArn"] self.hub_content_dependencies = [] if "Dependencies" in json_obj: self.hub_content_dependencies: Optional[List[HubContentDependency]] = [ HubContentDependency(dep) for dep in json_obj.get(["Dependencies"]) ] self.hub_content_description: str = json_obj.get("HubContentDescription") self.hub_content_display_name: str = json_obj.get("HubContentDisplayName") hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) self._region = hub_region self.hub_content_type: str = json_obj.get("HubContentType") hub_content_document = json.loads(json_obj["HubContentDocument"]) if self.hub_content_type == HubContentType.MODEL: self.hub_content_document: HubContentDocument = HubModelDocument( json_obj=hub_content_document, region=self._region, dependencies=self.hub_content_dependencies, ) elif self.hub_content_type == HubContentType.MODEL_REFERENCE: self.hub_content_document: HubContentDocument = HubModelDocument( json_obj=hub_content_document, region=self._region, dependencies=self.hub_content_dependencies, ) elif self.hub_content_type == HubContentType.NOTEBOOK: self.hub_content_document: HubContentDocument = HubNotebookDocument( json_obj=hub_content_document, region=self._region ) else: raise ValueError( f"[{self.hub_content_type}] is not a valid HubContentType." f"Should be one of: {[item.name for item in HubContentType]}." ) self.hub_content_markdown: str = json_obj.get("HubContentMarkdown") self.hub_content_name: str = json_obj["HubContentName"] self.hub_content_search_keywords: List[str] = json_obj.get("HubContentSearchKeywords") self.hub_content_status: str = json_obj["HubContentStatus"] self.hub_content_version: str = json_obj["HubContentVersion"] self.hub_name: str = json_obj["HubName"]
[docs] def get_hub_region(self) -> Optional[str]: """Returns the region hub is in.""" return self._region
[docs] class HubS3StorageConfig(HubDataHolderType): """Data class for any dependencies related to hub content. Includes scripts, model artifacts, datasets, or notebooks. """ __slots__ = ["s3_output_path"] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates HubS3StorageConfig object Args: json_obj (Dict[str, Any]): Dictionary representation of hub content description. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub content description. """ self.s3_output_path: Optional[str] = json_obj.get("S3OutputPath", "")
[docs] class DescribeHubResponse(HubDataHolderType): """Data class for the Hub from session.describe_hub()""" __slots__ = [ "creation_time", "failure_reason", "hub_arn", "hub_description", "hub_display_name", "hub_name", "hub_search_keywords", "hub_status", "last_modified_time", "s3_storage_config", "_region", ] _non_serializable_slots = ["_region"] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates DescribeHubResponse object. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) self.failure_reason: str = json_obj["FailureReason"] self.hub_arn: str = json_obj["HubArn"] hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) self._region = hub_region self.hub_description: str = json_obj["HubDescription"] self.hub_display_name: str = json_obj["HubDisplayName"] self.hub_name: str = json_obj["HubName"] self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] self.hub_status: str = json_obj["HubStatus"] self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"]) self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig(json_obj["S3StorageConfig"])
[docs] def get_hub_region(self) -> Optional[str]: """Returns the region hub is in.""" return self._region
[docs] class ImportHubResponse(HubDataHolderType): """Data class for the Hub from session.import_hub()""" __slots__ = [ "hub_arn", "hub_content_arn", ] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates ImportHubResponse object. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.hub_arn: str = json_obj["HubArn"] self.hub_content_arn: str = json_obj["HubContentArn"]
[docs] class HubSummary(HubDataHolderType): """Data class for the HubSummary from session.list_hubs()""" __slots__ = [ "creation_time", "hub_arn", "hub_description", "hub_display_name", "hub_name", "hub_search_keywords", "hub_status", "last_modified_time", ] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates HubSummary object. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) self.hub_arn: str = json_obj["HubArn"] self.hub_description: str = json_obj["HubDescription"] self.hub_display_name: str = json_obj["HubDisplayName"] self.hub_name: str = json_obj["HubName"] self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] self.hub_status: str = json_obj["HubStatus"] self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"])
[docs] class ListHubsResponse(HubDataHolderType): """Data class for the Hub from session.list_hubs()""" __slots__ = [ "hub_summaries", "next_token", ] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates ListHubsResponse object. Args: json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. """ self.hub_summaries: List[HubSummary] = [ HubSummary(item) for item in json_obj["HubSummaries"] ] self.next_token: str = json_obj["NextToken"]
[docs] class EcrUri(HubDataHolderType): """Data class for ECR image uri.""" __slots__ = ["account", "region_name", "repository", "tag"] def __init__(self, uri: str): """Instantiates EcrUri object.""" self.from_ecr_uri(uri)
[docs] def from_ecr_uri(self, uri: str) -> None: """Parse a given aws ecr image uri into its various components.""" uri_regex = ( r"^(?:(?P<account_id>[a-zA-Z0-9][\w-]*)\.dkr\.ecr\.(?P<region>[a-zA-Z0-9][\w-]*)" r"\.(?P<tld>[a-zA-Z0-9\.-]+))\/(?P<repository_name>([a-z0-9]+" r"(?:[._-][a-z0-9]+)*\/)*[a-z0-9]+(?:[._-][a-z0-9]+)*)(:*)(?P<image_tag>.*)?" ) parsed_image_uri = re.compile(uri_regex).match(uri) account = parsed_image_uri.group("account_id") region = parsed_image_uri.group("region") repository = parsed_image_uri.group("repository_name") tag = parsed_image_uri.group("image_tag") self.account = account self.region_name = region self.repository = repository self.tag = tag
[docs] class NotebookLocationUris(HubDataHolderType): """Data class for Notebook Location uri.""" __slots__ = ["demo_notebook", "model_fit", "model_deploy"] def __init__(self, json_obj: Dict[str, Any]): """Instantiates EcrUri object.""" self.from_json(json_obj)
[docs] def from_json(self, json_obj: str) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.demo_notebook = json_obj.get("demo_notebook") self.model_fit = json_obj.get("model_fit") self.model_deploy = json_obj.get("model_deploy")
[docs] class HubModelDocument(HubDataHolderType): """Data class for model type HubContentDocument from session.describe_hub_content().""" SCHEMA_VERSION = "2.3.0" __slots__ = [ "url", "min_sdk_version", "training_supported", "model_types", "capabilities", "incremental_training_supported", "dynamic_container_deployment_supported", "hosting_ecr_uri", "hosting_artifact_s3_data_type", "hosting_artifact_compression_type", "hosting_artifact_uri", "hosting_prepacked_artifact_uri", "hosting_prepacked_artifact_version", "hosting_script_uri", "hosting_use_script_uri", "hosting_eula_uri", "hosting_model_package_arn", "inference_ami_version", "model_subscription_link", "inference_configs", "inference_config_components", "inference_config_rankings", "training_artifact_s3_data_type", "training_artifact_compression_type", "training_model_package_artifact_uri", "hyperparameters", "inference_environment_variables", "training_script_uri", "training_prepacked_script_uri", "training_prepacked_script_version", "training_ecr_uri", "training_metrics", "training_artifact_uri", "training_configs", "training_config_components", "training_config_rankings", "inference_dependencies", "training_dependencies", "default_inference_instance_type", "supported_inference_instance_types", "default_training_instance_type", "supported_training_instance_types", "sage_maker_sdk_predictor_specifications", "inference_volume_size", "training_volume_size", "inference_enable_network_isolation", "training_enable_network_isolation", "fine_tuning_supported", "validation_supported", "default_training_dataset_uri", "resource_name_base", "gated_bucket", "default_payloads", "hosting_resource_requirements", "hosting_instance_type_variants", "training_instance_type_variants", "notebook_location_uris", "model_provider_icon_uri", "task", "framework", "datatype", "license", "contextual_help", "model_data_download_timeout", "container_startup_health_check_timeout", "encrypt_inter_container_traffic", "max_runtime_in_seconds", "disable_output_compression", "model_dir", "dependencies", "_region", ] _non_serializable_slots = ["_region"] def __init__( self, json_obj: Dict[str, Any], region: str, dependencies: List[HubContentDependency] = None, ) -> None: """Instantiates HubModelDocument object. Args: json_obj (Dict[str, Any]): Dictionary representation of hub content document. Raises: ValueError: When one of (json_obj) or (model_specs and studio_specs) is not provided. """ self._region = region self.dependencies = dependencies or [] self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub model document. """ self.url: str = json_obj.get("Url") self.min_sdk_version: str = json_obj.get("MinSdkVersion") self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri") self.hosting_artifact_uri = json_obj.get("HostingArtifactUri") self.hosting_script_uri = json_obj.get("HostingScriptUri") self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies") self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [ JumpStartEnvironmentVariable(env_variable, is_hub_content=True) for env_variable in json_obj.get("InferenceEnvironmentVariables", []) ] self.model_types: Optional[List[str]] = json_obj.get("ModelTypes") self.capabilities: Optional[List[str]] = json_obj.get("Capabilities") self.training_supported: bool = bool(json_obj.get("TrainingSupported")) self.incremental_training_supported: bool = bool( json_obj.get("IncrementalTrainingSupported") ) self.dynamic_container_deployment_supported: Optional[bool] = ( bool(json_obj.get("DynamicContainerDeploymentSupported")) if json_obj.get("DynamicContainerDeploymentSupported") else None ) self.hosting_artifact_s3_data_type: Optional[str] = json_obj.get( "HostingArtifactS3DataType" ) self.hosting_artifact_compression_type: Optional[str] = json_obj.get( "HostingArtifactCompressionType" ) self.hosting_prepacked_artifact_uri: Optional[str] = json_obj.get( "HostingPrepackedArtifactUri" ) self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get( "HostingPrepackedArtifactVersion" ) self.hosting_use_script_uri: Optional[bool] = ( bool(json_obj.get("HostingUseScriptUri")) if json_obj.get("HostingUseScriptUri") is not None else None ) self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") self.inference_ami_version: Optional[str] = json_obj.get("InferenceAmiVersion") self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink") self.inference_config_rankings = self._get_config_rankings(json_obj) self.inference_config_components = self._get_config_components(json_obj) self.inference_configs = self._get_configs(json_obj) self.default_inference_instance_type: Optional[str] = json_obj.get( "DefaultInferenceInstanceType" ) self.supported_inference_instance_types: Optional[str] = json_obj.get( "SupportedInferenceInstanceTypes" ) self.sage_maker_sdk_predictor_specifications: Optional[JumpStartPredictorSpecs] = ( JumpStartPredictorSpecs( json_obj.get("SageMakerSdkPredictorSpecifications"), is_hub_content=True, ) if json_obj.get("SageMakerSdkPredictorSpecifications") else None ) self.inference_volume_size: Optional[int] = json_obj.get("InferenceVolumeSize") self.inference_enable_network_isolation: Optional[str] = json_obj.get( "InferenceEnableNetworkIsolation", False ) self.fine_tuning_supported: Optional[bool] = ( bool(json_obj.get("FineTuningSupported")) if json_obj.get("FineTuningSupported") else None ) self.validation_supported: Optional[bool] = ( bool(json_obj.get("ValidationSupported")) if json_obj.get("ValidationSupported") else None ) self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase") self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False)) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( { alias: JumpStartSerializablePayload(payload, is_hub_content=True) for alias, payload in json_obj.get("DefaultPayloads").items() } if json_obj.get("DefaultPayloads") else None ) self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get( "HostingResourceRequirements", None ) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( JumpStartInstanceTypeVariants( json_obj.get("HostingInstanceTypeVariants"), is_hub_content=True, ) if json_obj.get("HostingInstanceTypeVariants") else None ) self.notebook_location_uris: Optional[NotebookLocationUris] = ( NotebookLocationUris(json_obj.get("NotebookLocationUris")) if json_obj.get("NotebookLocationUris") else None ) self.model_provider_icon_uri: Optional[str] = None # Not needed for private beta self.task: Optional[str] = json_obj.get("Task") self.framework: Optional[str] = json_obj.get("Framework") self.datatype: Optional[str] = json_obj.get("Datatype") self.license: Optional[str] = json_obj.get("License") self.contextual_help: Optional[str] = json_obj.get("ContextualHelp") self.model_dir: Optional[str] = json_obj.get("ModelDir") # Deploy kwargs self.model_data_download_timeout: Optional[str] = json_obj.get("ModelDataDownloadTimeout") self.container_startup_health_check_timeout: Optional[str] = json_obj.get( "ContainerStartupHealthCheckTimeout" ) if self.training_supported: self.default_training_dataset_uri: Optional[str] = json_obj.get( "DefaultTrainingDatasetUri" ) self.training_model_package_artifact_uri: Optional[str] = json_obj.get( "TrainingModelPackageArtifactUri" ) self.training_artifact_compression_type: Optional[str] = json_obj.get( "TrainingArtifactCompressionType" ) self.training_artifact_s3_data_type: Optional[str] = json_obj.get( "TrainingArtifactS3DataType" ) self.hyperparameters: List[JumpStartHyperparameter] = [] hyperparameters: Any = json_obj.get("Hyperparameters") if hyperparameters is not None: self.hyperparameters.extend( [ JumpStartHyperparameter(hyperparameter, is_hub_content=True) for hyperparameter in hyperparameters ] ) self.training_script_uri: Optional[str] = json_obj.get("TrainingScriptUri") self.training_prepacked_script_uri: Optional[str] = json_obj.get( "TrainingPrepackedScriptUri" ) self.training_prepacked_script_version: Optional[str] = json_obj.get( "TrainingPrepackedScriptVersion" ) self.training_ecr_uri: Optional[str] = json_obj.get("TrainingEcrUri") self._non_serializable_slots.append("training_ecr_specs") self.training_metrics: Optional[List[Dict[str, str]]] = json_obj.get( "TrainingMetrics", None ) self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri") self.training_config_rankings = self._get_config_rankings( json_obj, _ComponentType.TRAINING ) self.training_config_components = self._get_config_components( json_obj, _ComponentType.TRAINING ) self.training_configs = self._get_configs(json_obj, _ComponentType.TRAINING) self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies") self.default_training_instance_type: Optional[str] = json_obj.get( "DefaultTrainingInstanceType" ) self.supported_training_instance_types: Optional[str] = json_obj.get( "SupportedTrainingInstanceTypes" ) self.training_volume_size: Optional[int] = json_obj.get("TrainingVolumeSize") self.training_enable_network_isolation: Optional[str] = json_obj.get( "TrainingEnableNetworkIsolation", False ) self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( JumpStartInstanceTypeVariants( json_obj.get("TrainingInstanceTypeVariants"), is_hub_content=True, ) if json_obj.get("TrainingInstanceTypeVariants") else None ) # Estimator kwargs self.encrypt_inter_container_traffic: Optional[bool] = ( bool(json_obj.get("EncryptInterContainerTraffic")) if json_obj.get("EncryptInterContainerTraffic") else None ) self.max_runtime_in_seconds: Optional[str] = json_obj.get("MaxRuntimeInSeconds") self.disable_output_compression: Optional[bool] = ( bool(json_obj.get("DisableOutputCompression")) if json_obj.get("DisableOutputCompression") else None )
[docs] def get_schema_version(self) -> str: """Returns schema version.""" return self.SCHEMA_VERSION
[docs] def get_region(self) -> str: """Returns hub region.""" return self._region
def _get_config_rankings( self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE ) -> Optional[Dict[str, JumpStartConfigRanking]]: """Returns config rankings.""" config_rankings = json_obj.get(f"{component_type.value}ConfigRankings") return ( { alias: JumpStartConfigRanking(ranking, is_hub_content=True) for alias, ranking in config_rankings.items() } if config_rankings else None ) def _get_config_components( self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE ) -> Optional[Dict[str, JumpStartConfigComponent]]: """Returns config components.""" config_components = json_obj.get(f"{component_type.value}ConfigComponents") return ( { alias: JumpStartConfigComponent(alias, config, is_hub_content=True) for alias, config in config_components.items() } if config_components else None ) def _get_configs( self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE ) -> Optional[JumpStartMetadataConfigs]: """Returns configs.""" if not (configs := json_obj.get(f"{component_type.value}Configs")): return None configs_dict = {} for alias, config in configs.items(): config_components = None if isinstance(config, dict) and (component_names := config.get("ComponentNames")): config_components = { name: getattr(self, f"{component_type.value.lower()}_config_components").get( name ) for name in component_names } configs_dict[alias] = JumpStartMetadataConfig( alias, config, json_obj, config_components, is_hub_content=True ) if component_type == _ComponentType.INFERENCE: config_rankings = self.inference_config_rankings scope = JumpStartScriptScope.INFERENCE else: config_rankings = self.training_config_rankings scope = JumpStartScriptScope.TRAINING return JumpStartMetadataConfigs(configs_dict, config_rankings, scope)
[docs] class HubNotebookDocument(HubDataHolderType): """Data class for notebook type HubContentDocument from session.describe_hub_content().""" SCHEMA_VERSION = "1.0.0" __slots__ = ["notebook_location", "dependencies", "_region"] _non_serializable_slots = ["_region"] def __init__(self, json_obj: Dict[str, Any], region: str) -> None: """Instantiates HubNotebookDocument object. Args: json_obj (Dict[str, Any]): Dictionary representation of hub content document. """ self._region = region self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub content description. """ self.notebook_location = json_obj["NotebookLocation"] self.dependencies: List[HubContentDependency] = [ HubContentDependency(dep) for dep in json_obj["Dependencies"] ]
[docs] def get_schema_version(self) -> str: """Returns schema version.""" return self.SCHEMA_VERSION
[docs] def get_region(self) -> str: """Returns hub region.""" return self._region
HubContentDocument = Union[HubModelDocument, HubNotebookDocument]
[docs] class HubContentInfo(HubDataHolderType): """Data class for the HubContentInfo from session.list_hub_contents().""" __slots__ = [ "creation_time", "document_schema_version", "hub_content_arn", "hub_content_name", "hub_content_status", "hub_content_type", "hub_content_version", "hub_content_description", "hub_content_display_name", "hub_content_search_keywords", "_region", ] _non_serializable_slots = ["_region"] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates HubContentInfo object. Args: json_obj (Dict[str, Any]): Dictionary representation of hub content description. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub content description. """ self.creation_time: str = json_obj["CreationTime"] self.document_schema_version: str = json_obj["DocumentSchemaVersion"] self.hub_content_arn: str = json_obj["HubContentArn"] self.hub_content_name: str = json_obj["HubContentName"] self.hub_content_status: str = json_obj["HubContentStatus"] self.hub_content_type: HubContentType = HubContentType(json_obj["HubContentType"]) self.hub_content_version: str = json_obj["HubContentVersion"] self.hub_content_description: Optional[str] = json_obj.get("HubContentDescription") self.hub_content_display_name: Optional[str] = json_obj.get("HubContentDisplayName") self._region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn( self.hub_content_arn ) self.hub_content_search_keywords: Optional[List[str]] = json_obj.get( "HubContentSearchKeywords" )
[docs] def get_hub_region(self) -> Optional[str]: """Returns the region hub is in.""" return self._region
[docs] class ListHubContentsResponse(HubDataHolderType): """Data class for the Hub from session.list_hub_contents()""" __slots__ = [ "hub_content_summaries", "next_token", ] def __init__(self, json_obj: Dict[str, Any]) -> None: """Instantiates ImportHubResponse object. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.from_json(json_obj)
[docs] def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representation of hub description. """ self.hub_content_summaries: List[HubContentInfo] = [ HubContentInfo(item) for item in json_obj["HubContentSummaries"] ] self.next_token: str = json_obj["NextToken"]