from sagemaker.core.common_utils import (
format_tags,
resolve_value_from_config,
update_list_of_dicts_with_values_from_config,
_create_resource,
can_model_package_source_uri_autopopulate,
)
from sagemaker.core.config import (
MODEL_PACKAGE_VALIDATION_ROLE_PATH,
VALIDATION_ROLE,
VALIDATION_PROFILES,
MODEL_PACKAGE_INFERENCE_SPECIFICATION_CONTAINERS_PATH,
MODEL_PACKAGE_VALIDATION_PROFILES_PATH,
)
from sagemaker.core.resources import ModelPackageModelCard
from botocore.exceptions import ClientError
import logging
logger = LOGGER = logging.getLogger("sagemaker")
[docs]
def get_model_package_args(
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
model_data=None,
image_uri=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status=None,
description=None,
tags=None,
container_def_list=None,
drift_check_baselines=None,
customer_metadata_properties=None,
validation_specification=None,
domain=None,
sample_payload_url=None,
task=None,
skip_model_validation=None,
source_uri=None,
model_card=None,
model_life_cycle=None,
):
if container_def_list is not None:
containers = container_def_list
else:
container = {
"Image": image_uri,
}
if model_data is not None:
container["ModelDataUrl"] = model_data
containers = [container]
model_package_args = {
"containers": containers,
"inference_instances": inference_instances,
"transform_instances": transform_instances,
"marketplace_cert": marketplace_cert,
}
if content_types is not None:
model_package_args["content_types"] = content_types
if response_types is not None:
model_package_args["response_types"] = response_types
if model_package_name is not None:
model_package_args["model_package_name"] = model_package_name
if model_package_group_name is not None:
model_package_args["model_package_group_name"] = model_package_group_name
if model_metrics is not None:
model_package_args["model_metrics"] = model_metrics._to_request_dict()
if drift_check_baselines is not None:
model_package_args["drift_check_baselines"] = drift_check_baselines._to_request_dict()
if metadata_properties is not None:
model_package_args["metadata_properties"] = metadata_properties._to_request_dict()
if approval_status is not None:
model_package_args["approval_status"] = approval_status
if description is not None:
model_package_args["description"] = description
if tags is not None:
model_package_args["tags"] = format_tags(tags)
if customer_metadata_properties is not None:
model_package_args["customer_metadata_properties"] = customer_metadata_properties
if validation_specification is not None:
model_package_args["validation_specification"] = validation_specification
if domain is not None:
model_package_args["domain"] = domain
if sample_payload_url is not None:
model_package_args["sample_payload_url"] = sample_payload_url
if task is not None:
model_package_args["task"] = task
if skip_model_validation is not None:
model_package_args["skip_model_validation"] = skip_model_validation
if source_uri is not None:
model_package_args["source_uri"] = source_uri
if model_life_cycle is not None:
model_package_args["model_life_cycle"] = model_life_cycle._to_request_dict()
if model_card is not None:
original_req = {}
if isinstance(model_card, ModelPackageModelCard):
original_req["ModelCardContent"] = model_card.model_card_content
else:
original_req["ModelCardContent"] = model_card.content
original_req["ModelCardStatus"] = model_card.model_card_status
model_package_args["model_card"] = original_req
return model_package_args
[docs]
def get_create_model_package_request(
model_package_name=None,
model_package_group_name=None,
containers=None,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status="PendingManualApproval",
description=None,
tags=None,
drift_check_baselines=None,
customer_metadata_properties=None,
validation_specification=None,
domain=None,
sample_payload_url=None,
task=None,
skip_model_validation="None",
source_uri=None,
model_card=None,
model_life_cycle=None,
):
if all([model_package_name, model_package_group_name]):
raise ValueError(
"model_package_name and model_package_group_name cannot be present at the " "same time."
)
if all([model_package_name, source_uri]):
raise ValueError(
"Un-versioned SageMaker Model Package currently cannot be " "created with source_uri."
)
if (containers is not None) and all(
[
model_package_name,
any(
[
(("ModelDataSource" in c) and (c["ModelDataSource"] is not None))
for c in containers
]
),
]
):
raise ValueError(
"Un-versioned SageMaker Model Package currently cannot be "
"created with ModelDataSource."
)
request_dict = {}
if model_package_name is not None:
request_dict["ModelPackageName"] = model_package_name
if model_package_group_name is not None:
request_dict["ModelPackageGroupName"] = model_package_group_name
if description is not None:
request_dict["ModelPackageDescription"] = description
if tags is not None:
request_dict["Tags"] = format_tags(tags)
if model_metrics:
request_dict["ModelMetrics"] = model_metrics
if drift_check_baselines:
request_dict["DriftCheckBaselines"] = drift_check_baselines
if metadata_properties:
request_dict["MetadataProperties"] = metadata_properties
if customer_metadata_properties is not None:
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
if validation_specification:
request_dict["ValidationSpecification"] = validation_specification
if domain is not None:
request_dict["Domain"] = domain
if sample_payload_url is not None:
request_dict["SamplePayloadUrl"] = sample_payload_url
if task is not None:
request_dict["Task"] = task
if source_uri is not None:
request_dict["SourceUri"] = source_uri
if containers is not None:
inference_specification = {
"Containers": containers,
}
if content_types is not None:
inference_specification.update(
{
"SupportedContentTypes": content_types,
}
)
if response_types is not None:
inference_specification.update(
{
"SupportedResponseMIMETypes": response_types,
}
)
if model_package_group_name is not None:
if inference_instances is not None:
inference_specification.update(
{
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
}
)
if transform_instances is not None:
inference_specification.update(
{
"SupportedTransformInstanceTypes": transform_instances,
}
)
else:
if not all([inference_instances, transform_instances]):
raise ValueError(
"inference_instances and transform_instances "
"must be provided if model_package_group_name is not present."
)
inference_specification.update(
{
"SupportedRealtimeInferenceInstanceTypes": inference_instances,
"SupportedTransformInstanceTypes": transform_instances,
}
)
request_dict["InferenceSpecification"] = inference_specification
request_dict["CertifyForMarketplace"] = marketplace_cert
request_dict["ModelApprovalStatus"] = approval_status
request_dict["SkipModelValidation"] = skip_model_validation
if model_card is not None:
request_dict["ModelCard"] = model_card
if model_life_cycle is not None:
request_dict["ModelLifeCycle"] = model_life_cycle
return request_dict
[docs]
def create_model_package_from_containers(
sagemaker_session,
containers=None,
content_types=None,
response_types=None,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status="PendingManualApproval",
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
validation_specification=None,
domain=None,
sample_payload_url=None,
task=None,
skip_model_validation="None",
source_uri=None,
model_card=None,
model_life_cycle=None,
):
"""Get request dictionary for CreateModelPackage API.
Args:
containers (list): A list of inference containers that can be used for inference
specifications of Model Package (default: None).
content_types (list): The supported MIME types for the input data (default: None).
response_types (list): The supported MIME types for the output data (default: None).
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
`model_package_name`, using `model_package_group_name` makes the Model Package
versioned (default: None).
model_metrics (ModelMetrics): ModelMetrics object (default: None).
metadata_properties (MetadataProperties): MetadataProperties object (default: None)
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
for AWS Marketplace (default: False).
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
or "PendingManualApproval" (default: "PendingManualApproval").
description (str): Model Package description (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
"MACHINE_LEARNING" (default: None).
sample_payload_url (str): The S3 path where the sample payload is stored
(default: None).
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
skip_model_validation (str): Indicates if you want to skip model validation.
Values can be "All" or "None" (default: None).
source_uri (str): The URI of the source for the model package (default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None).
"""
if containers:
# Containers are provided. Now we can merge missing entries from config.
# If Containers are not provided, it is safe to ignore. This is because,
# if this object is provided to the API, then Image is required for Containers.
# That is not supported by the config now. So if we merge values from config,
# then API will throw an exception. In the future, when SageMaker Config starts
# supporting other parameters we can add that.
update_list_of_dicts_with_values_from_config(
containers,
MODEL_PACKAGE_INFERENCE_SPECIFICATION_CONTAINERS_PATH,
required_key_paths=["Image"],
sagemaker_session=sagemaker_session,
)
if validation_specification:
# ValidationSpecification is provided. Now we can merge missing entries from config.
# If ValidationSpecification is not provided, it is safe to ignore. This is because,
# if this object is provided to the API, then both ValidationProfiles and ValidationRole
# are required and for ValidationProfile, ProfileName is a required parameter. That is
# not supported by the config now. So if we merge values from config, then API will
# throw an exception. In the future, when SageMaker Config starts supporting other
# parameters we can add that.
validation_role = resolve_value_from_config(
validation_specification.get(VALIDATION_ROLE, None),
MODEL_PACKAGE_VALIDATION_ROLE_PATH,
sagemaker_session=sagemaker_session,
)
validation_specification[VALIDATION_ROLE] = validation_role
validation_profiles = validation_specification.get(VALIDATION_PROFILES, [])
update_list_of_dicts_with_values_from_config(
validation_profiles,
MODEL_PACKAGE_VALIDATION_PROFILES_PATH,
required_key_paths=["ProfileName", "TransformJobDefinition"],
sagemaker_session=sagemaker_session,
)
model_pkg_request = get_create_model_package_request(
model_package_name,
model_package_group_name,
containers,
content_types,
response_types,
inference_instances,
transform_instances,
model_metrics,
metadata_properties,
marketplace_cert,
approval_status,
description,
drift_check_baselines=drift_check_baselines,
customer_metadata_properties=customer_metadata_properties,
validation_specification=validation_specification,
domain=domain,
sample_payload_url=sample_payload_url,
task=task,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
model_life_cycle=model_life_cycle,
)
def submit(request):
if model_package_group_name is not None and not model_package_group_name.startswith("arn:"):
is_model_package_group_present = False
try:
model_package_groups_response = sagemaker_session.search(
resource="ModelPackageGroup",
search_expression={
"Filters": [
{
"Name": "ModelPackageGroupName",
"Value": request["ModelPackageGroupName"],
"Operator": "Equals",
}
],
},
)
if len(model_package_groups_response.get("Results")) > 0:
is_model_package_group_present = True
except Exception: # pylint: disable=W0703
model_package_groups = []
model_package_groups_response = (
sagemaker_session.sagemaker_client.list_model_package_groups(
NameContains=request["ModelPackageGroupName"],
)
)
model_package_groups = (
model_package_groups
+ model_package_groups_response["ModelPackageGroupSummaryList"]
)
next_token = model_package_groups_response.get("NextToken")
while next_token is not None and next_token != "":
model_package_groups_response = (
sagemaker_session.sagemaker_client.list_model_package_groups(
NameContains=request["ModelPackageGroupName"], NextToken=next_token
)
)
model_package_groups = (
model_package_groups
+ model_package_groups_response["ModelPackageGroupSummaryList"]
)
next_token = model_package_groups_response.get("NextToken")
filtered_model_package_group = list(
filter(
lambda mpg: mpg.get("ModelPackageGroupName")
== request["ModelPackageGroupName"],
model_package_groups,
)
)
is_model_package_group_present = len(filtered_model_package_group) > 0
if not is_model_package_group_present:
_create_resource(
lambda: sagemaker_session.sagemaker_client.create_model_package_group(
ModelPackageGroupName=request["ModelPackageGroupName"]
)
)
if "SourceUri" in request and request["SourceUri"] is not None:
# Remove inference spec from request if the
# given source uri can lead to auto-population of it
if can_model_package_source_uri_autopopulate(request["SourceUri"]):
if "InferenceSpecification" in request:
del request["InferenceSpecification"]
return sagemaker_session.sagemaker_client.create_model_package(**request)
# If source uri can't autopopulate,
# first create model package with just the inference spec
# and then update model package with the source uri.
# Done this way because passing source uri and inference spec together
# in create/update model package is not allowed in the base sdk.
request_source_uri = request["SourceUri"]
del request["SourceUri"]
model_package = sagemaker_session.sagemaker_client.create_model_package(**request)
update_source_uri_args = {
"ModelPackageArn": model_package.get("ModelPackageArn"),
"SourceUri": request_source_uri,
}
return sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args)
return sagemaker_session.sagemaker_client.create_model_package(**request)
return sagemaker_session._intercept_create_request(
model_pkg_request, submit, create_model_package_from_containers.__name__
)
[docs]
def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data):
"""Create a SageMaker Model Package from the results of training with an Algorithm Package.
Args:
name (str): ModelPackage name
description (str): Model Package description
algorithm_arn (str): arn or name of the algorithm used for training.
model_data (str or dict[str, Any]): s3 URI or a dictionary representing a
``ModelDataSource`` to the model artifacts produced by training
"""
sourceAlgorithm = {"AlgorithmName": algorithm_arn}
if isinstance(model_data, dict):
sourceAlgorithm["ModelDataSource"] = model_data
else:
sourceAlgorithm["ModelDataUrl"] = model_data
request = {
"ModelPackageName": name,
"ModelPackageDescription": description,
"SourceAlgorithmSpecification": {"SourceAlgorithms": [sourceAlgorithm]},
}
try:
logger.info("Creating model package with name: %s", name)
self.sagemaker_client.create_model_package(**request)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "ValidationException" and "ModelPackage already exists" in message:
logger.warning("Using already existing model package: %s", name)
else:
raise