# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores notebook utils related to SageMaker JumpStart."""
from __future__ import absolute_import
import copy
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import cmp_to_key
import json
import os
from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict
from packaging.version import Version
from sagemaker.core.jumpstart import accessors
from sagemaker.core.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
PROPRIETARY_MODEL_SPEC_PREFIX,
)
from sagemaker.core.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
from sagemaker.core.jumpstart.filters import (
SPECIAL_SUPPORTED_FILTER_KEYS,
ProprietaryModelFilterIdentifiers,
BooleanValues,
Identity,
SpecialSupportedFilterKeys,
)
from sagemaker.core.jumpstart.filters import (
Constant,
ModelFilter,
Operator,
evaluate_filter_expression,
)
from sagemaker.core.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.core.jumpstart.utils import (
get_jumpstart_content_bucket,
get_region_fallback,
get_sagemaker_version,
verify_model_region_and_return_specs,
validate_model_id_and_get_type,
)
from sagemaker.core.helper.session_helper import Session
MAX_SEARCH_WORKERS = min(32, (os.cpu_count() or 4) * 2)
def _compare_model_version_tuples( # pylint: disable=too-many-return-statements
model_version_1: Optional[Tuple[str, str]] = None,
model_version_2: Optional[Tuple[str, str]] = None,
) -> int:
"""Performs comparison of sdk specs paths, in order to sort them.
Args:
model_version_1 (Tuple[str, str]): The first model ID and version tuple to compare.
model_version_2 (Tuple[str, str]): The second model ID and version tuple to compare.
"""
if model_version_1 is None or model_version_2 is None:
if model_version_2 is not None:
return -1
if model_version_1 is not None:
return 1
return 0
model_id_1, version_1 = model_version_1
model_id_2, version_2 = model_version_2
if model_id_1 < model_id_2:
return -1
if model_id_2 < model_id_1:
return 1
if Version(version_1) < Version(version_2):
return 1
if Version(version_2) < Version(version_1):
return -1
return 0
def _model_filter_in_operator_generator(filter_operator: Operator) -> Generator:
"""Generator for model filters in an operator."""
for operator in filter_operator:
if isinstance(operator.unresolved_value, ModelFilter):
yield operator
def _put_resolved_booleans_into_filter(
filter_operator: Operator, model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues]
) -> None:
"""Iterate over the operators in the filter, assign resolved value if found in second arg.
If not found, assigns ``UNKNOWN``.
"""
for operator in _model_filter_in_operator_generator(filter_operator):
model_filter = operator.unresolved_value
operator.resolved_value = model_filters_to_resolved_values.get(
model_filter, BooleanValues.UNKNOWN
)
def _populate_model_filters_to_resolved_values(
manifest_specs_cached_values: Dict[str, Any],
model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues],
model_filters: Operator,
) -> None:
"""Iterate over the model filters, if the filter key has a cached value, evaluate the filter.
The resolved filter values are placed in ``model_filters_to_resolved_values``.
"""
for model_filter in model_filters:
if model_filter.key in manifest_specs_cached_values:
cached_model_value = manifest_specs_cached_values[model_filter.key]
evaluated_expression: BooleanValues = evaluate_filter_expression(
model_filter, cached_model_value
)
model_filters_to_resolved_values[model_filter] = evaluated_expression
[docs]
def list_jumpstart_tasks( # pylint: disable=redefined-builtin
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
region: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> List[str]:
"""List tasks for JumpStart, and optionally apply filters to result.
Args:
filter (Union[Operator, str]): Optional. The filter to apply to list tasks. This can be
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
or simply a string filter which will get serialized into an Identity filter.
(e.g. ``"task == ic"``). If this argument is not supplied, all tasks will be listed.
(Default: Constant(BooleanValues.TRUE)).
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
models. (Default: None).
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
"""
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)
tasks: Set[str] = set()
for model_id, _ in _generate_jumpstart_model_versions(
filter=filter,
region=region,
sagemaker_session=sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
):
_, task, _ = extract_framework_task_model(model_id)
tasks.add(task)
return sorted(list(tasks))
[docs]
def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
region: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> List[str]:
"""List frameworks for JumpStart, and optionally apply filters to result.
Args:
filter (Union[Operator, str]): Optional. The filter to apply to list frameworks. This can be
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
or simply a string filter which will get serialized into an Identity filter.
(eg. ``"task == ic"``). If this argument is not supplied, all frameworks will be listed.
(Default: Constant(BooleanValues.TRUE)).
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
models. (Default: None).
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
"""
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)
frameworks: Set[str] = set()
for model_id, _ in _generate_jumpstart_model_versions(
filter=filter,
region=region,
sagemaker_session=sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
):
framework, _, _ = extract_framework_task_model(model_id)
frameworks.add(framework)
return sorted(list(frameworks))
[docs]
def list_jumpstart_scripts( # pylint: disable=redefined-builtin
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
region: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> List[str]:
"""List scripts for JumpStart, and optionally apply filters to result.
Args:
filter (Union[Operator, str]): Optional. The filter to apply to list scripts. This can be
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
or simply a string filter which will get serialized into an Identity filter.
(e.g. ``"task == ic"``). If this argument is not supplied, all scripts will be listed.
(Default: Constant(BooleanValues.TRUE)).
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
models. (Default: None).
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
"""
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)
if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or (
isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower()
):
return sorted([e.value for e in JumpStartScriptScope])
scripts: Set[str] = set()
for model_id, version in _generate_jumpstart_model_versions(
filter=filter,
region=region,
sagemaker_session=sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
):
scripts.add(JumpStartScriptScope.INFERENCE)
model_specs = verify_model_region_and_return_specs(
region=region,
model_id=model_id,
version=version,
sagemaker_session=sagemaker_session,
scope=JumpStartScriptScope.INFERENCE,
)
if model_specs.training_supported:
scripts.add(JumpStartScriptScope.TRAINING)
if scripts == {e.value for e in JumpStartScriptScope}:
break
return sorted(list(scripts))
def _is_valid_version(version: str) -> bool:
"""Checks if the version is convertable to Version class."""
try:
Version(version)
return True
except Exception: # pylint: disable=broad-except
return False
[docs]
def list_jumpstart_models( # pylint: disable=redefined-builtin
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
region: Optional[str] = None,
list_incomplete_models: bool = False,
list_old_models: bool = False,
list_versions: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> List[Union[Tuple[str], Tuple[str, str]]]:
"""List models for JumpStart, and optionally apply filters to result.
Args:
filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
or simply a string filter which will get serialized into an Identity filter.
(e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed.
(Default: Constant(BooleanValues.TRUE)).
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
models. (Default: None).
list_incomplete_models (bool): Optional. If a model does not contain metadata fields
requested by the filter, and the filter cannot be resolved to a include/not include,
whether the model should be included. By default, these models are omitted from results.
(Default: False).
list_old_models (bool): Optional. If there are older versions of a model, whether the older
versions should be included in the returned result. (Default: False).
list_versions (bool): Optional. True if versions for models should be returned in addition
to the id of the model. (Default: False).
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
"""
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)
model_id_version_dict: Dict[str, List[str]] = dict()
for model_id, version in _generate_jumpstart_model_versions(
filter=filter,
region=region,
list_incomplete_models=list_incomplete_models,
sagemaker_session=sagemaker_session,
):
if model_id not in model_id_version_dict:
model_id_version_dict[model_id] = list()
model_version = Version(version) if _is_valid_version(version) else version
model_id_version_dict[model_id].append(model_version)
if not list_versions:
return sorted(list(model_id_version_dict.keys()))
if not list_old_models:
for model_id, versions in model_id_version_dict.items():
try:
model_id_version_dict.update({model_id: set([max(versions)])})
except TypeError:
versions = [str(v) for v in versions]
model_id_version_dict.update({model_id: set([max(versions)])})
model_id_version_set: Set[Tuple[str, str]] = set()
for model_id in model_id_version_dict:
for version in model_id_version_dict[model_id]:
model_id_version_set.add((model_id, str(version)))
return sorted(list(model_id_version_set), key=cmp_to_key(_compare_model_version_tuples))
def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
region: Optional[str] = None,
list_incomplete_models: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: Optional[JumpStartModelType] = None,
) -> Generator:
"""Generate models for JumpStart, and optionally apply filters to result.
Args:
filter (Union[Operator, str]): Optional. The filter to apply to generate models. This can be
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
or simply a string filter which will get serialized into an Identity filter.
(e.g. ``"task == ic"``). If this argument is not supplied, all models will be generated.
(Default: Constant(BooleanValues.TRUE)).
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
models. (Default: None).
list_incomplete_models (bool): Optional. If a model does not contain metadata fields
requested by the filter, and the filter cannot be resolved to a include/not include,
whether the model should be included. By default, these models are omitted from
results. (Default: False).
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
"""
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)
prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
region=region,
s3_client=sagemaker_session.s3_client,
model_type=JumpStartModelType.PROPRIETARY,
)
open_weight_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
region=region,
s3_client=sagemaker_session.s3_client,
model_type=JumpStartModelType.OPEN_WEIGHTS,
)
models_manifest_list = (
open_weight_manifest_list
if model_type == JumpStartModelType.OPEN_WEIGHTS
else (
prop_models_manifest_list
if model_type == JumpStartModelType.PROPRIETARY
else open_weight_manifest_list + prop_models_manifest_list
)
)
if isinstance(filter, str):
filter = Identity(filter)
manifest_keys = set(
open_weight_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__
)
all_keys: Set[str] = set()
model_filters: Set[ModelFilter] = set()
for operator in _model_filter_in_operator_generator(filter):
model_filter = operator.unresolved_value
key = model_filter.key
all_keys.add(key)
if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in {
identifier.value for identifier in ProprietaryModelFilterIdentifiers
}:
model_filter.set_value(JumpStartModelType.PROPRIETARY.value)
model_filters.add(model_filter)
for key in all_keys:
if "." in key:
raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').")
metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS
required_manifest_keys = manifest_keys.intersection(metadata_filter_keys)
possible_spec_keys = metadata_filter_keys - manifest_keys
is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys
is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys
is_model_type_filter = SpecialSupportedFilterKeys.MODEL_TYPE in all_keys
def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]:
copied_filter = copy.deepcopy(filter)
manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {}
model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {}
for val in required_manifest_keys:
manifest_specs_cached_values[val] = getattr(model_manifest, val)
if is_task_filter:
manifest_specs_cached_values[SpecialSupportedFilterKeys.TASK] = (
extract_framework_task_model(model_manifest.model_id)[1]
)
if is_framework_filter:
manifest_specs_cached_values[SpecialSupportedFilterKeys.FRAMEWORK] = (
extract_framework_task_model(model_manifest.model_id)[0]
)
if is_model_type_filter:
manifest_specs_cached_values[SpecialSupportedFilterKeys.MODEL_TYPE] = (
extract_model_type_filter_representation(model_manifest.spec_key)
)
if Version(model_manifest.min_version) > Version(get_sagemaker_version()):
return None
_populate_model_filters_to_resolved_values(
manifest_specs_cached_values,
model_filters_to_resolved_values,
model_filters,
)
_put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values)
copied_filter.eval()
if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]:
if copied_filter.resolved_value == BooleanValues.TRUE:
return (model_manifest.model_id, model_manifest.version)
return None
if copied_filter.resolved_value == BooleanValues.UNEVALUATED:
raise RuntimeError(
"Filter expression in unevaluated state after using "
"values from model manifest. Model ID and version that "
f"is failing: {(model_manifest.model_id, model_manifest.version)}."
)
copied_filter_2 = copy.deepcopy(filter)
# spec is downloaded to thread's memory. since each thread
# accesses a unique s3 spec, there is no need to use the JS caching utils.
# spec only stays in memory for lifecycle of thread.
model_specs = JumpStartModelSpecs(
json.loads(
sagemaker_session.read_s3_file(
get_jumpstart_content_bucket(region), model_manifest.spec_key
)
)
)
for val in possible_spec_keys:
if hasattr(model_specs, val):
manifest_specs_cached_values[val] = getattr(model_specs, val)
_populate_model_filters_to_resolved_values(
manifest_specs_cached_values,
model_filters_to_resolved_values,
model_filters,
)
_put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values)
copied_filter_2.eval()
if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED:
if copied_filter_2.resolved_value == BooleanValues.TRUE or (
BooleanValues.UNKNOWN and list_incomplete_models
):
return (model_manifest.model_id, model_manifest.version)
return None
raise RuntimeError(
"Filter expression in unevaluated state after using values from model specs. "
"Model ID and version that is failing: "
f"{(model_manifest.model_id, model_manifest.version)}."
)
with ThreadPoolExecutor(max_workers=MAX_SEARCH_WORKERS) as executor:
futures = []
for header in models_manifest_list:
futures.append(executor.submit(evaluate_model, header))
for future in as_completed(futures):
error = future.exception()
if error:
raise error
result = future.result()
if result:
yield result
[docs]
def get_model_url(
model_id: str,
model_version: str,
region: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
) -> str:
"""Retrieve web url describing pretrained model.
Args:
model_id (str): The model ID for which to retrieve the url.
model_version (str): The model version for which to retrieve the url.
region (str): Optional. The region from which to retrieve metadata.
(Default: None)
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
to retrieve the model url.
"""
model_type = validate_model_id_and_get_type(
model_id=model_id,
model_version=model_version,
region=region,
sagemaker_session=sagemaker_session,
)
region = region or get_region_fallback(
sagemaker_session=sagemaker_session,
)
model_specs = verify_model_region_and_return_specs(
region=region,
model_id=model_id,
version=model_version,
sagemaker_session=sagemaker_session,
scope=JumpStartScriptScope.INFERENCE,
model_type=model_type,
config_name=config_name,
)
return model_specs.url