Source code for sagemaker.mlops.workflow.retry

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Pipeline parameters and conditions for workflow."""
from __future__ import absolute_import

from enum import Enum
from typing import List, Optional
import attr

from sagemaker.core.workflow.entities import Entity, DefaultEnumMeta, RequestType


DEFAULT_BACKOFF_RATE = 2.0
DEFAULT_INTERVAL_SECONDS = 1
MAX_ATTEMPTS_CAP = 20
MAX_EXPIRE_AFTER_MIN = 14400


[docs] class StepExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta): """Step ExceptionType enum.""" SERVICE_FAULT = "Step.SERVICE_FAULT" THROTTLING = "Step.THROTTLING"
[docs] class SageMakerJobExceptionTypeEnum(Enum, metaclass=DefaultEnumMeta): """SageMaker Job ExceptionType enum.""" INTERNAL_ERROR = "SageMaker.JOB_INTERNAL_ERROR" CAPACITY_ERROR = "SageMaker.CAPACITY_ERROR" RESOURCE_LIMIT = "SageMaker.RESOURCE_LIMIT"
[docs] @attr.s class RetryPolicy(Entity): """RetryPolicy base class Attributes: backoff_rate (float): The multiplier by which the retry interval increases during each attempt (default: 2.0) interval_seconds (int): An integer that represents the number of seconds before the first retry attempt (default: 1) max_attempts (int): A positive integer that represents the maximum number of retry attempts. (default: None) expire_after_mins (int): A positive integer that represents the maximum minute to expire any further retry attempt (default: None) """ backoff_rate: float = attr.ib(default=DEFAULT_BACKOFF_RATE) interval_seconds: int = attr.ib(default=DEFAULT_INTERVAL_SECONDS) max_attempts: int = attr.ib(default=None) expire_after_mins: int = attr.ib(default=None)
[docs] @backoff_rate.validator def validate_backoff_rate(self, _, value): """Validate the input back off rate type""" if value: assert value >= 0.0, "backoff_rate should be non-negative"
[docs] @interval_seconds.validator def validate_interval_seconds(self, _, value): """Validate the input interval seconds""" if value: assert value >= 0.0, "interval_seconds rate should be non-negative"
[docs] @max_attempts.validator def validate_max_attempts(self, _, value): """Validate the input max attempts""" if value: assert ( MAX_ATTEMPTS_CAP >= value >= 1 ), f"max_attempts must in range of (0, {MAX_ATTEMPTS_CAP}] attempts"
[docs] @expire_after_mins.validator def validate_expire_after_mins(self, _, value): """Validate expire after mins""" if value: assert ( MAX_EXPIRE_AFTER_MIN >= value >= 0 ), f"expire_after_mins must in range of (0, {MAX_EXPIRE_AFTER_MIN}] minutes"
[docs] def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" if (self.max_attempts is None) == self.expire_after_mins is None: raise ValueError("Only one of [max_attempts] and [expire_after_mins] can be given.") request = { "BackoffRate": self.backoff_rate, "IntervalSeconds": self.interval_seconds, } if self.max_attempts: request["MaxAttempts"] = self.max_attempts if self.expire_after_mins: request["ExpireAfterMin"] = self.expire_after_mins return request
[docs] class StepRetryPolicy(RetryPolicy): """RetryPolicy for a retryable step. The pipeline service will retry `sagemaker.workflow.retry.StepRetryExceptionTypeEnum.SERVICE_FAULT` and `sagemaker.workflow.retry.StepRetryExceptionTypeEnum.THROTTLING` regardless of pipeline step type by default. However, for step defined as retryable, you can override them by specifying a StepRetryPolicy. Attributes: exception_types (List[StepExceptionTypeEnum]): the exception types to match for this policy backoff_rate (float): The multiplier by which the retry interval increases during each attempt (default: 2.0) interval_seconds (int): An integer that represents the number of seconds before the first retry attempt (default: 1) max_attempts (int): A positive integer that represents the maximum number of retry attempts. (default: None) expire_after_mins (int): A positive integer that represents the maximum minute to expire any further retry attempt (default: None) """ def __init__( self, exception_types: List[StepExceptionTypeEnum], backoff_rate: float = 2.0, interval_seconds: int = 1, max_attempts: Optional[int] = None, expire_after_mins: Optional[int] = None, ): super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins) for exception_type in exception_types: if not isinstance(exception_type, StepExceptionTypeEnum): raise ValueError(f"{exception_type} is not of StepExceptionTypeEnum.") self.exception_types = exception_types
[docs] def to_request(self) -> RequestType: """Gets the request structure for retry policy.""" request = super().to_request() request["ExceptionType"] = [e.value for e in self.exception_types] return request
def __hash__(self): """Hash function for StepRetryPolicy types""" return hash(tuple(self.to_request()))
[docs] class SageMakerJobStepRetryPolicy(RetryPolicy): """RetryPolicy for exception thrown by SageMaker Job. Attributes: exception_types (List[SageMakerJobExceptionTypeEnum]): The SageMaker exception to match for this policy. The SageMaker exceptions captured here are the exceptions thrown by synchronously creating the job. For instance the resource limit exception. failure_reason_types (List[SageMakerJobExceptionTypeEnum]): the SageMaker failure reason types to match for this policy. The failure reason type is presented in FailureReason field of the Describe response, it indicates the runtime failure reason for a job. backoff_rate (float): The multiplier by which the retry interval increases during each attempt (default: 2.0) interval_seconds (int): An integer that represents the number of seconds before the first retry attempt (default: 1) max_attempts (int): A positive integer that represents the maximum number of retry attempts. (default: None) expire_after_mins (int): A positive integer that represents the maximum minute to expire any further retry attempt (default: None) """ def __init__( self, exception_types: Optional[List[SageMakerJobExceptionTypeEnum]] = None, failure_reason_types: Optional[List[SageMakerJobExceptionTypeEnum]] = None, backoff_rate: float = 2.0, interval_seconds: int = 1, max_attempts: Optional[int] = None, expire_after_mins: Optional[int] = None, ): super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins) if not exception_types and not failure_reason_types: raise ValueError( "At least one of the [exception_types, failure_reason_types] needs to be given." ) self.exception_type_list: List[SageMakerJobExceptionTypeEnum] = [] if exception_types: self.exception_type_list += exception_types if failure_reason_types: self.exception_type_list += failure_reason_types for exception_type in self.exception_type_list: if not isinstance(exception_type, SageMakerJobExceptionTypeEnum): raise ValueError(f"{exception_type} is not of SageMakerJobExceptionTypeEnum.")
[docs] def to_request(self) -> RequestType: """Gets the request structure for retry policy.""" request = super().to_request() request["ExceptionType"] = [e.value for e in self.exception_type_list] return request
def __hash__(self): """Hash function for SageMakerJobStepRetryPolicy types""" return hash(tuple(self.to_request()))