# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
from __future__ import absolute_import
import copy
import json
import os
import tempfile
from abc import ABC
from typing import List, Union, Optional
import attr
from sagemaker.core import s3
from sagemaker.core.clarify import (
DataConfig,
BiasConfig,
ModelConfig,
ModelPredictedLabelConfig,
SHAPConfig,
ProcessingOutputHandler,
_upload_analysis_config,
SageMakerClarifyProcessor,
_set,
)
from sagemaker.core.model_monitor import BiasAnalysisConfig, ExplainabilityAnalysisConfig
from sagemaker.core.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
from sagemaker.core.shapes import ProcessingInput, ProcessingOutput
from sagemaker.core.common_utils import name_from_base
from sagemaker.core.workflow import is_pipeline_variable
from sagemaker.core.helper.pipeline_variable import (
RequestType,
PipelineVariable,
)
from sagemaker.core.workflow.properties import Properties
from sagemaker.mlops.workflow.steps import Step, StepTypeEnum, CacheConfig
from sagemaker.mlops.workflow.check_job_config import CheckJobConfig
from sagemaker.core.workflow.utilities import trim_request_dict
_DATA_BIAS_TYPE = "DATA_BIAS"
_MODEL_BIAS_TYPE = "MODEL_BIAS"
_MODEL_EXPLAINABILITY_TYPE = "MODEL_EXPLAINABILITY"
_BIAS_MONITORING_CFG_BASE_NAME = "bias-monitoring"
_EXPLAINABILITY_MONITORING_CFG_BASE_NAME = "model-explainability-monitoring"
[docs]
@attr.s
class ClarifyCheckConfig(ABC):
"""Clarify Check Config
Attributes:
data_config (DataConfig): Config of the input/output data.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
This field CANNOT be any type of the `PipelineVariable`.
monitoring_analysis_config_uri: (str): The uri of monitoring analysis config.
This field does not take input.
It will be generated once uploading the created analysis config file.
"""
data_config: DataConfig = attr.ib()
kms_key: str = attr.ib(kw_only=True, default=None)
monitoring_analysis_config_uri: str = attr.ib(kw_only=True, default=None)
[docs]
@attr.s
class DataBiasCheckConfig(ClarifyCheckConfig):
"""Data Bias Check Config
Attributes:
data_bias_config (BiasConfig): Config of sensitive groups.
methods (str or list[str]): Selector of a subset of potential metrics:
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
Defaults to computing all.
This field CANNOT be any type of the `PipelineVariable`.
""" # noqa E501
data_bias_config: BiasConfig = attr.ib()
methods: Union[str, List[str]] = attr.ib(default="all")
[docs]
@attr.s
class ModelBiasCheckConfig(ClarifyCheckConfig):
"""Model Bias Check Config
Attributes:
data_bias_config (BiasConfig): Config of sensitive groups.
model_config (ModelConfig): Config of the model and its endpoint to be created.
model_predicted_label_config (ModelPredictedLabelConfig): Config of how to
extract the predicted label from the model output.
methods (str or list[str]): Selector of a subset of potential metrics:
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
Defaults to computing all.
This field CANNOT be any type of the `PipelineVariable`.
"""
data_bias_config: BiasConfig = attr.ib()
model_config: ModelConfig = attr.ib()
model_predicted_label_config: ModelPredictedLabelConfig = attr.ib()
methods: Union[str, List[str]] = attr.ib(default="all")
[docs]
@attr.s
class ModelExplainabilityCheckConfig(ClarifyCheckConfig):
"""Model Explainability Check Config
Attributes:
model_config (ModelConfig): Config of the model and its endpoint to be created.
explainability_config (SHAPConfig or PDPConfig): Config of the explainability method.
Supports SHAP or PDP.
For `PDPConfig`, `features` must be specified.
`top_k_features` based on SHAP is currently not supported.
model_scores (str or int or ModelPredictedLabelConfig): Index or JMESPath expression
to locate the predicted scores in the model output (default: None).
This is not required if the model output is a single score. Alternatively,
an instance of ModelPredictedLabelConfig can be provided
but this field CANNOT be any type of the `PipelineVariable`.
"""
model_config: ModelConfig = attr.ib()
explainability_config: SHAPConfig = attr.ib()
model_scores: Union[str, int, ModelPredictedLabelConfig] = attr.ib(default=None)
[docs]
class ClarifyCheckStep(Step):
"""ClarifyCheckStep step for workflow."""
def __init__(
self,
name: str,
clarify_check_config: ClarifyCheckConfig,
check_job_config: CheckJobConfig,
skip_check: Union[bool, PipelineVariable] = False,
fail_on_violation: Union[bool, PipelineVariable] = True,
register_new_baseline: Union[bool, PipelineVariable] = False,
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
cache_config: Optional[CacheConfig] = None,
depends_on: Optional[List[Union[str, Step]]] = None,
):
"""Constructs a ClarifyCheckStep.
To understand the `skip_check`, `fail_on_violation`, `register_new_baseline`
and `supplied_baseline_constraints` parameters, check the following documentation:
https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-quality-clarify-baseline-lifecycle.html
Args:
name (str): The name of the ClarifyCheckStep step.
clarify_check_config (ClarifyCheckConfig): A ClarifyCheckConfig instance.
check_job_config (CheckJobConfig): A CheckJobConfig instance.
skip_check (bool or PipelineVariable): Whether the check
should be skipped (default: False).
fail_on_violation (bool or PipelineVariable): Whether to fail the step
if violation detected (default: True).
register_new_baseline (bool or PipelineVariable): Whether
the new baseline should be registered (default: False).
model_package_group_name (str or PipelineVariable): The name of a
registered model package group, among which the baseline will be fetched
from the latest approved model (default: None).
supplied_baseline_constraints (str or PipelineVariable): The S3 path
to the supplied constraints object representing the constraints JSON file
which will be used for drift to check (default: None).
display_name (str): The display name of the ClarifyCheckStep step (default: None).
description (str): The description of the ClarifyCheckStep step (default: None).
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance
(default: None).
depends_on (List[Union[str, Step]]): A list of `Step`
names or `Step` instances that this `ClarifyCheckStep`
depends on (default: None).
"""
if (
not isinstance(clarify_check_config, DataBiasCheckConfig)
and not isinstance(clarify_check_config, ModelBiasCheckConfig)
and not isinstance(clarify_check_config, ModelExplainabilityCheckConfig)
):
raise RuntimeError(
"The clarify_check_config can only be object of "
+ "DataBiasCheckConfig, ModelBiasCheckConfig or ModelExplainabilityCheckConfig"
)
if is_pipeline_variable(clarify_check_config.data_config.s3_analysis_config_output_path):
raise RuntimeError(
"s3_analysis_config_output_path cannot be of type "
+ "ExecutionVariable/Expression/Parameter/Properties"
)
if (
not clarify_check_config.data_config.s3_analysis_config_output_path
and is_pipeline_variable(clarify_check_config.data_config.s3_output_path)
):
raise RuntimeError(
"`s3_output_path` cannot be of type ExecutionVariable/Expression/Parameter"
+ "/Properties if `s3_analysis_config_output_path` is none or empty "
)
super(ClarifyCheckStep, self).__init__(
name, display_name, description, StepTypeEnum.CLARIFY_CHECK, depends_on
)
self.skip_check = skip_check
self.fail_on_violation = fail_on_violation
self.register_new_baseline = register_new_baseline
self.clarify_check_config = clarify_check_config
self.check_job_config = check_job_config
self.model_package_group_name = model_package_group_name
self.supplied_baseline_constraints = supplied_baseline_constraints
self.cache_config = cache_config
if isinstance(self.clarify_check_config, ModelExplainabilityCheckConfig):
self._model_monitor = self.check_job_config._generate_model_monitor(
"ModelExplainabilityMonitor"
)
else:
self._model_monitor = self.check_job_config._generate_model_monitor("ModelBiasMonitor")
self.clarify_check_config.monitoring_analysis_config_uri = (
self._upload_monitoring_analysis_config()
)
self._baselining_processor = self._model_monitor._create_baselining_processor()
self._processing_params = self._generate_processing_job_parameters(
self._generate_processing_job_analysis_config(), self._baselining_processor
)
root_prop = Properties(step_name=name, step=self)
root_prop.__dict__["CalculatedBaselineConstraints"] = Properties(
step_name=name, step=self, path="CalculatedBaselineConstraints"
)
root_prop.__dict__["BaselineUsedForDriftCheckConstraints"] = Properties(
step_name=name, step=self, path="BaselineUsedForDriftCheckConstraints"
)
self._properties = root_prop
@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to define the ClarifyCheck step."""
from sagemaker.core.workflow.utilities import _pipeline_config
# Create request dictionary manually with correct AWS API field names
processing_inputs = []
for inp_key in ["config_input", "data_input"]:
inp = self._processing_params[inp_key]
input_dict = {
"InputName": inp.input_name,
}
if inp.s3_input:
input_dict["S3Input"] = {
"S3Uri": inp.s3_input.s3_uri,
"LocalPath": inp.s3_input.local_path,
"S3DataType": getattr(inp.s3_input, 's3_data_type', 'S3Prefix'),
"S3InputMode": getattr(inp.s3_input, 's3_input_mode', 'File'),
}
processing_inputs.append(input_dict)
processing_outputs = [{
"OutputName": self._processing_params["result_output"].output_name,
"S3Output": {
"S3Uri": self._processing_params["result_output"].s3_output.s3_uri,
"LocalPath": self._processing_params["result_output"].s3_output.local_path,
"S3UploadMode": self._processing_params["result_output"].s3_output.s3_upload_mode,
}
}]
request_dict = {
"ProcessingInputs": processing_inputs,
"ProcessingOutputConfig": {"Outputs": processing_outputs},
"ProcessingJobName": self._baselining_processor._current_job_name or "clarify-job",
"ProcessingResources": {
"ClusterConfig": {
"InstanceCount": self._baselining_processor.instance_count,
"InstanceType": self._baselining_processor.instance_type,
"VolumeSizeInGB": getattr(self._baselining_processor, 'volume_size_in_gb', 30),
}
},
"AppSpecification": {
"ImageUri": self._baselining_processor.image_uri,
},
"RoleArn": self._baselining_processor.role,
"StoppingCondition": {
"MaxRuntimeInSeconds": getattr(self._baselining_processor, 'max_runtime_in_seconds', None) or 86400
},
}
# Add optional fields if they exist
if self._baselining_processor.env:
request_dict["Environment"] = self._baselining_processor.env
if self._baselining_processor.network_config:
request_dict["NetworkConfig"] = self._baselining_processor.network_config
if self._baselining_processor.entrypoint:
request_dict["AppSpecification"]["ContainerEntrypoint"] = self._baselining_processor.entrypoint
if self._baselining_processor.arguments:
request_dict["AppSpecification"]["ContainerArguments"] = self._baselining_processor.arguments
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "ProcessingJobName", _pipeline_config)
return request_dict
@property
def properties(self):
"""A Properties object representing the output parameters of the ClarifyCheck step."""
return self._properties
[docs]
def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration etc."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
if isinstance(self.clarify_check_config, DataBiasCheckConfig):
request_dict["CheckType"] = _DATA_BIAS_TYPE
elif isinstance(self.clarify_check_config, ModelBiasCheckConfig):
request_dict["CheckType"] = _MODEL_BIAS_TYPE
else:
request_dict["CheckType"] = _MODEL_EXPLAINABILITY_TYPE
request_dict["ModelPackageGroupName"] = self.model_package_group_name
request_dict["SkipCheck"] = self.skip_check
request_dict["FailOnViolation"] = self.fail_on_violation
request_dict["RegisterNewBaseline"] = self.register_new_baseline
request_dict["SuppliedBaselineConstraints"] = self.supplied_baseline_constraints
if isinstance(
self.clarify_check_config, (ModelBiasCheckConfig, ModelExplainabilityCheckConfig)
):
request_dict["ModelName"] = (
self.clarify_check_config.model_config.get_predictor_config()["model_name"]
)
return request_dict
def _generate_processing_job_analysis_config(self) -> dict:
"""Generate the clarify processing job analysis config
Returns:
dict: processing job analysis config dictionary.
"""
analysis_config = self.clarify_check_config.data_config.get_config()
if isinstance(self.clarify_check_config, DataBiasCheckConfig):
analysis_config.update(self.clarify_check_config.data_bias_config.get_config())
analysis_config["methods"] = {
"pre_training_bias": {"methods": self.clarify_check_config.methods}
}
elif isinstance(self.clarify_check_config, ModelBiasCheckConfig):
analysis_config.update(self.clarify_check_config.data_bias_config.get_config())
(
probability_threshold,
predictor_config,
) = self.clarify_check_config.model_predicted_label_config.get_predictor_config()
predictor_config.update(self.clarify_check_config.model_config.get_predictor_config())
if "model_name" in predictor_config:
predictor_config.pop("model_name")
analysis_config["methods"] = {
"post_training_bias": {"methods": self.clarify_check_config.methods}
}
analysis_config["predictor"] = predictor_config
_set(probability_threshold, "probability_threshold", analysis_config)
else:
predictor_config = self.clarify_check_config.model_config.get_predictor_config()
if "model_name" in predictor_config:
predictor_config.pop("model_name")
model_scores = self.clarify_check_config.model_scores
if isinstance(model_scores, ModelPredictedLabelConfig):
probability_threshold, predicted_label_config = model_scores.get_predictor_config()
_set(probability_threshold, "probability_threshold", analysis_config)
predictor_config.update(predicted_label_config)
else:
_set(model_scores, "label", predictor_config)
analysis_config["methods"] = (
self.clarify_check_config.explainability_config.get_explainability_config()
)
analysis_config["predictor"] = predictor_config
return analysis_config
def _generate_processing_job_parameters(
self, analysis_config: dict, baselining_processor: SageMakerClarifyProcessor
) -> dict:
"""Generates input and output parameters for the clarify processing job
Args:
analysis_config (dict): A clarify processing job analysis config
baselining_processor (SageMakerClarifyProcessor): A SageMakerClarifyProcessor instance
Returns:
dict: with two ProcessingInput objects as the clarify processing job inputs and
a ProcessingOutput object as the clarify processing job output parameter
"""
data_config = self.clarify_check_config.data_config
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
with tempfile.TemporaryDirectory() as tmpdirname:
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
with open(analysis_config_file, "w") as f:
json.dump(analysis_config, f)
s3_analysis_config_file = _upload_analysis_config(
analysis_config_file,
data_config.s3_analysis_config_output_path or data_config.s3_output_path,
baselining_processor.sagemaker_session,
self.clarify_check_config.kms_key,
)
config_input = ProcessingInput(
input_name="analysis_config",
s3_input={
"s3_uri": s3_analysis_config_file,
"local_path": SageMakerClarifyProcessor._CLARIFY_CONFIG_INPUT,
"s3_data_type": "S3Prefix",
"s3_input_mode": "File",
"s3_compression_type": "None",
}
)
data_input = ProcessingInput(
input_name="dataset",
s3_input={
"s3_uri": data_config.s3_data_input_path,
"local_path": SageMakerClarifyProcessor._CLARIFY_DATA_INPUT,
"s3_data_type": "S3Prefix",
"s3_input_mode": "File",
"s3_data_distribution_type": data_config.s3_data_distribution_type,
"s3_compression_type": data_config.s3_compression_type,
}
)
result_output = ProcessingOutput(
output_name="analysis_result",
s3_output={
"s3_uri": data_config.s3_output_path,
"local_path": SageMakerClarifyProcessor._CLARIFY_OUTPUT,
"s3_upload_mode": ProcessingOutputHandler.get_s3_upload_mode(analysis_config),
}
)
return dict(config_input=config_input, data_input=data_input, result_output=result_output)
def _upload_monitoring_analysis_config(self) -> str:
"""Generate and upload monitoring schedule analysis config to s3
Returns:
str: The S3 uri of the uploaded monitoring schedule analysis config
"""
output_s3_uri = self._get_s3_base_uri_for_monitoring_analysis_config()
if isinstance(self.clarify_check_config, ModelExplainabilityCheckConfig):
# Explainability analysis doesn't need label
headers = copy.deepcopy(self.clarify_check_config.data_config.headers)
if headers and self.clarify_check_config.data_config.label in headers:
headers.remove(self.clarify_check_config.data_config.label)
explainability_analysis_config = ExplainabilityAnalysisConfig(
explainability_config=self.clarify_check_config.explainability_config,
model_config=self.clarify_check_config.model_config,
headers=headers,
)
analysis_config = explainability_analysis_config._to_dict()
if "predictor" in analysis_config and "model_name" in analysis_config["predictor"]:
analysis_config["predictor"].pop("model_name")
job_definition_name = name_from_base(
f"{_EXPLAINABILITY_MONITORING_CFG_BASE_NAME}-config"
)
else:
bias_analysis_config = BiasAnalysisConfig(
bias_config=self.clarify_check_config.data_bias_config,
headers=self.clarify_check_config.data_config.headers,
label=self.clarify_check_config.data_config.label,
)
analysis_config = bias_analysis_config._to_dict()
job_definition_name = name_from_base(f"{_BIAS_MONITORING_CFG_BASE_NAME}-config")
return self._model_monitor._upload_analysis_config(
analysis_config, output_s3_uri, job_definition_name, self.clarify_check_config.kms_key
)
def _get_s3_base_uri_for_monitoring_analysis_config(self) -> str:
"""Generate s3 base uri for monitoring schedule analysis config
Returns:
str: The S3 base uri of the monitoring schedule analysis config
"""
s3_analysis_config_output_path = (
self.clarify_check_config.data_config.s3_analysis_config_output_path
)
monitoring_cfg_base_name = f"{_BIAS_MONITORING_CFG_BASE_NAME}-configuration"
if isinstance(self.clarify_check_config, ModelExplainabilityCheckConfig):
monitoring_cfg_base_name = f"{_EXPLAINABILITY_MONITORING_CFG_BASE_NAME}-configuration"
if s3_analysis_config_output_path:
return s3.s3_path_join(
s3_analysis_config_output_path,
monitoring_cfg_base_name,
)
return s3.s3_path_join(
"s3://",
self._model_monitor.sagemaker_session.default_bucket(),
self._model_monitor.sagemaker_session.default_bucket_prefix,
_MODEL_MONITOR_S3_PATH,
monitoring_cfg_base_name,
)