Source code for sagemaker.core.jumpstart.notebook_utils

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores notebook utils related to SageMaker JumpStart."""
from __future__ import absolute_import
import copy

from concurrent.futures import ThreadPoolExecutor, as_completed

from functools import cmp_to_key
import json
import os
from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict
from packaging.version import Version
from sagemaker.core.jumpstart import accessors
from sagemaker.core.jumpstart.constants import (
    DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
    PROPRIETARY_MODEL_SPEC_PREFIX,
)
from sagemaker.core.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
from sagemaker.core.jumpstart.filters import (
    SPECIAL_SUPPORTED_FILTER_KEYS,
    ProprietaryModelFilterIdentifiers,
    BooleanValues,
    Identity,
    SpecialSupportedFilterKeys,
)
from sagemaker.core.jumpstart.filters import (
    Constant,
    ModelFilter,
    Operator,
    evaluate_filter_expression,
)
from sagemaker.core.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.core.jumpstart.utils import (
    get_jumpstart_content_bucket,
    get_region_fallback,
    get_sagemaker_version,
    verify_model_region_and_return_specs,
    validate_model_id_and_get_type,
)
from sagemaker.core.helper.session_helper import Session

MAX_SEARCH_WORKERS = min(32, (os.cpu_count() or 4) * 2)


def _compare_model_version_tuples(  # pylint: disable=too-many-return-statements
    model_version_1: Optional[Tuple[str, str]] = None,
    model_version_2: Optional[Tuple[str, str]] = None,
) -> int:
    """Performs comparison of sdk specs paths, in order to sort them.

    Args:
        model_version_1 (Tuple[str, str]): The first model ID and version tuple to compare.
        model_version_2 (Tuple[str, str]): The second model ID and version tuple to compare.
    """
    if model_version_1 is None or model_version_2 is None:
        if model_version_2 is not None:
            return -1
        if model_version_1 is not None:
            return 1
        return 0

    model_id_1, version_1 = model_version_1

    model_id_2, version_2 = model_version_2

    if model_id_1 < model_id_2:
        return -1

    if model_id_2 < model_id_1:
        return 1

    if Version(version_1) < Version(version_2):
        return 1

    if Version(version_2) < Version(version_1):
        return -1

    return 0


def _model_filter_in_operator_generator(filter_operator: Operator) -> Generator:
    """Generator for model filters in an operator."""
    for operator in filter_operator:
        if isinstance(operator.unresolved_value, ModelFilter):
            yield operator


def _put_resolved_booleans_into_filter(
    filter_operator: Operator, model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues]
) -> None:
    """Iterate over the operators in the filter, assign resolved value if found in second arg.

    If not found, assigns ``UNKNOWN``.
    """
    for operator in _model_filter_in_operator_generator(filter_operator):
        model_filter = operator.unresolved_value
        operator.resolved_value = model_filters_to_resolved_values.get(
            model_filter, BooleanValues.UNKNOWN
        )


def _populate_model_filters_to_resolved_values(
    manifest_specs_cached_values: Dict[str, Any],
    model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues],
    model_filters: Operator,
) -> None:
    """Iterate over the model filters, if the filter key has a cached value, evaluate the filter.

    The resolved filter values are placed in ``model_filters_to_resolved_values``.
    """
    for model_filter in model_filters:
        if model_filter.key in manifest_specs_cached_values:
            cached_model_value = manifest_specs_cached_values[model_filter.key]
            evaluated_expression: BooleanValues = evaluate_filter_expression(
                model_filter, cached_model_value
            )
            model_filters_to_resolved_values[model_filter] = evaluated_expression


[docs] def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: """Parse the model ID, return a tuple framework, task, rest-of-id. Args: model_id (str): The model ID for which to extract the framework/task/model. """ _id_parts = model_id.split("-") if len(_id_parts) < 3: return "", "", "" framework = _id_parts[0] task = _id_parts[1] name = "-".join(_id_parts[2:]) return framework, task, name
[docs] def extract_model_type_filter_representation(spec_key: str) -> str: """Parses model spec key, determine if the model is proprietary or open weight. Args: spek_key (str): The model spec key for which to extract the model type. """ model_spec_prefix = spec_key.split("/")[0] if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX: return JumpStartModelType.PROPRIETARY.value return JumpStartModelType.OPEN_WEIGHTS.value
[docs] def list_jumpstart_tasks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List tasks for JumpStart, and optionally apply filters to result. Args: filter (Union[Operator, str]): Optional. The filter to apply to list tasks. This can be either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), or simply a string filter which will get serialized into an Identity filter. (e.g. ``"task == ic"``). If this argument is not supplied, all tasks will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ region = region or get_region_fallback( sagemaker_session=sagemaker_session, ) tasks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( filter=filter, region=region, sagemaker_session=sagemaker_session, model_type=JumpStartModelType.OPEN_WEIGHTS, ): _, task, _ = extract_framework_task_model(model_id) tasks.add(task) return sorted(list(tasks))
[docs] def list_jumpstart_frameworks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List frameworks for JumpStart, and optionally apply filters to result. Args: filter (Union[Operator, str]): Optional. The filter to apply to list frameworks. This can be either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), or simply a string filter which will get serialized into an Identity filter. (eg. ``"task == ic"``). If this argument is not supplied, all frameworks will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ region = region or get_region_fallback( sagemaker_session=sagemaker_session, ) frameworks: Set[str] = set() for model_id, _ in _generate_jumpstart_model_versions( filter=filter, region=region, sagemaker_session=sagemaker_session, model_type=JumpStartModelType.OPEN_WEIGHTS, ): framework, _, _ = extract_framework_task_model(model_id) frameworks.add(framework) return sorted(list(frameworks))
[docs] def list_jumpstart_scripts( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """List scripts for JumpStart, and optionally apply filters to result. Args: filter (Union[Operator, str]): Optional. The filter to apply to list scripts. This can be either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), or simply a string filter which will get serialized into an Identity filter. (e.g. ``"task == ic"``). If this argument is not supplied, all scripts will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding models. (Default: None). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ region = region or get_region_fallback( sagemaker_session=sagemaker_session, ) if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or ( isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower() ): return sorted([e.value for e in JumpStartScriptScope]) scripts: Set[str] = set() for model_id, version in _generate_jumpstart_model_versions( filter=filter, region=region, sagemaker_session=sagemaker_session, model_type=JumpStartModelType.OPEN_WEIGHTS, ): scripts.add(JumpStartScriptScope.INFERENCE) model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, version=version, sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, ) if model_specs.training_supported: scripts.add(JumpStartScriptScope.TRAINING) if scripts == {e.value for e in JumpStartScriptScope}: break return sorted(list(scripts))
def _is_valid_version(version: str) -> bool: """Checks if the version is convertable to Version class.""" try: Version(version) return True except Exception: # pylint: disable=broad-except return False
[docs] def list_jumpstart_models( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: Optional[str] = None, list_incomplete_models: bool = False, list_old_models: bool = False, list_versions: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[Union[Tuple[str], Tuple[str, str]]]: """List models for JumpStart, and optionally apply filters to result. Args: filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), or simply a string filter which will get serialized into an Identity filter. (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding models. (Default: None). list_incomplete_models (bool): Optional. If a model does not contain metadata fields requested by the filter, and the filter cannot be resolved to a include/not include, whether the model should be included. By default, these models are omitted from results. (Default: False). list_old_models (bool): Optional. If there are older versions of a model, whether the older versions should be included in the returned result. (Default: False). list_versions (bool): Optional. True if versions for models should be returned in addition to the id of the model. (Default: False). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ region = region or get_region_fallback( sagemaker_session=sagemaker_session, ) model_id_version_dict: Dict[str, List[str]] = dict() for model_id, version in _generate_jumpstart_model_versions( filter=filter, region=region, list_incomplete_models=list_incomplete_models, sagemaker_session=sagemaker_session, ): if model_id not in model_id_version_dict: model_id_version_dict[model_id] = list() model_version = Version(version) if _is_valid_version(version) else version model_id_version_dict[model_id].append(model_version) if not list_versions: return sorted(list(model_id_version_dict.keys())) if not list_old_models: for model_id, versions in model_id_version_dict.items(): try: model_id_version_dict.update({model_id: set([max(versions)])}) except TypeError: versions = [str(v) for v in versions] model_id_version_dict.update({model_id: set([max(versions)])}) model_id_version_set: Set[Tuple[str, str]] = set() for model_id in model_id_version_dict: for version in model_id_version_dict[model_id]: model_id_version_set.add((model_id, str(version))) return sorted(list(model_id_version_set), key=cmp_to_key(_compare_model_version_tuples))
def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: Optional[str] = None, list_incomplete_models: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: Optional[JumpStartModelType] = None, ) -> Generator: """Generate models for JumpStart, and optionally apply filters to result. Args: filter (Union[Operator, str]): Optional. The filter to apply to generate models. This can be either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), or simply a string filter which will get serialized into an Identity filter. (e.g. ``"task == ic"``). If this argument is not supplied, all models will be generated. (Default: Constant(BooleanValues.TRUE)). region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding models. (Default: None). list_incomplete_models (bool): Optional. If a model does not contain metadata fields requested by the filter, and the filter cannot be resolved to a include/not include, whether the model should be included. By default, these models are omitted from results. (Default: False). sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ region = region or get_region_fallback( sagemaker_session=sagemaker_session, ) prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( region=region, s3_client=sagemaker_session.s3_client, model_type=JumpStartModelType.PROPRIETARY, ) open_weight_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( region=region, s3_client=sagemaker_session.s3_client, model_type=JumpStartModelType.OPEN_WEIGHTS, ) models_manifest_list = ( open_weight_manifest_list if model_type == JumpStartModelType.OPEN_WEIGHTS else ( prop_models_manifest_list if model_type == JumpStartModelType.PROPRIETARY else open_weight_manifest_list + prop_models_manifest_list ) ) if isinstance(filter, str): filter = Identity(filter) manifest_keys = set( open_weight_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__ ) all_keys: Set[str] = set() model_filters: Set[ModelFilter] = set() for operator in _model_filter_in_operator_generator(filter): model_filter = operator.unresolved_value key = model_filter.key all_keys.add(key) if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in { identifier.value for identifier in ProprietaryModelFilterIdentifiers }: model_filter.set_value(JumpStartModelType.PROPRIETARY.value) model_filters.add(model_filter) for key in all_keys: if "." in key: raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').") metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS required_manifest_keys = manifest_keys.intersection(metadata_filter_keys) possible_spec_keys = metadata_filter_keys - manifest_keys is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys is_model_type_filter = SpecialSupportedFilterKeys.MODEL_TYPE in all_keys def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]: copied_filter = copy.deepcopy(filter) manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {} model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {} for val in required_manifest_keys: manifest_specs_cached_values[val] = getattr(model_manifest, val) if is_task_filter: manifest_specs_cached_values[SpecialSupportedFilterKeys.TASK] = ( extract_framework_task_model(model_manifest.model_id)[1] ) if is_framework_filter: manifest_specs_cached_values[SpecialSupportedFilterKeys.FRAMEWORK] = ( extract_framework_task_model(model_manifest.model_id)[0] ) if is_model_type_filter: manifest_specs_cached_values[SpecialSupportedFilterKeys.MODEL_TYPE] = ( extract_model_type_filter_representation(model_manifest.spec_key) ) if Version(model_manifest.min_version) > Version(get_sagemaker_version()): return None _populate_model_filters_to_resolved_values( manifest_specs_cached_values, model_filters_to_resolved_values, model_filters, ) _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values) copied_filter.eval() if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]: if copied_filter.resolved_value == BooleanValues.TRUE: return (model_manifest.model_id, model_manifest.version) return None if copied_filter.resolved_value == BooleanValues.UNEVALUATED: raise RuntimeError( "Filter expression in unevaluated state after using " "values from model manifest. Model ID and version that " f"is failing: {(model_manifest.model_id, model_manifest.version)}." ) copied_filter_2 = copy.deepcopy(filter) # spec is downloaded to thread's memory. since each thread # accesses a unique s3 spec, there is no need to use the JS caching utils. # spec only stays in memory for lifecycle of thread. model_specs = JumpStartModelSpecs( json.loads( sagemaker_session.read_s3_file( get_jumpstart_content_bucket(region), model_manifest.spec_key ) ) ) for val in possible_spec_keys: if hasattr(model_specs, val): manifest_specs_cached_values[val] = getattr(model_specs, val) _populate_model_filters_to_resolved_values( manifest_specs_cached_values, model_filters_to_resolved_values, model_filters, ) _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values) copied_filter_2.eval() if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED: if copied_filter_2.resolved_value == BooleanValues.TRUE or ( BooleanValues.UNKNOWN and list_incomplete_models ): return (model_manifest.model_id, model_manifest.version) return None raise RuntimeError( "Filter expression in unevaluated state after using values from model specs. " "Model ID and version that is failing: " f"{(model_manifest.model_id, model_manifest.version)}." ) with ThreadPoolExecutor(max_workers=MAX_SEARCH_WORKERS) as executor: futures = [] for header in models_manifest_list: futures.append(executor.submit(evaluate_model, header)) for future in as_completed(futures): error = future.exception() if error: raise error result = future.result() if result: yield result
[docs] def get_model_url( model_id: str, model_version: str, region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, config_name: Optional[str] = None, ) -> str: """Retrieve web url describing pretrained model. Args: model_id (str): The model ID for which to retrieve the url. model_version (str): The model version for which to retrieve the url. region (str): Optional. The region from which to retrieve metadata. (Default: None) sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to retrieve the model url. """ model_type = validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region, sagemaker_session=sagemaker_session, ) region = region or get_region_fallback( sagemaker_session=sagemaker_session, ) model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, version=model_version, sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, model_type=model_type, config_name=config_name, ) return model_specs.url