Source code for sagemaker.mlops.workflow.model_step

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The `ModelStep` definition for SageMaker Pipelines Workflows"""
from __future__ import absolute_import

import logging
from typing import Union, List, Dict, Optional

from sagemaker.core.resources import Model
from sagemaker.mlops.workflow._utils import _RepackModelStep
from sagemaker.core.workflow.pipeline_context import PipelineSession, _ModelStepArguments
from sagemaker.mlops.workflow.retry import RetryPolicy, SageMakerJobStepRetryPolicy
from sagemaker.mlops.workflow.step_collections import StepCollection
from sagemaker.mlops.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum
from sagemaker.core.helper.pipeline_variable import RequestType
from sagemaker.core.workflow.properties import Properties
from sagemaker.core.workflow.utilities import trim_request_dict

_REPACK_MODEL_RETRY_POLICIES = "repack_model_retry_policies"
_REPACK_MODEL_NAME_BASE = "RepackModel"
_IGNORED_REPACK_PARAM_LIST = ["entry_point", "source_dir", "hyperparameters", "dependencies"]

logger = logging.getLogger(__name__)


[docs] class ModelStep(ConfigurableRetryStep): """`ModelStep` for SageMaker Pipelines Workflows.""" def __init__( self, name: str, step_args: Union[_ModelStepArguments, Dict], depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, retry_policies: Optional[Union[List[RetryPolicy], Dict[str, List[RetryPolicy]]]] = None, display_name: Optional[str] = None, description: Optional[str] = None, repack_model_step_settings: Optional[Dict[str, any]] = None, ): """Constructs a `ModelStep`. Args: name (str): The name of the `ModelStep`. A name is required and must be unique within a pipeline. step_args (_ModelStepArguments): The arguments for the `ModelStep` definition, generated by invoking the :func:`~sagemaker.model.Model.register` or :func:`~sagemaker.model.Model.create` under the :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. Example:: model = Model(sagemaker_session=PipelineSession()) model_step = ModelStep(step_args=model.register()) depends_on (List[Union[str, Step, StepCollection]]): A list of `Step` or `StepCollection` names or `Step` instances or `StepCollection` that it depends on. If a listed `Step` name does not exist, an error is returned (default: None). retry_policies (List[RetryPolicy]): The list of retry policies for the `ModelStep` (default: None). Note: `SageMakerJobStepRetryPolicy` is not allowed, since create/register model step does not support it. .. code:: python ModelStep( ... retry_policies=[ StepRetryPolicy(...), ], ) display_name (str): The display name of the `ModelStep`. The display name provides better UI readability. (default: None). description (str): The description of the `ModelStep` (default: None). repack_model_step_settings (Dict[str, any]): The kwargs passed to the _RepackModelStep to customize the configuration of the underlying repack model job (default: None). Only used if model repacking is needed. """ from sagemaker.core.workflow.utilities import validate_step_args_input validate_step_args_input( step_args=step_args, expected_caller={ "create_model", "create_model_package_from_containers", }, error_message="The step_args of ModelStep must be obtained from ModelBuilder.build() " "or ModelBuilder.register(). For more, see: https://sagemaker.readthedocs.io/en/stable/" "amazon_sagemaker_model_building_pipeline.html#model-step", ) # Handle both dictionary (from ModelBuilder.build()) and _ModelStepArguments (from Model.create()) if isinstance(step_args, dict): # step_args is a dictionary from ModelBuilder.build() # Convert to _ModelStepArguments-like structure class DictStepArgs: def __init__(self, args_dict): self.create_model_request = args_dict self.create_model_package_request = None self.need_runtime_repack = set() self.runtime_repack_output_prefix = None self.model = None # ModelBuilder instance not available in dict case step_args = DictStepArgs(step_args) else: # step_args is _ModelStepArguments from Model.create() if not (step_args.create_model_request is None) ^ ( step_args.create_model_package_request is None ): raise ValueError( "Invalid step_args: either _register_model_args or _create_model_args" " should be provided. They are mutually exclusive. Please use the model's " ".create() or .register() method to generate the step_args under PipelineSession." ) if not isinstance(step_args.model.sagemaker_session, PipelineSession): raise TypeError( "To correctly configure a ModelStep, " "the sagemaker_session of the model must be a PipelineSession object." ) # Determine step type based on step_args if step_args.create_model_package_request: step_type = StepTypeEnum.REGISTER_MODEL else: step_type = StepTypeEnum.CREATE_MODEL super(ModelStep, self).__init__( name, step_type, display_name, description, depends_on, retry_policies ) self.step_args = step_args self.steps: List[Step] = [] self._repack_model_step_settings = ( dict(repack_model_step_settings) if repack_model_step_settings else {} ) self._model = step_args.model self._create_model_args = self.step_args.create_model_request self._register_model_args = self.step_args.create_model_package_request self._need_runtime_repack = self.step_args.need_runtime_repack self._runtime_repack_output_prefix = self.step_args.runtime_repack_output_prefix if isinstance(retry_policies, dict): self._repack_model_retry_policies = retry_policies.get( _REPACK_MODEL_RETRY_POLICIES, None ) else: self._repack_model_retry_policies = retry_policies # Validate that SageMakerJobStepRetryPolicy is not used for model step if retry_policies and not isinstance(retry_policies, dict): for policy in retry_policies: if isinstance(policy, SageMakerJobStepRetryPolicy): raise ValueError( "SageMakerJobStepRetryPolicy is not allowed for a create/register" " model step. Please use StepRetryPolicy instead" ) # Set up properties based on step type if self._register_model_args: self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelPackageOutput" ) else: self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelOutput" ) if self._need_runtime_repack: self._append_repack_model_step() elif self._repack_model_step_settings: logger.warning( "Non-empty repack_model_step_settings is supplied but no repack model " "step is needed. Ignoring the repack_model_step_settings." ) @property def arguments(self) -> RequestType: """The arguments dict that are used to call the appropriate SageMaker API.""" from sagemaker.core.workflow.utilities import _pipeline_config if self._register_model_args: request_dict = self._register_model_args # these are not available in the workflow service and will cause rejection warn_msg_template = ( "Popping out '%s' from the pipeline definition " "since it will be overridden in pipeline execution time." ) if "CertifyForMarketplace" in request_dict: request_dict.pop("CertifyForMarketplace") logger.warning(warn_msg_template, "CertifyForMarketplace") if "Description" in request_dict: request_dict.pop("Description") logger.warning(warn_msg_template, "Description") # Continue to pop job name if not explicitly opted-in via config request_dict = trim_request_dict(request_dict, "ModelPackageName", _pipeline_config) else: request_dict = self._create_model_args # Continue to pop job name if not explicitly opted-in via config request_dict = trim_request_dict(request_dict, "ModelName", _pipeline_config) return request_dict @property def properties(self): """A Properties object representing the appropriate SageMaker response data model.""" return self._properties def _append_repack_model_step(self): """Create and append a `_RepackModelStep` for the runtime repack""" if isinstance(self._model, Model): model_list = [self._model] else: logger.warning("No models to repack") return self._pop_out_non_configurable_repack_model_step_args() security_group_ids, subnets = self._resolve_repack_model_step_vpc_configs() for i, model in enumerate(model_list): runtime_repack_flg = ( self._need_runtime_repack and id(model) in self._need_runtime_repack ) if runtime_repack_flg: name_base = model.name or i repack_model_step = _RepackModelStep( name="{}-{}-{}".format(self.name, _REPACK_MODEL_NAME_BASE, name_base), sagemaker_session=( self._repack_model_step_settings.pop("sagemaker_session", None) or self._model.sagemaker_session or model.sagemaker_session ), role=( self._repack_model_step_settings.pop("role", None) or self._model.role or model.role ), model_data=model.model_data, entry_point=model.entry_point, source_dir=model.source_dir, dependencies=model.dependencies, subnets=subnets, security_group_ids=security_group_ids, description=( "Used to repack a model with customer scripts for a " "register/create model step" ), depends_on=self.depends_on, retry_policies=self._repack_model_retry_policies, output_path=( self._repack_model_step_settings.pop("output_path", None) or self._runtime_repack_output_prefix ), output_kms_key=( self._repack_model_step_settings.pop("output_kms_key", None) or model.model_kms_key ), **self._repack_model_step_settings, ) self.steps.append(repack_model_step) repacked_model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts if self._create_model_args: container = self.step_args.create_model_request["PrimaryContainer"] else: container = self.step_args.create_model_package_request[ "InferenceSpecification" ]["Containers"][i] container["ModelDataUrl"] = repacked_model_data def _pop_out_non_configurable_repack_model_step_args(self): """Pop out non-configurable args from _repack_model_step_settings""" if not self._repack_model_step_settings: return for ignored_param in _IGNORED_REPACK_PARAM_LIST: if self._repack_model_step_settings.pop(ignored_param, None): logger.warning( "The repack model step parameter - %s is not configurable. Ignoring it.", ignored_param, ) def _resolve_repack_model_step_vpc_configs(self): """Resolve vpc configs for repack model step""" # Note: the EstimatorBase constructor ensures that: # "When setting up custom VPC, both subnets and security_group_ids must be set" if self._repack_model_step_settings.get( "security_group_ids", None ) or self._repack_model_step_settings.get("subnets", None): security_group_ids = self._repack_model_step_settings.pop("security_group_ids", None) subnets = self._repack_model_step_settings.pop("subnets", None) return security_group_ids, subnets if self._model.vpc_config: security_group_ids = self._model.vpc_config.get("SecurityGroupIds", None) subnets = self._model.vpc_config.get("Subnets", None) return security_group_ids, subnets return None, None