Source code for sagemaker.core.jumpstart.exceptions

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores exceptions related to SageMaker JumpStart."""
from __future__ import absolute_import
from typing import List, Optional

from botocore.exceptions import ClientError

from sagemaker.core.jumpstart.constants import MODEL_ID_LIST_WEB_URL
from sagemaker.core.jumpstart.enums import JumpStartScriptScope


NO_AVAILABLE_INSTANCES_ERROR_MSG = (
    "No instances available in {region} that can support model ID '{model_id}'. "
    "Please try another region."
)

NO_AVAILABLE_RESOURCE_REQUIREMENT_RECOMMENDATION_ERROR_MSG = (
    "No available compute resource requirement recommendation for model ID '{model_id}'. "
    "Provide the resource requirements in the deploy method."
)

INVALID_MODEL_ID_ERROR_MSG = (
    "Invalid model ID: '{model_id}'. Specify a different model ID or try a different AWS Region. "
    f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. "
    "The module `sagemaker.jumpstart.notebook_utils` contains utilities for "
    "fetching model IDs. We recommend upgrading to the latest version of sagemaker "
    "to get access to the most models."
)


_MAJOR_VERSION_WARNING_MSG = (
    "Note that models may have different input/output signatures after a major version upgrade."
)

_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION = (
    "We recommend that you specify a more recent "
    "model version or choose a different model. To access the latest models "
    "and model versions, be sure to upgrade to the latest version of the SageMaker Python SDK."
)


[docs] def get_wildcard_model_version_msg( model_id: str, wildcard_model_version: str, full_model_version: str ) -> str: """Returns customer-facing message for using a model version with a wildcard character.""" return ( f"Using model '{model_id}' with wildcard version identifier '{wildcard_model_version}'. " f"You can pin to version '{full_model_version}' " f"for more stable results. {_MAJOR_VERSION_WARNING_MSG}" )
[docs] def get_proprietary_model_subscription_msg( model_id: str, subscription_link: str, ) -> str: """Returns customer-facing message for using a proprietary model.""" return ( f"INFO: Using proprietary model '{model_id}'. " f"To subscribe to this model in AWS Marketplace, see {subscription_link}" )
[docs] def get_wildcard_proprietary_model_version_msg( model_id: str, wildcard_model_version: str, available_versions: List[str] ) -> str: """Returns customer-facing message for passing wildcard version to proprietary models.""" msg = ( f"Proprietary model '{model_id}' does not support " f"wildcard version identifier '{wildcard_model_version}'. " ) if len(available_versions) > 0: msg += f"You can pin to version '{available_versions[0]}'. " msg += f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " return msg
[docs] def get_old_model_version_msg( model_id: str, current_model_version: str, latest_model_version: str ) -> str: """Returns customer-facing message associated with using an old model version.""" return ( f"Using model '{model_id}' with version '{current_model_version}'. " f"You can upgrade to version '{latest_model_version}' to get the latest model " f"specifications. {_MAJOR_VERSION_WARNING_MSG}" )
[docs] def get_proprietary_model_subscription_error(error: ClientError, subscription_link: str) -> None: """Returns customer-facing message associated with a Marketplace subscription error.""" error_code = error.response["Error"]["Code"] error_message = error.response["Error"]["Message"] if error_code == "ValidationException" and "not subscribed" in error_message: raise MarketplaceModelSubscriptionError(subscription_link)
[docs] class JumpStartHyperparametersError(ValueError): """Exception raised for bad hyperparameters of a JumpStart model.""" def __init__( self, message: Optional[str] = None, ): self.message = message super().__init__(self.message)
[docs] class VulnerableJumpStartModelError(ValueError): """Exception raised when trying to access a JumpStart model specs flagged as vulnerable. Raise this exception only if the scope of attributes accessed in the specifications have vulnerabilities. For example, a model training script may have vulnerabilities, but not the hosting scripts. In such a case, raise a ``VulnerableJumpStartModelError`` only when accessing the training specifications. """ def __init__( self, model_id: Optional[str] = None, version: Optional[str] = None, vulnerabilities: Optional[List[str]] = None, scope: Optional[JumpStartScriptScope] = None, message: Optional[str] = None, ): """Instantiates VulnerableJumpStartModelError exception. Args: model_id (Optional[str]): model ID of vulnerable JumpStart model. (Default: None). version (Optional[str]): version of vulnerable JumpStart model. (Default: None). vulnerabilities (Optional[List[str]]): vulnerabilities associated with model. (Default: None). """ if message: self.message = message else: if None in [model_id, version, vulnerabilities, scope]: raise RuntimeError( "Must specify `model_id`, `version`, `vulnerabilities`, " "and scope arguments." ) if scope == JumpStartScriptScope.INFERENCE: self.message = ( f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore "has at least 1 vulnerable dependency in the inference script. " f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} " "List of vulnerabilities: " f"{', '.join(vulnerabilities)}" # type: ignore ) elif scope == JumpStartScriptScope.TRAINING: self.message = ( f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore "has at least 1 vulnerable dependency in the training script. " f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} " "List of vulnerabilities: " f"{', '.join(vulnerabilities)}" # type: ignore ) else: raise NotImplementedError( "Unsupported scope for VulnerableJumpStartModelError: " # type: ignore f"'{scope.value}'" ) super().__init__(self.message)
[docs] class DeprecatedJumpStartModelError(ValueError): """Exception raised when trying to access a JumpStart model deprecated specifications. A deprecated specification for a JumpStart model does not mean the whole model is deprecated. There may be more recent specifications available for this model. For example, all specification before version ``2.0.0`` may be deprecated, in such a case, the SDK would raise this exception only when specifications ``1.*`` are accessed. """ def __init__( self, model_id: Optional[str] = None, version: Optional[str] = None, message: Optional[str] = None, ): if message: self.message = message else: if None in [model_id, version]: raise RuntimeError("Must specify `model_id` and `version` arguments.") self.message = ( f"Version '{version}' of JumpStart model '{model_id}' is deprecated. " f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION}" ) super().__init__(self.message)
[docs] class MarketplaceModelSubscriptionError(ValueError): """Exception raised when trying to deploy a JumpStart Marketplace model. A caller is required to subscribe to the Marketplace product in order to deploy. This exception is raised when a caller tries to deploy a JumpStart Marketplace model but the caller is not subscribed to the model. """ def __init__( self, model_subscription_link: Optional[str] = None, message: Optional[str] = None, ): if message: self.message = message else: self.message = ( "To use a proprietary JumpStart model, " "you must first subscribe to the model in AWS Marketplace. " ) if model_subscription_link: self.message += f"To subscribe to this model, see {model_subscription_link}" super().__init__(self.message)