Source code for sagemaker.mlops.workflow.automl_step

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The `AutoMLStep` definition for SageMaker Pipelines Workflows"""
from __future__ import absolute_import

from typing import Union, Optional, List

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.exceptions import AutoMLStepInvalidModeError

from sagemaker.core.helper.pipeline_variable import RequestType
from sagemaker.core.workflow.pipeline_context import _JobStepArguments
from sagemaker.core.workflow.properties import Properties
from sagemaker.mlops.workflow.retry import RetryPolicy
from sagemaker.mlops.workflow.steps import ConfigurableRetryStep, CacheConfig, Step, StepTypeEnum
from sagemaker.core.workflow.utilities import validate_step_args_input, trim_request_dict

from sagemaker.serve.model_builder import ModelBuilder


[docs] class AutoMLStep(ConfigurableRetryStep): """`AutoMLStep` for SageMaker Pipelines Workflows.""" def __init__( self, name: str, step_args: _JobStepArguments, display_name: Optional[str] = None, description: Optional[str] = None, cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step]]] = None, retry_policies: Optional[List[RetryPolicy]] = None, ): """Construct a `AutoMLStep`, given a `AutoML` instance. In addition to the `AutoML` instance, the other arguments are those that are supplied to the `fit` method of the `sagemaker.automl.automl.AutoML`. Args: name (str): The name of the `AutoMLStep`. step_args (_JobStepArguments): The arguments for the `AutoMLStep` definition. display_name (str): The display name of the `AutoMLStep`. description (str): The description of the `AutoMLStep`. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. depends_on (List[Union[str, Step]]): A list of `Step` names or `Step` instances that this `AutoMLStep` depends on. retry_policies (List[RetryPolicy]): A list of retry policies. """ super(AutoMLStep, self).__init__( name, StepTypeEnum.AUTOML, display_name, description, depends_on, retry_policies ) validate_step_args_input( step_args=step_args, expected_caller={Session.auto_ml.__name__}, error_message="The step_args of AutoMLStep must be obtained " "from automl.fit().", ) self.step_args = step_args self.cache_config = cache_config root_property = Properties( step_name=name, step=self, shape_name="DescribeAutoMLJobResponse" ) best_candidate_properties = Properties( step_name=name, step=self, path="BestCandidateProperties" ) best_candidate_properties.__dict__["ModelInsightsJsonReportPath"] = Properties( step_name=name, step=self, path="BestCandidateProperties.ModelInsightsJsonReportPath" ) best_candidate_properties.__dict__["ExplainabilityJsonReportPath"] = Properties( step_name=name, step=self, path="BestCandidateProperties.ExplainabilityJsonReportPath" ) root_property.__dict__["BestCandidateProperties"] = best_candidate_properties self._properties = root_property @property def arguments(self) -> RequestType: """The arguments dictionary that is used to call `create_auto_ml_job`. NOTE: The `CreateAutoMLJob` request is not quite the args list that workflow needs. `ModelDeployConfig` and `GenerateCandidateDefinitionsOnly` attribute cannot be included. """ from sagemaker.core.workflow.utilities import execute_job_functions from sagemaker.core.workflow.utilities import _pipeline_config # execute fit function in AutoML with saved parameters, # and store args in PipelineSession's _context execute_job_functions(self.step_args) # populate request dict with args auto_ml = self.step_args.func_args[0] request_dict = auto_ml.sagemaker_session.context.args if "AutoMLJobConfig" not in request_dict: raise AutoMLStepInvalidModeError() if ( "Mode" not in request_dict["AutoMLJobConfig"] or request_dict["AutoMLJobConfig"]["Mode"] != "ENSEMBLING" ): raise AutoMLStepInvalidModeError() if "ModelDeployConfig" in request_dict: request_dict.pop("ModelDeployConfig", None) if "GenerateCandidateDefinitionsOnly" in request_dict: request_dict.pop("GenerateCandidateDefinitionsOnly", None) # Continue to pop job name if not explicitly opted-in via config # AutoML Trims to AutoMLJo-2023-06-23-22-57-39-083 request_dict = trim_request_dict(request_dict, "AutoMLJobName", _pipeline_config) return request_dict @property def properties(self): """A `Properties` object representing the `DescribeAutoMLJobResponse` data model.""" return self._properties
[docs] def to_request(self) -> RequestType: """Updates the dictionary with cache configuration.""" request_dict = super().to_request() if self.cache_config: request_dict.update(self.cache_config.config) return request_dict
[docs] def get_best_auto_ml_model_builder(self, role, sagemaker_session=None): """Get the best candidate model artifacts, image uri and env variables for the best model. Args: role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker AutoML jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. sagemaker_session (sagemaker.core.helper.session.Session): A SageMaker Session object, used for SageMaker interactions. If the best model will be used as part of ModelStep, then sagemaker_session should be class:`~sagemaker.workflow.pipeline_context.PipelineSession`. Example:: model = Model(sagemaker_session=PipelineSession()) model_step = ModelStep(step_args=model.register()) """ inference_container = self.properties.BestCandidate.InferenceContainers[0] inference_container_environment = inference_container.Environment image = inference_container.Image model_data = inference_container.ModelDataUrl model_builder = ModelBuilder( image_uri=image, s3_model_data_url=model_data, env_vars={ "MODEL_NAME": inference_container_environment["MODEL_NAME"], "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": inference_container_environment[ "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" ], "SAGEMAKER_SUBMIT_DIRECTORY": inference_container_environment[ "SAGEMAKER_SUBMIT_DIRECTORY" ], "SAGEMAKER_INFERENCE_SUPPORTED": inference_container_environment[ "SAGEMAKER_INFERENCE_SUPPORTED" ], "SAGEMAKER_INFERENCE_OUTPUT": inference_container_environment[ "SAGEMAKER_INFERENCE_OUTPUT" ], "SAGEMAKER_PROGRAM": inference_container_environment["SAGEMAKER_PROGRAM"], }, sagemaker_session=sagemaker_session, role_arn=role, ) return model_builder