# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Common config for QualityCheckStep and ClarifyCheckStep."""
from __future__ import absolute_import
import logging
from typing import Optional
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.model_monitor import (
ModelMonitor,
DefaultModelMonitor,
ModelQualityMonitor,
ModelBiasMonitor,
ModelExplainabilityMonitor,
)
from sagemaker.core.common_utils import format_tags
[docs]
class CheckJobConfig:
"""Check job config for QualityCheckStep and ClarifyCheckStep."""
def __init__(
self,
role,
instance_count=1,
instance_type="ml.m5.xlarge",
volume_size_in_gb=30,
volume_kms_key=None,
output_kms_key=None,
max_runtime_in_seconds=None,
base_job_name=None,
sagemaker_session=None,
env=None,
tags=None,
network_config=None,
):
"""Constructs a CheckJobConfig instance.
Args:
role (str): An AWS IAM role. The Amazon SageMaker jobs use this role.
instance_count (int): The number of instances to run the jobs with (default: 1).
instance_type (str): Type of EC2 instance to use for the job
(default: 'ml.m5.xlarge').
volume_size_in_gb (int): Size in GB of the EBS volume
to use for storing data during processing (default: 30).
volume_kms_key (str): A KMS key for the processing volume (default: None).
output_kms_key (str): The KMS key id for the job's outputs (default: None).
max_runtime_in_seconds (int): Timeout in seconds. After this amount of
time, Amazon SageMaker terminates the job regardless of its current status.
Default: 3600 if not specified
base_job_name (str): Prefix for the job name. If not specified,
a default name is generated based on the training image name and
current timestamp (default: None).
sagemaker_session (sagemaker_core.helper.session_helper): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed (default: None). If not specified, one is
created using the default AWS configuration chain.
env (dict): Environment variables to be passed to the job (default: None).
tags (Optional[Tags]): List of tags to be passed to the job (default: None).
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
object that configures network isolation, encryption of
inter-container traffic, security group IDs, and subnets (default: None).
"""
self.role = role
self.instance_count = instance_count
self.instance_type = instance_type
self.volume_size_in_gb = volume_size_in_gb
self.volume_kms_key = volume_kms_key
self.output_kms_key = output_kms_key
self.max_runtime_in_seconds = max_runtime_in_seconds
self.base_job_name = base_job_name
self.sagemaker_session = sagemaker_session or Session()
self.env = env
self.tags = format_tags(tags)
self.network_config = network_config
def _generate_model_monitor(self, mm_type: str) -> Optional[ModelMonitor]:
"""Generates a ModelMonitor object
Generates a ModelMonitor object with required config attributes for
QualityCheckStep and ClarifyCheckStep
Args:
mm_type (str): The subclass type of ModelMonitor object.
A valid mm_type should be one of the following: "DefaultModelMonitor",
"ModelQualityMonitor", "ModelBiasMonitor", "ModelExplainabilityMonitor"
Return:
sagemaker.model_monitor.ModelMonitor or None if the mm_type is not valid
"""
if mm_type == "DefaultModelMonitor":
monitor = DefaultModelMonitor(
role=self.role,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
output_kms_key=self.output_kms_key,
max_runtime_in_seconds=self.max_runtime_in_seconds,
base_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session,
env=self.env,
tags=self.tags,
network_config=self.network_config,
)
elif mm_type == "ModelQualityMonitor":
monitor = ModelQualityMonitor(
role=self.role,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
output_kms_key=self.output_kms_key,
max_runtime_in_seconds=self.max_runtime_in_seconds,
base_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session,
env=self.env,
tags=self.tags,
network_config=self.network_config,
)
elif mm_type == "ModelBiasMonitor":
monitor = ModelBiasMonitor(
role=self.role,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
output_kms_key=self.output_kms_key,
max_runtime_in_seconds=self.max_runtime_in_seconds,
base_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session,
env=self.env,
tags=self.tags,
network_config=self.network_config,
)
elif mm_type == "ModelExplainabilityMonitor":
monitor = ModelExplainabilityMonitor(
role=self.role,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
output_kms_key=self.output_kms_key,
max_runtime_in_seconds=self.max_runtime_in_seconds,
base_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session,
env=self.env,
tags=self.tags,
network_config=self.network_config,
)
else:
logging.warning(
'Expected model monitor types: "DefaultModelMonitor", "ModelQualityMonitor", '
'"ModelBiasMonitor", "ModelExplainabilityMonitor"'
)
return None
return monitor