# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The `Step` definitions for SageMaker Pipelines Workflows."""
from __future__ import absolute_import
import abc
from enum import Enum
from typing import Dict, List, Set, Union, Optional, Any, TYPE_CHECKING
import attr
from sagemaker.core.local.local_session import LocalSagemakerClient
# Primitive imports (stay in core)
from sagemaker.core.workflow.entities import Entity
from sagemaker.core.helper.pipeline_variable import RequestType
from sagemaker.core.workflow.pipeline_context import _JobStepArguments
from sagemaker.core.workflow.properties import (
PropertyFile,
Properties,
)
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.workflow.functions import Join, JsonGet
# Orchestration imports (now in mlops)
from sagemaker.mlops.workflow.retry import RetryPolicy
from sagemaker.core.workflow.step_outputs import StepOutput
from sagemaker.core.workflow.utilities import trim_request_dict
from sagemaker.core.processing import Processor
# Lazy import to avoid circular dependency
# ModelTrainer imports from core, and core.workflow imports ModelTrainer
if TYPE_CHECKING:
from sagemaker.mlops.workflow.step_collections import StepCollection
[docs]
class StepTypeEnum(Enum):
"""Enum of `Step` types."""
CONDITION = "Condition"
CREATE_MODEL = "Model"
PROCESSING = "Processing"
REGISTER_MODEL = "RegisterModel"
TRAINING = "Training"
TRANSFORM = "Transform"
CALLBACK = "Callback"
TUNING = "Tuning"
LAMBDA = "Lambda"
QUALITY_CHECK = "QualityCheck"
CLARIFY_CHECK = "ClarifyCheck"
EMR = "EMR"
EMR_SERVERLESS = "EMRServerless"
FAIL = "Fail"
AUTOML = "AutoML"
[docs]
class Step(Entity):
"""Pipeline `Step` for SageMaker Pipelines Workflows."""
def __init__(
self,
name: str,
display_name: Optional[str] = None,
description: Optional[str] = None,
step_type: StepTypeEnum = None,
depends_on: Optional[List[Union[str, "Step", "StepCollection", StepOutput]]] = None,
):
"""Initialize a Step
Args:
name (str): The name of the `Step`.
display_name (str): The display name of the `Step`.
description (str): The description of the `Step`.
step_type (StepTypeEnum): The type of the `Step`.
depends_on (List[Union[str, Step, StepCollection]]): The list of `Step`/`StepCollection`
names or `Step` or `StepCollection`, `StepOutput` instances that the current `Step`
depends on.
"""
self.name = name
self.display_name = display_name
self.description = description
self.step_type = step_type
if depends_on is not None:
self._depends_on = depends_on
else:
self._depends_on = None
@property
def depends_on(self) -> Optional[List[Union[str, "Step", "StepCollection", StepOutput]]]:
"""The list of steps the current `Step` depends on."""
return self._depends_on
@depends_on.setter
def depends_on(self, depends_on: List[Union[str, "Step", "StepCollection", StepOutput]]):
"""Set the list of steps the current step explicitly depends on."""
if depends_on is not None:
self._depends_on = depends_on
else:
self._depends_on = None
@property
@abc.abstractmethod
def arguments(self) -> RequestType:
"""The arguments to the particular `Step` service call."""
@property
def step_only_arguments(self) -> RequestType:
"""The arguments to this Step only.
Compound Steps such as the ConditionStep will have to
override this method to return arguments pertaining to only that step.
"""
return self.arguments
@property
@abc.abstractmethod
def properties(self):
"""The properties of the particular `Step`."""
[docs]
def to_request(self) -> RequestType:
"""Gets the request structure for workflow service calls."""
request_dict = {
"Name": self.name,
"Type": self.step_type.value,
"Arguments": self.arguments,
}
if self.depends_on:
request_dict["DependsOn"] = list(self.depends_on)
if self.display_name:
request_dict["DisplayName"] = self.display_name
if self.description:
request_dict["Description"] = self.description
return request_dict
[docs]
def add_depends_on(self, step_names: List[Union[str, "Step", "StepCollection", StepOutput]]):
"""Add `Step` names or `Step` instances to the current `Step` depends on list."""
if not step_names:
return
if not self._depends_on:
self._depends_on = []
self._depends_on.extend(step_names)
@property
def ref(self) -> Dict[str, str]:
"""Gets a reference dictionary for `Step` instances."""
return {"Name": self.name}
# TODO: move this method to CompiledStep
def _find_step_dependencies(
self, step_map: Dict[str, Union["Step", "StepCollection"]]
) -> List[str]:
"""Find all step names this step is dependent on."""
step_dependencies = set()
if self.depends_on:
step_dependencies.update(self._find_dependencies_in_depends_on_list(step_map))
step_dependencies.update(
self._find_dependencies_in_step_arguments(self.step_only_arguments, step_map)
)
return list(step_dependencies)
def _find_dependencies_in_depends_on_list(
self, step_map: Dict[str, Union["Step", "StepCollection"]]
) -> Set[str]:
"""Find dependency steps referenced in the depends-on field of this step."""
# import here to prevent circular import
from sagemaker.mlops.workflow.step_collections import StepCollection
dependencies = set()
for step in self.depends_on:
if isinstance(step, Step):
dependencies.add(step.name)
elif isinstance(step, StepCollection):
dependencies.add(step.steps[-1].name)
elif isinstance(step, str):
# step could be the name of a `Step` or a `StepCollection`
dependencies.add(self._get_step_name_from_str(step, step_map))
return dependencies
def _find_dependencies_in_step_arguments(
self, obj: Any, step_map: Dict[str, Union["Step", "StepCollection"]]
):
"""Find the step dependencies referenced in the arguments of this step."""
dependencies = set()
pipeline_variables = Step._find_pipeline_variables_in_step_arguments(obj)
for pipeline_variable in pipeline_variables:
for referenced_step in pipeline_variable._referenced_steps:
if isinstance(referenced_step, Step):
dependencies.add(referenced_step.name)
else:
dependencies.add(self._get_step_name_from_str(referenced_step, step_map))
from sagemaker.core.workflow.function_step import DelayedReturn
# TODO: we can remove the if-elif once move the validators to JsonGet constructor
if isinstance(pipeline_variable, JsonGet):
self._validate_json_get_function(pipeline_variable, step_map)
elif isinstance(pipeline_variable, DelayedReturn):
# DelayedReturn showing up in arguments, meaning that it's data referenced
# We should convert it to JsonGet and validate the JsonGet object
self._validate_json_get_function(pipeline_variable._to_json_get(), step_map)
return dependencies
def _validate_json_get_function(
self, json_get: JsonGet, step_map: Dict[str, Union["Step", "StepCollection"]]
):
"""Validate the JsonGet function inputs."""
if json_get.property_file:
self._validate_json_get_property_file_reference(json_get=json_get, step_map=step_map)
# TODO: move it to JsonGet constructor
def _validate_json_get_property_file_reference(self, json_get: JsonGet, step_map: dict):
"""Validate the property file reference in JsonGet"""
property_file_reference = json_get.property_file
processing_step = step_map[json_get.step_name]
property_file = None
if isinstance(property_file_reference, str):
if not processing_step.step_type == StepTypeEnum.PROCESSING:
raise ValueError(
f"Invalid JsonGet function {json_get.expr} in step '{self.name}'. "
f"JsonGet function (with property_file) can only be evaluated "
f"on processing step outputs."
)
for file in processing_step.property_files:
if file.name == property_file_reference:
property_file = file
break
elif isinstance(property_file_reference, PropertyFile):
property_file = property_file_reference
if property_file is None:
raise ValueError(
f"Invalid JsonGet function {json_get.expr} in step '{self.name}'. Property file "
f"reference '{property_file_reference}' is undefined in step "
f"'{processing_step.name}'."
)
property_file_output = None
if "ProcessingOutputConfig" in processing_step.arguments:
for output in processing_step.arguments["ProcessingOutputConfig"]["Outputs"]:
if output["OutputName"] == property_file.output_name:
property_file_output = output
if property_file_output is None:
raise ValueError(
f"Processing output name '{property_file.output_name}' defined in property file "
f"'{property_file.name}' not found in processing step '{processing_step.name}'."
)
@staticmethod
def _find_pipeline_variables_in_step_arguments(obj: RequestType) -> List[PipelineVariable]:
"""Recursively find all the pipeline variables in the step arguments."""
pipeline_variables = list()
if isinstance(obj, dict):
for value in obj.values():
if isinstance(value, PipelineVariable):
pipeline_variables.append(value)
else:
pipeline_variables.extend(
Step._find_pipeline_variables_in_step_arguments(value)
)
elif isinstance(obj, list):
for item in obj:
if isinstance(item, PipelineVariable):
pipeline_variables.append(item)
else:
pipeline_variables.extend(Step._find_pipeline_variables_in_step_arguments(item))
return pipeline_variables
@staticmethod
def _get_step_name_from_str(
str_input: str, step_map: Dict[str, Union["Step", "StepCollection"]]
) -> str:
"""Convert a Step or StepCollection name input to step name."""
from sagemaker.mlops.workflow.step_collections import StepCollection
if str_input not in step_map:
raise ValueError(f"Step {str_input} is undefined.")
if isinstance(step_map[str_input], StepCollection):
return step_map[str_input].steps[-1].name
return str_input
@staticmethod
def _trim_experiment_config(request_dict: Dict):
"""For job steps, trim the experiment config to keep the trial component display name."""
if request_dict.get("ExperimentConfig", {}).get("TrialComponentDisplayName"):
request_dict["ExperimentConfig"] = {
"TrialComponentDisplayName": request_dict["ExperimentConfig"][
"TrialComponentDisplayName"
]
}
else:
request_dict.pop("ExperimentConfig", None)
[docs]
@attr.s
class CacheConfig:
"""Configuration class to enable caching in SageMaker Pipelines Workflows.
If caching is enabled, the pipeline attempts to find a previous execution of a `Step`
that was called with the same arguments. `Step` caching only considers successful execution.
If a successful previous execution is found, the pipeline propagates the values
from the previous execution rather than recomputing the `Step`.
When multiple successful executions exist within the timeout period,
it uses the result for the most recent successful execution.
Attributes:
enable_caching (bool): To enable `Step` caching. Defaults to `False`.
expire_after (str): If `Step` caching is enabled, a timeout also needs to defined.
It defines how old a previous execution can be to be considered for reuse.
Value should be an ISO 8601 duration string. Defaults to `None`.
Examples::
'p30d' # 30 days
'P4DT12H' # 4 days and 12 hours
'T12H' # 12 hours
"""
enable_caching: bool = attr.ib(default=False)
expire_after = attr.ib(
default=None, validator=attr.validators.optional(attr.validators.instance_of(str))
)
@property
def config(self):
"""Configures `Step` caching for SageMaker Pipelines Workflows."""
config = {"Enabled": self.enable_caching}
if self.expire_after is not None:
config["ExpireAfter"] = self.expire_after
return {"CacheConfig": config}
[docs]
class ConfigurableRetryStep(Step):
"""`ConfigurableRetryStep` for SageMaker Pipelines Workflows."""
def __init__(
self,
name: str,
step_type: StepTypeEnum,
display_name: Optional[str] = None,
description: Optional[str] = None,
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
retry_policies: Optional[List[RetryPolicy]] = None,
):
super().__init__(
name=name,
display_name=display_name,
step_type=step_type,
description=description,
depends_on=depends_on,
)
self.retry_policies = [] if not retry_policies else retry_policies
[docs]
def add_retry_policy(self, retry_policy: RetryPolicy):
"""Add a policy to the current `ConfigurableRetryStep` retry policies list."""
if not retry_policy:
return
if not self.retry_policies:
self.retry_policies = []
self.retry_policies.append(retry_policy)
[docs]
def to_request(self) -> RequestType:
"""Gets the request structure for `ConfigurableRetryStep`."""
step_dict = super().to_request()
if self.retry_policies:
step_dict["RetryPolicies"] = self._resolve_retry_policy(self.retry_policies)
return step_dict
@staticmethod
def _resolve_retry_policy(retry_policy_list: List[RetryPolicy]) -> List[RequestType]:
"""Resolve the `ConfigurableRetryStep` retry policy list."""
return [retry_policy.to_request() for retry_policy in retry_policy_list]
[docs]
class TrainingStep(ConfigurableRetryStep):
"""`TrainingStep` for SageMaker Pipelines Workflows."""
def __init__(
self,
name: str,
step_args: Optional[_JobStepArguments] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
cache_config: Optional[CacheConfig] = None,
depends_on: Optional[List[Union[str, Step]]] = None,
retry_policies: Optional[List[RetryPolicy]] = None,
):
"""Construct a `TrainingStep` using step_args from model_trainer.train().
Args:
name (str): The name of the `TrainingStep`.
step_args (_JobStepArguments): The arguments for the `TrainingStep` definition.
display_name (str): The display name of the `TrainingStep`.
description (str): The description of the `TrainingStep`.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[Union[str, Step]]): A list of `Step`
names or `Step` instances that this `TrainingStep`
depends on.
retry_policies (List[RetryPolicy]): A list of retry policies.
"""
super(TrainingStep, self).__init__(
name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies
)
if step_args:
from sagemaker.core.workflow.utilities import validate_step_args_input
# Lazy import to avoid circular dependency
from sagemaker.train.model_trainer import ModelTrainer
validate_step_args_input(
step_args=step_args,
expected_caller={ModelTrainer.train.__name__},
error_message="The step_args of TrainingStep must be obtained from model_trainer.train().",
)
self.step_args = step_args
self._properties = Properties(
step_name=name, step=self, shape_name="DescribeTrainingJobResponse"
)
self.cache_config = cache_config
@property
def arguments(self) -> RequestType:
"""The arguments dictionary that is used to call `create_training_job`.
NOTE: The `CreateTrainingJob` request is not quite the args list that workflow needs.
"""
from sagemaker.core.workflow.utilities import execute_job_functions
from sagemaker.core.workflow.utilities import _pipeline_config
if self.step_args:
# execute fit function with saved parameters,
# and store args in PipelineSession's _context
execute_job_functions(self.step_args)
# populate request dict with args
model_trainer = self.step_args.func_args[0]
request_dict = model_trainer.sagemaker_session.context.args
else:
raise ValueError("step_args input is required.")
if "HyperParameters" in request_dict:
request_dict["HyperParameters"].pop("sagemaker_job_name", None)
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "TrainingJobName", _pipeline_config)
Step._trim_experiment_config(request_dict)
return request_dict
@property
def properties(self):
"""A `Properties` object representing the `DescribeTrainingJobResponse` data model."""
return self._properties
[docs]
def to_request(self) -> RequestType:
"""Updates the request dictionary with cache configuration."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
return request_dict
[docs]
class ProcessingStep(ConfigurableRetryStep):
"""`ProcessingStep` for SageMaker Pipelines Workflows."""
def __init__(
self,
name: str,
step_args: Optional[_JobStepArguments] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
property_files: Optional[List[PropertyFile]] = None,
cache_config: Optional[CacheConfig] = None,
depends_on: Optional[List[Union[str, Step]]] = None,
retry_policies: Optional[List[RetryPolicy]] = None,
):
"""Construct a `ProcessingStep`, given a `Processor` instance.
In addition to the `Processor` instance, the other arguments are those that are supplied to
the `process` method of the `sagemaker.processing.Processor`.
Args:
name (str): The name of the `ProcessingStep`.
step_args (_JobStepArguments): The arguments for the `ProcessingStep` definition.
display_name (str): The display name of the `ProcessingStep`.
description (str): The description of the `ProcessingStep`
property_files (List[PropertyFile]): A list of property files that workflow looks
for and resolves from the configured processing output list.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[Union[str, Step]]): A list of `Step`
names or `Step` instances that this `ProcessingStep`
depends on.
retry_policies (List[RetryPolicy]): A list of retry policies.
"""
super(ProcessingStep, self).__init__(
name, StepTypeEnum.PROCESSING, display_name, description, depends_on, retry_policies
)
if not step_args:
raise ValueError("step_args is required for ProcessingStep.")
from sagemaker.core.workflow.utilities import validate_step_args_input
validate_step_args_input(
step_args=step_args,
expected_caller={
Processor.run.__name__,
LocalSagemakerClient().create_processing_job.__name__,
},
error_message=f"The step_args of ProcessingStep must be obtained from processor.run() or in local mode, not {step_args.caller_name}",
)
self.step_args = step_args
self.property_files = property_files or []
self.cache_config = cache_config
self._properties = Properties(
step_name=name, step=self, shape_name="DescribeProcessingJobResponse"
)
@property
def arguments(self) -> RequestType:
"""The arguments dictionary that is used to call `create_processing_job`.
NOTE: The `CreateProcessingJob` request is not quite the args list that workflow needs.
`ExperimentConfig` cannot be included in the arguments.
"""
from sagemaker.core.workflow.utilities import execute_job_functions
from sagemaker.core.workflow.utilities import _pipeline_config
# execute run function with saved parameters,
# and store args in PipelineSession's _context
execute_job_functions(self.step_args)
# populate request dict with args
processor = self.step_args.func_args[0]
request_dict = processor.sagemaker_session.context.args
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "ProcessingJobName", _pipeline_config)
Step._trim_experiment_config(request_dict)
return request_dict
@property
def properties(self):
"""A `Properties` object representing the `DescribeProcessingJobResponse` data model."""
return self._properties
[docs]
def to_request(self) -> RequestType:
"""Get the request structure for workflow service calls."""
request_dict = super(ProcessingStep, self).to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
if self.property_files:
request_dict["PropertyFiles"] = [
property_file.expr for property_file in self.property_files
]
return request_dict
[docs]
class TuningStep(ConfigurableRetryStep):
"""`TuningStep` for SageMaker Pipelines Workflows."""
def __init__(
self,
name: str,
step_args: Optional[_JobStepArguments] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
cache_config: Optional[CacheConfig] = None,
depends_on: Optional[List[Union[str, Step]]] = None,
retry_policies: Optional[List[RetryPolicy]] = None,
):
"""Construct a `TuningStep`, given a `HyperparameterTuner` instance.
In addition to the `HyperparameterTuner` instance, the other arguments are those
that are supplied to the `fit` method of the `sagemaker.tuner.HyperparameterTuner`.
Args:
name (str): The name of the `TuningStep`.
step_args (_JobStepArguments): The arguments for the `TuningStep` definition.
display_name (str): The display name of the `TuningStep`.
description (str): The description of the `TuningStep`.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[Union[str, Step]]): A list of `Step`
names or `Step` instances that this `TuningStep`
depends on.
retry_policies (List[RetryPolicy]): A list of retry policies.
"""
super(TuningStep, self).__init__(
name, StepTypeEnum.TUNING, display_name, description, depends_on, retry_policies
)
if not step_args:
raise ValueError("step_args is required for TuningStep.")
from sagemaker.core.workflow.utilities import validate_step_args_input
validate_step_args_input(
step_args=step_args,
expected_caller={"tune"},
error_message="The step_args of TuningStep must be obtained from tuner.tune().",
)
self.step_args = step_args
self._properties = Properties(
step_name=name,
step=self,
shape_names=[
"DescribeHyperParameterTuningJobResponse",
"ListTrainingJobsForHyperParameterTuningJobResponse",
],
)
self.cache_config = cache_config
@property
def arguments(self) -> RequestType:
"""The arguments dictionary that is used to call `create_hyper_parameter_tuning_job`.
NOTE: The `CreateHyperParameterTuningJob` request is not quite the
args list that workflow needs.
"""
from sagemaker.core.workflow.utilities import execute_job_functions
from sagemaker.core.workflow.utilities import _pipeline_config
# execute fit function with saved parameters,
# and store args in PipelineSession's _context
execute_job_functions(self.step_args)
# populate request dict with args
tuner = self.step_args.func_args[0]
request_dict = tuner.sagemaker_session.context.args
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(
request_dict, "HyperParameterTuningJobName", _pipeline_config
)
return request_dict
@property
def properties(self):
"""A `Properties` object
A `Properties` object representing `DescribeHyperParameterTuningJobResponse` and
`ListTrainingJobsForHyperParameterTuningJobResponse` data model.
"""
return self._properties
[docs]
def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
return request_dict
[docs]
def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") -> Join:
"""Get the model artifact S3 URI from the top performing training jobs.
Args:
top_k (int): The index of the top performing training job
tuning step stores up to 50 top performing training jobs.
A valid top_k value is from 0 to 49. The best training job
model is at index 0.
s3_bucket (str): The S3 bucket to store the training job output artifact.
prefix (str): The S3 key prefix to store the training job output artifact.
"""
values = ["s3:/", s3_bucket]
if prefix != "" and prefix is not None:
values.append(prefix)
return Join(
on="/",
values=values
+ [
self.properties.TrainingJobSummaries[top_k].TrainingJobName,
"output/model.tar.gz",
],
)