# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The `ModelStep` definition for SageMaker Pipelines Workflows"""
from __future__ import absolute_import
import logging
from typing import Union, List, Dict, Optional
from sagemaker.core.resources import Model
from sagemaker.mlops.workflow._utils import _RepackModelStep
from sagemaker.core.workflow.pipeline_context import PipelineSession, _ModelStepArguments
from sagemaker.mlops.workflow.retry import RetryPolicy, SageMakerJobStepRetryPolicy
from sagemaker.mlops.workflow.step_collections import StepCollection
from sagemaker.mlops.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum
from sagemaker.core.helper.pipeline_variable import RequestType
from sagemaker.core.workflow.properties import Properties
from sagemaker.core.workflow.utilities import trim_request_dict
_REPACK_MODEL_RETRY_POLICIES = "repack_model_retry_policies"
_REPACK_MODEL_NAME_BASE = "RepackModel"
_IGNORED_REPACK_PARAM_LIST = ["entry_point", "source_dir", "hyperparameters", "dependencies"]
logger = logging.getLogger(__name__)
[docs]
class ModelStep(ConfigurableRetryStep):
"""`ModelStep` for SageMaker Pipelines Workflows."""
def __init__(
self,
name: str,
step_args: Union[_ModelStepArguments, Dict],
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
retry_policies: Optional[Union[List[RetryPolicy], Dict[str, List[RetryPolicy]]]] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
repack_model_step_settings: Optional[Dict[str, any]] = None,
):
"""Constructs a `ModelStep`.
Args:
name (str): The name of the `ModelStep`. A name is required and must be
unique within a pipeline.
step_args (_ModelStepArguments): The arguments for the `ModelStep` definition,
generated by invoking the :func:`~sagemaker.model.Model.register` or
:func:`~sagemaker.model.Model.create`
under the :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. Example::
model = Model(sagemaker_session=PipelineSession())
model_step = ModelStep(step_args=model.register())
depends_on (List[Union[str, Step, StepCollection]]):
A list of `Step` or `StepCollection`
names or `Step` instances or `StepCollection` that it depends on.
If a listed `Step` name does not exist, an error is returned (default: None).
retry_policies (List[RetryPolicy]): The list of retry policies for the `ModelStep`
(default: None). Note: `SageMakerJobStepRetryPolicy` is not allowed, since
create/register model step does not support it.
.. code:: python
ModelStep(
...
retry_policies=[
StepRetryPolicy(...),
],
)
display_name (str): The display name of the `ModelStep`.
The display name provides better UI readability. (default: None).
description (str): The description of the `ModelStep` (default: None).
repack_model_step_settings (Dict[str, any]): The kwargs passed to the _RepackModelStep
to customize the configuration of the underlying repack model job (default: None).
Only used if model repacking is needed.
"""
from sagemaker.core.workflow.utilities import validate_step_args_input
validate_step_args_input(
step_args=step_args,
expected_caller={
"create_model",
"create_model_package_from_containers",
},
error_message="The step_args of ModelStep must be obtained from ModelBuilder.build() "
"or ModelBuilder.register(). For more, see: https://sagemaker.readthedocs.io/en/stable/"
"amazon_sagemaker_model_building_pipeline.html#model-step",
)
# Handle both dictionary (from ModelBuilder.build()) and _ModelStepArguments (from Model.create())
if isinstance(step_args, dict):
# step_args is a dictionary from ModelBuilder.build()
# Convert to _ModelStepArguments-like structure
class DictStepArgs:
def __init__(self, args_dict):
self.create_model_request = args_dict
self.create_model_package_request = None
self.need_runtime_repack = set()
self.runtime_repack_output_prefix = None
self.model = None # ModelBuilder instance not available in dict case
step_args = DictStepArgs(step_args)
else:
# step_args is _ModelStepArguments from Model.create()
if not (step_args.create_model_request is None) ^ (
step_args.create_model_package_request is None
):
raise ValueError(
"Invalid step_args: either _register_model_args or _create_model_args"
" should be provided. They are mutually exclusive. Please use the model's "
".create() or .register() method to generate the step_args under PipelineSession."
)
if not isinstance(step_args.model.sagemaker_session, PipelineSession):
raise TypeError(
"To correctly configure a ModelStep, "
"the sagemaker_session of the model must be a PipelineSession object."
)
# Determine step type based on step_args
if step_args.create_model_package_request:
step_type = StepTypeEnum.REGISTER_MODEL
else:
step_type = StepTypeEnum.CREATE_MODEL
super(ModelStep, self).__init__(
name, step_type, display_name, description, depends_on, retry_policies
)
self.step_args = step_args
self.steps: List[Step] = []
self._repack_model_step_settings = (
dict(repack_model_step_settings) if repack_model_step_settings else {}
)
self._model = step_args.model
self._create_model_args = self.step_args.create_model_request
self._register_model_args = self.step_args.create_model_package_request
self._need_runtime_repack = self.step_args.need_runtime_repack
self._runtime_repack_output_prefix = self.step_args.runtime_repack_output_prefix
if isinstance(retry_policies, dict):
self._repack_model_retry_policies = retry_policies.get(
_REPACK_MODEL_RETRY_POLICIES, None
)
else:
self._repack_model_retry_policies = retry_policies
# Validate that SageMakerJobStepRetryPolicy is not used for model step
if retry_policies and not isinstance(retry_policies, dict):
for policy in retry_policies:
if isinstance(policy, SageMakerJobStepRetryPolicy):
raise ValueError(
"SageMakerJobStepRetryPolicy is not allowed for a create/register"
" model step. Please use StepRetryPolicy instead"
)
# Set up properties based on step type
if self._register_model_args:
self._properties = Properties(
step_name=name, step=self, shape_name="DescribeModelPackageOutput"
)
else:
self._properties = Properties(
step_name=name, step=self, shape_name="DescribeModelOutput"
)
if self._need_runtime_repack:
self._append_repack_model_step()
elif self._repack_model_step_settings:
logger.warning(
"Non-empty repack_model_step_settings is supplied but no repack model "
"step is needed. Ignoring the repack_model_step_settings."
)
@property
def arguments(self) -> RequestType:
"""The arguments dict that are used to call the appropriate SageMaker API."""
from sagemaker.core.workflow.utilities import _pipeline_config
if self._register_model_args:
request_dict = self._register_model_args
# these are not available in the workflow service and will cause rejection
warn_msg_template = (
"Popping out '%s' from the pipeline definition "
"since it will be overridden in pipeline execution time."
)
if "CertifyForMarketplace" in request_dict:
request_dict.pop("CertifyForMarketplace")
logger.warning(warn_msg_template, "CertifyForMarketplace")
if "Description" in request_dict:
request_dict.pop("Description")
logger.warning(warn_msg_template, "Description")
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "ModelPackageName", _pipeline_config)
else:
request_dict = self._create_model_args
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "ModelName", _pipeline_config)
return request_dict
@property
def properties(self):
"""A Properties object representing the appropriate SageMaker response data model."""
return self._properties
def _append_repack_model_step(self):
"""Create and append a `_RepackModelStep` for the runtime repack"""
if isinstance(self._model, Model):
model_list = [self._model]
else:
logger.warning("No models to repack")
return
self._pop_out_non_configurable_repack_model_step_args()
security_group_ids, subnets = self._resolve_repack_model_step_vpc_configs()
for i, model in enumerate(model_list):
runtime_repack_flg = (
self._need_runtime_repack and id(model) in self._need_runtime_repack
)
if runtime_repack_flg:
name_base = model.name or i
repack_model_step = _RepackModelStep(
name="{}-{}-{}".format(self.name, _REPACK_MODEL_NAME_BASE, name_base),
sagemaker_session=(
self._repack_model_step_settings.pop("sagemaker_session", None)
or self._model.sagemaker_session
or model.sagemaker_session
),
role=(
self._repack_model_step_settings.pop("role", None)
or self._model.role
or model.role
),
model_data=model.model_data,
entry_point=model.entry_point,
source_dir=model.source_dir,
dependencies=model.dependencies,
subnets=subnets,
security_group_ids=security_group_ids,
description=(
"Used to repack a model with customer scripts for a "
"register/create model step"
),
depends_on=self.depends_on,
retry_policies=self._repack_model_retry_policies,
output_path=(
self._repack_model_step_settings.pop("output_path", None)
or self._runtime_repack_output_prefix
),
output_kms_key=(
self._repack_model_step_settings.pop("output_kms_key", None)
or model.model_kms_key
),
**self._repack_model_step_settings,
)
self.steps.append(repack_model_step)
repacked_model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
if self._create_model_args:
container = self.step_args.create_model_request["PrimaryContainer"]
else:
container = self.step_args.create_model_package_request[
"InferenceSpecification"
]["Containers"][i]
container["ModelDataUrl"] = repacked_model_data
def _pop_out_non_configurable_repack_model_step_args(self):
"""Pop out non-configurable args from _repack_model_step_settings"""
if not self._repack_model_step_settings:
return
for ignored_param in _IGNORED_REPACK_PARAM_LIST:
if self._repack_model_step_settings.pop(ignored_param, None):
logger.warning(
"The repack model step parameter - %s is not configurable. Ignoring it.",
ignored_param,
)
def _resolve_repack_model_step_vpc_configs(self):
"""Resolve vpc configs for repack model step"""
# Note: the EstimatorBase constructor ensures that:
# "When setting up custom VPC, both subnets and security_group_ids must be set"
if self._repack_model_step_settings.get(
"security_group_ids", None
) or self._repack_model_step_settings.get("subnets", None):
security_group_ids = self._repack_model_step_settings.pop("security_group_ids", None)
subnets = self._repack_model_step_settings.pop("subnets", None)
return security_group_ids, subnets
if self._model.vpc_config:
security_group_ids = self._model.vpc_config.get("SecurityGroupIds", None)
subnets = self._model.vpc_config.get("Subnets", None)
return security_group_ids, subnets
return None, None