# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
from __future__ import absolute_import
import logging
from abc import ABC
from typing import List, Union, Optional
import os
import pathlib
import attr
from sagemaker.core import s3
from sagemaker.core.model_monitor import ModelMonitor
from sagemaker.core.shapes import ProcessingInput, ProcessingOutput
from sagemaker.core.processing import Processor
from sagemaker.core.workflow import is_pipeline_variable
from sagemaker.core.helper.pipeline_variable import (
RequestType,
PipelineVariable,
PrimitiveType
)
from sagemaker.core.workflow.parameters import Parameter, ParameterString
from sagemaker.core.workflow.properties import (
Properties,
)
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum, CacheConfig
from sagemaker.mlops.workflow.check_job_config import CheckJobConfig
from sagemaker.core.workflow.utilities import trim_request_dict
_CONTAINER_BASE_PATH = "/opt/ml/processing"
_CONTAINER_INPUT_PATH = "input"
_CONTAINER_OUTPUT_PATH = "output"
_BASELINE_DATASET_INPUT_NAME = "baseline_dataset_input"
_RECORD_PREPROCESSOR_SCRIPT_INPUT_NAME = "record_preprocessor_script_input"
_POST_ANALYTICS_PROCESSOR_SCRIPT_INPUT_NAME = "post_analytics_processor_script_input"
_MODEL_MONITOR_S3_PATH = "model-monitor"
_BASELINING_S3_PATH = "baselining"
_RESULTS_S3_PATH = "results"
_DEFAULT_OUTPUT_NAME = "quality_check_output"
_MODEL_QUALITY_TYPE = "MODEL_QUALITY"
_DATA_QUALITY_TYPE = "DATA_QUALITY"
logger = logging.getLogger(__name__)
[docs]
@attr.s
class QualityCheckConfig(ABC):
"""Quality Check Config.
Attributes:
baseline_dataset (str or PipelineVariable): The path to the
baseline_dataset file. This can be a local path or an S3 uri string
dataset_format (dict): The format of the baseline_dataset.
output_s3_uri (str or PipelineVariable): Desired S3 destination of
the constraint_violations and statistics json files (default: None).
If not specified an auto generated path will be used:
"s3://<default_session_bucket>/model-monitor/baselining/<job_name>/results"
post_analytics_processor_script (str): The path to the record post-analytics
processor script (default: None). This can be a local path or an S3 uri string
but CANNOT be any type of the PipelineVariable.
"""
baseline_dataset: Union[str, PipelineVariable] = attr.ib()
dataset_format: dict = attr.ib()
output_s3_uri: Union[str, PipelineVariable] = attr.ib(kw_only=True, default=None)
post_analytics_processor_script: str = attr.ib(kw_only=True, default=None)
[docs]
@attr.s
class DataQualityCheckConfig(QualityCheckConfig):
"""Data Quality Check Config.
Attributes:
record_preprocessor_script (str): The path to the record preprocessor script
(default: None).
This can be a local path or an S3 uri string
but CANNOT be any type of the PipelineVariable.
"""
record_preprocessor_script: str = attr.ib(default=None)
[docs]
@attr.s
class ModelQualityCheckConfig(QualityCheckConfig):
"""Model Quality Check Config.
Attributes:
problem_type (str or PipelineVariable): The type of problem of this model
quality monitoring.
Valid values are "Regression", "BinaryClassification", "MulticlassClassification".
inference_attribute (str or PipelineVariable): Index or JSONpath to
locate predicted label(s) (default: None).
probability_attribute (str or PipelineVariable): Index or JSONpath to
locate probabilities (default: None).
ground_truth_attribute (str or PipelineVariable: Index or JSONpath to
locate actual label(s) (default: None).
probability_threshold_attribute (str or PipelineVariable): Threshold to
convert probabilities to binaries (default: None).
"""
problem_type: Union[str, PipelineVariable] = attr.ib()
inference_attribute: Union[str, PipelineVariable] = attr.ib(default=None)
probability_attribute: Union[str, PipelineVariable] = attr.ib(default=None)
ground_truth_attribute: Union[str, PipelineVariable] = attr.ib(default=None)
probability_threshold_attribute: Union[str, PipelineVariable] = attr.ib(default=None)
[docs]
class QualityCheckStep(Step):
"""QualityCheck step for workflow."""
def __init__(
self,
name: str,
quality_check_config: QualityCheckConfig,
check_job_config: CheckJobConfig,
skip_check: Union[bool, PipelineVariable] = False,
fail_on_violation: Union[bool, PipelineVariable] = True,
register_new_baseline: Union[bool, PipelineVariable] = False,
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
supplied_baseline_statistics: Optional[Union[str, PipelineVariable]] = None,
supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
cache_config: Optional[CacheConfig] = None,
depends_on: Optional[List[Union[str, Step]]] = None,
):
"""Constructs a QualityCheckStep.
To understand the `skip_check`, `fail_on_violation`, `register_new_baseline`,
`supplied_baseline_constraints` and `supplied_baseline_constraints` parameters,
check the following documentation:
https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-quality-clarify-baseline-lifecycle.html
Args:
name (str): The name of the QualityCheckStep step.
quality_check_config (QualityCheckConfig): A QualityCheckConfig instance.
check_job_config (CheckJobConfig): A CheckJobConfig instance.
skip_check (bool or PipelineVariable): Whether the check
should be skipped (default: False).
fail_on_violation (bool or PipelineVariable): Whether to fail the step
if violation detected (default: True).
register_new_baseline (bool or PipelineVariable): Whether
the new baseline should be registered (default: False).
model_package_group_name (str or PipelineVariable): The name of a
registered model package group, among which the baseline will be fetched
from the latest approved model (default: None).
supplied_baseline_statistics (str or PipelineVariable): The S3 path
to the supplied statistics object representing the statistics JSON file
which will be used for drift to check (default: None).
supplied_baseline_constraints (str or PipelineVariable): The S3 path
to the supplied constraints object representing the constraints JSON file
which will be used for drift to check (default: None).
display_name (str): The display name of the QualityCheckStep step (default: None).
description (str): The description of the QualityCheckStep step (default: None).
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance
(default: None).
depends_on (List[Union[str, Step]]): A list of `Step`
names or `Step` instances that this `QualityCheckStep`
depends on (default: None).
"""
if not isinstance(quality_check_config, DataQualityCheckConfig) and not isinstance(
quality_check_config, ModelQualityCheckConfig
):
raise RuntimeError(
"The quality_check_config can only be object of "
+ "DataQualityCheckConfig or ModelQualityCheckConfig"
)
super(QualityCheckStep, self).__init__(
name, display_name, description, StepTypeEnum.QUALITY_CHECK, depends_on
)
self.skip_check = skip_check
self.fail_on_violation = fail_on_violation
self.register_new_baseline = register_new_baseline
self.check_job_config = check_job_config
self.quality_check_config = quality_check_config
self.model_package_group_name = model_package_group_name
self.supplied_baseline_statistics = supplied_baseline_statistics
self.supplied_baseline_constraints = supplied_baseline_constraints
self.cache_config = cache_config
if isinstance(self.quality_check_config, DataQualityCheckConfig):
self._model_monitor = self.check_job_config._generate_model_monitor(
"DefaultModelMonitor"
)
else:
self._model_monitor = self.check_job_config._generate_model_monitor(
"ModelQualityMonitor"
)
self._model_monitor.latest_baselining_job_name = (
self._model_monitor._generate_baselining_job_name()
)
baseline_job_inputs_with_nones = self._generate_baseline_job_inputs()
self._baseline_job_inputs = [
baseline_job_input
for baseline_job_input in baseline_job_inputs_with_nones.values()
if baseline_job_input is not None
]
self._baseline_output = self._generate_baseline_output()
self._baselining_processor = self._generate_baseline_processor(
baseline_dataset_input=baseline_job_inputs_with_nones["baseline_dataset_input"],
baseline_output=self._baseline_output,
post_processor_script_input=baseline_job_inputs_with_nones[
"post_processor_script_input"
],
record_preprocessor_script_input=baseline_job_inputs_with_nones[
"record_preprocessor_script_input"
],
)
root_prop = Properties(step_name=name, step=self)
root_prop.__dict__["CalculatedBaselineConstraints"] = Properties(
step_name=name, step=self, path="CalculatedBaselineConstraints"
)
root_prop.__dict__["CalculatedBaselineStatistics"] = Properties(
step_name=name, step=self, path="CalculatedBaselineStatistics"
)
root_prop.__dict__["BaselineUsedForDriftCheckStatistics"] = Properties(
step_name=name, step=self, path="BaselineUsedForDriftCheckStatistics"
)
root_prop.__dict__["BaselineUsedForDriftCheckConstraints"] = Properties(
step_name=name, step=self, path="BaselineUsedForDriftCheckConstraints"
)
self._properties = root_prop
@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to define the QualityCheck step."""
from sagemaker.core.workflow.utilities import _pipeline_config
# Create request dictionary manually with correct AWS API field names
processing_inputs = []
if self._baseline_job_inputs:
for inp in self._baseline_job_inputs:
input_dict = {
"InputName": inp.input_name,
}
if inp.s3_input:
input_dict["S3Input"] = {
"S3Uri": inp.s3_input.s3_uri,
"LocalPath": inp.s3_input.local_path,
"S3DataType": getattr(inp.s3_input, 's3_data_type', 'S3Prefix'),
"S3InputMode": getattr(inp.s3_input, 's3_input_mode', 'File'),
}
processing_inputs.append(input_dict)
processing_outputs = [{
"OutputName": self._baseline_output.output_name,
"S3Output": {
"S3Uri": self._baseline_output.s3_output.s3_uri,
"LocalPath": self._baseline_output.s3_output.local_path,
"S3UploadMode": self._baseline_output.s3_output.s3_upload_mode,
}
}]
request_dict = {
"ProcessingInputs": processing_inputs,
"ProcessingOutputConfig": {"Outputs": processing_outputs},
"ProcessingJobName": self._baselining_processor._current_job_name or "baseline-job",
"ProcessingResources": {
"ClusterConfig": {
"InstanceCount": self._baselining_processor.instance_count,
"InstanceType": self._baselining_processor.instance_type,
"VolumeSizeInGB": getattr(self._baselining_processor, 'volume_size_in_gb', 30),
}
},
"AppSpecification": {
"ImageUri": self._baselining_processor.image_uri,
},
"RoleArn": self._baselining_processor.role,
"StoppingCondition": {
"MaxRuntimeInSeconds": getattr(self._baselining_processor, 'max_runtime_in_seconds', None) or 86400
},
}
# Add optional fields if they exist
if self._baselining_processor.env:
request_dict["Environment"] = self._baselining_processor.env
if self._baselining_processor.network_config:
request_dict["NetworkConfig"] = self._baselining_processor.network_config
if self._baselining_processor.entrypoint:
request_dict["AppSpecification"]["ContainerEntrypoint"] = self._baselining_processor.entrypoint
if self._baselining_processor.arguments:
request_dict["AppSpecification"]["ContainerArguments"] = self._baselining_processor.arguments
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "ProcessingJobName", _pipeline_config)
return request_dict
@property
def properties(self):
"""A Properties object representing the output parameters of the QualityCheck step."""
return self._properties
[docs]
def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration etc."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
if isinstance(self.quality_check_config, DataQualityCheckConfig):
request_dict["CheckType"] = _DATA_QUALITY_TYPE
else:
request_dict["CheckType"] = _MODEL_QUALITY_TYPE
request_dict["ModelPackageGroupName"] = self.model_package_group_name
request_dict["SkipCheck"] = self.skip_check
request_dict["FailOnViolation"] = self.fail_on_violation
request_dict["RegisterNewBaseline"] = self.register_new_baseline
request_dict["SuppliedBaselineStatistics"] = self.supplied_baseline_statistics
request_dict["SuppliedBaselineConstraints"] = self.supplied_baseline_constraints
return request_dict
def _generate_baseline_job_inputs(self):
"""Generates a dict with ProcessingInput objects
Generates a dict with three ProcessingInput objects: baseline_dataset_input,
post_processor_script_input and record_preprocessor_script_input
Returns:
dict: with three ProcessingInput objects as baseline job inputs
"""
baseline_dataset = self.quality_check_config.baseline_dataset
baseline_dataset_des = str(
pathlib.PurePosixPath(
_CONTAINER_BASE_PATH, _CONTAINER_INPUT_PATH, _BASELINE_DATASET_INPUT_NAME
)
)
if is_pipeline_variable(baseline_dataset):
baseline_dataset_input = ProcessingInput(
input_name=_BASELINE_DATASET_INPUT_NAME,
s3_input={
"s3_uri": self.quality_check_config.baseline_dataset,
"local_path": baseline_dataset_des,
}
)
else:
baseline_dataset_input = self._model_monitor._upload_and_convert_to_processing_input(
source=self.quality_check_config.baseline_dataset,
destination=baseline_dataset_des,
name=_BASELINE_DATASET_INPUT_NAME,
)
post_processor_script_input = self._model_monitor._upload_and_convert_to_processing_input(
source=self.quality_check_config.post_analytics_processor_script,
destination=str(
pathlib.PurePosixPath(
_CONTAINER_BASE_PATH,
_CONTAINER_INPUT_PATH,
_POST_ANALYTICS_PROCESSOR_SCRIPT_INPUT_NAME,
)
),
name=_POST_ANALYTICS_PROCESSOR_SCRIPT_INPUT_NAME,
)
record_preprocessor_script_input = None
if isinstance(self.quality_check_config, DataQualityCheckConfig):
record_preprocessor_script_input = (
self._model_monitor._upload_and_convert_to_processing_input(
source=self.quality_check_config.record_preprocessor_script,
destination=str(
pathlib.PurePosixPath(
_CONTAINER_BASE_PATH,
_CONTAINER_INPUT_PATH,
_RECORD_PREPROCESSOR_SCRIPT_INPUT_NAME,
)
),
name=_RECORD_PREPROCESSOR_SCRIPT_INPUT_NAME,
)
)
return dict(
baseline_dataset_input=baseline_dataset_input,
post_processor_script_input=post_processor_script_input,
record_preprocessor_script_input=record_preprocessor_script_input,
)
def _generate_baseline_output(self):
"""Generates a ProcessingOutput object
Returns:
sagemaker.processing.ProcessingOutput: The normalized ProcessingOutput object.
"""
s3_uri = self.quality_check_config.output_s3_uri or s3.s3_path_join(
"s3://",
self._model_monitor.sagemaker_session.default_bucket(),
self._model_monitor.sagemaker_session.default_bucket_prefix,
_MODEL_MONITOR_S3_PATH,
_BASELINING_S3_PATH,
self._model_monitor.latest_baselining_job_name,
_RESULTS_S3_PATH,
)
return ProcessingOutput(
output_name=_DEFAULT_OUTPUT_NAME,
s3_output={
"s3_uri": s3_uri,
"local_path": str(pathlib.PurePosixPath(_CONTAINER_BASE_PATH, _CONTAINER_OUTPUT_PATH)),
"s3_upload_mode": "EndOfJob",
}
)
def _generate_baseline_processor(
self,
baseline_dataset_input,
baseline_output,
post_processor_script_input=None,
record_preprocessor_script_input=None,
):
"""Generates a baseline processor
Args:
baseline_dataset_input (ProcessingInput): A ProcessingInput instance for baseline
dataset input.
baseline_output (ProcessingOutput): A ProcessingOutput instance for baseline
dataset output.
post_processor_script_input (ProcessingInput): A ProcessingInput instance for
post processor script input.
record_preprocessor_script_input (ProcessingInput): A ProcessingInput instance for
record preprocessor script input.
Returns:
sagemaker.processing.Processor: The baseline processor
"""
quality_check_cfg = self.quality_check_config
# Unlike other input, dataset must be a directory for the Monitoring image.
baseline_dataset_container_path = baseline_dataset_input.s3_input.local_path
post_processor_script_container_path = None
if post_processor_script_input is not None:
post_processor_script_container_path = str(
pathlib.PurePosixPath(
post_processor_script_input.s3_input.local_path,
os.path.basename(quality_check_cfg.post_analytics_processor_script),
)
)
record_preprocessor_script_container_path = None
if isinstance(quality_check_cfg, DataQualityCheckConfig):
if record_preprocessor_script_input is not None:
record_preprocessor_script_container_path = str(
pathlib.PurePosixPath(
record_preprocessor_script_input.s3_input.local_path,
os.path.basename(quality_check_cfg.record_preprocessor_script),
)
)
normalized_env = ModelMonitor._generate_env_map(
env=self._model_monitor.env,
dataset_format=quality_check_cfg.dataset_format,
output_path=baseline_output.s3_output.local_path,
enable_cloudwatch_metrics=False, # Only supported for monitoring schedules
dataset_source_container_path=baseline_dataset_container_path,
record_preprocessor_script_container_path=record_preprocessor_script_container_path,
post_processor_script_container_path=post_processor_script_container_path,
)
else:
inference_attribute = _format_env_variable_value(
var_value=quality_check_cfg.inference_attribute, var_name="inference_attribute"
)
probability_attribute = _format_env_variable_value(
var_value=quality_check_cfg.probability_attribute, var_name="probability_attribute"
)
ground_truth_attribute = _format_env_variable_value(
var_value=quality_check_cfg.ground_truth_attribute,
var_name="ground_truth_attribute",
)
probability_threshold_attr = _format_env_variable_value(
var_value=quality_check_cfg.probability_threshold_attribute,
var_name="probability_threshold_attr",
)
normalized_env = ModelMonitor._generate_env_map(
env=self._model_monitor.env,
dataset_format=quality_check_cfg.dataset_format,
output_path=baseline_output.s3_output.local_path,
enable_cloudwatch_metrics=False, # Only supported for monitoring schedules
dataset_source_container_path=baseline_dataset_container_path,
post_processor_script_container_path=post_processor_script_container_path,
analysis_type=_MODEL_QUALITY_TYPE,
problem_type=quality_check_cfg.problem_type,
inference_attribute=inference_attribute,
probability_attribute=probability_attribute,
ground_truth_attribute=ground_truth_attribute,
probability_threshold_attribute=probability_threshold_attr,
)
return Processor(
role=self._model_monitor.role,
image_uri=self._model_monitor.image_uri,
instance_count=self._model_monitor.instance_count,
instance_type=self._model_monitor.instance_type,
entrypoint=self._model_monitor.entrypoint,
volume_size_in_gb=self._model_monitor.volume_size_in_gb,
volume_kms_key=self._model_monitor.volume_kms_key,
output_kms_key=self._model_monitor.output_kms_key,
max_runtime_in_seconds=self._model_monitor.max_runtime_in_seconds,
base_job_name=self._model_monitor.base_job_name,
sagemaker_session=self._model_monitor.sagemaker_session,
env=normalized_env,
tags=self._model_monitor.tags,
network_config=self._model_monitor.network_config,
)
def _format_env_variable_value(var_value: Union[PrimitiveType, PipelineVariable], var_name: str):
"""Helper function to format the variable values passed to env var
Args:
var_value (PrimitiveType or PipelineVariable): The value of the variable.
var_name (str): The name of the variable.
"""
if var_value is None:
return None
if is_pipeline_variable(var_value):
if isinstance(var_value, Parameter) and not isinstance(var_value, ParameterString):
raise ValueError(f"{var_name} cannot be Parameter types other than ParameterString.")
logger.warning("%s's runtime value must be the string type.", var_name)
return var_value
return str(var_value)