Source code for sagemaker.serve.inference_recommendation_mixin

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Inference Recommender mixin for SageMaker model optimization.

This module provides the _InferenceRecommenderMixin class that enables SageMaker models
to use Inference Recommender for right-sizing and optimization recommendations.

Key Features:
- Automatic instance type and configuration recommendations
- Load testing with custom traffic patterns
- Performance optimization based on latency and throughput requirements
- Support for both Default and Advanced recommendation jobs

Example:
    Basic usage with a ModelBuilder::
    
        model_builder = ModelBuilder(model="my-model")
        model = model_builder.build()
        
        # Get right-sizing recommendations
        model.right_size(
            sample_payload_url="s3://my-bucket/sample-payload.json",
            supported_content_types=["application/json"],
            supported_instance_types=["ml.m5.large", "ml.m5.xlarge"]
        )
        
        # Deploy with recommendations
        predictor = model.deploy()
"""
from __future__ import absolute_import

# Standard library imports
import logging
import re
from typing import Any, Dict, List, Optional, Tuple

# SageMaker imports
from sagemaker.core.parameter import CategoricalParameter

# ========================================
# Constants
# ========================================

INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
    "xgboost": "XGBOOST",
    "sklearn": "SAGEMAKER-SCIKIT-LEARN", 
    "pytorch": "PYTORCH",
    "tensorflow": "TENSORFLOW",
    "mxnet": "MXNET",
}

# Setting LOGGER for backward compatibility, in case users import it
logger = LOGGER = logging.getLogger("sagemaker")


[docs] class Phase: """Traffic pattern phase configuration for Advanced Inference Recommendations. Defines a phase of load testing with specific duration, user count, and spawn rate. Multiple phases can be combined to create complex traffic patterns. Args: duration_in_seconds: How long this phase should run initial_number_of_users: Number of concurrent users at start of phase spawn_rate: Rate at which new users are added (users per second) Example: Create a ramp-up phase:: phase = Phase( duration_in_seconds=300, # 5 minutes initial_number_of_users=1, spawn_rate=2 # Add 2 users per second ) """ def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int) -> None: """Initialize a Phase for load testing. Args: duration_in_seconds: Duration of this phase in seconds initial_number_of_users: Starting number of concurrent users spawn_rate: Rate of adding new users (users per second) """ self.to_json = { "DurationInSeconds": duration_in_seconds, "InitialNumberOfUsers": initial_number_of_users, "SpawnRate": spawn_rate, }
[docs] class ModelLatencyThreshold: """Latency threshold configuration for Advanced Inference Recommendations. Defines acceptable response latency limits for model inference. Used to filter recommendations based on performance requirements. Args: percentile: Latency percentile to measure (e.g., "P95", "P99") value_in_milliseconds: Maximum acceptable latency in milliseconds Example: Set P95 latency threshold:: threshold = ModelLatencyThreshold( percentile="P95", value_in_milliseconds=100 # 100ms max P95 latency ) """ def __init__(self, percentile: str, value_in_milliseconds: int) -> None: """Initialize a ModelLatencyThreshold. Args: percentile: Latency percentile (e.g., "P95", "P99") value_in_milliseconds: Maximum latency threshold in milliseconds """ self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds}
class _InferenceRecommenderMixin: """Mixin class providing SageMaker Inference Recommender functionality. This mixin adds right-sizing capabilities to SageMaker models, enabling automatic instance type and configuration recommendations based on model performance requirements. The mixin provides: - Automatic framework detection from container images - Default and Advanced recommendation job types - Load testing with custom traffic patterns - Performance-based filtering and optimization This class is designed to be mixed into Model classes that have: - sagemaker_session: SageMaker session for API calls - role_arn: IAM role for job execution - model_name: Name of the model - image_uri: Container image URI (optional, for framework detection) """ def right_size( self, sample_payload_url: Optional[str] = None, supported_content_types: Optional[List[str]] = None, supported_instance_types: Optional[List[str]] = None, job_name: Optional[str] = None, framework: Optional[str] = None, framework_version: Optional[str] = None, job_duration_in_seconds: Optional[int] = None, hyperparameter_ranges: Optional[List[Dict[str, CategoricalParameter]]] = None, phases: Optional[List[Phase]] = None, traffic_type: Optional[str] = None, max_invocations: Optional[int] = None, model_latency_thresholds: Optional[List[ModelLatencyThreshold]] = None, max_tests: Optional[int] = None, max_parallel_tests: Optional[int] = None, log_level: Optional[str] = "Verbose", ) -> "_InferenceRecommenderMixin": """Recommends an instance type for a SageMaker or BYOC model. Create a SageMaker ``Model`` or use a registered ``ModelPackage``, to start an Inference Recommender job. The name of the created model is accessible in the ``name`` field of this ``Model`` after right_size returns. Args: sample_payload_url (str): The S3 path where the sample payload is stored. supported_content_types: (list[str]): The supported MIME types for the input data. supported_instance_types (list[str]): A list of the instance types that this model is expected to work on. (default: None). job_name (str): The name of the Inference Recommendations Job. (default: None). framework (str): The machine learning framework of the Image URI. Only required to specify if you bring your own custom containers (default: None). job_duration_in_seconds (int): The maximum job duration that a job can run for. (default: None). hyperparameter_ranges (list[Dict[str, sagemaker.parameter.CategoricalParameter]]): Specifies the hyper parameters to be used during endpoint load tests. `instance_type` must be specified as a hyperparameter range. `env_vars` can be specified as an optional hyperparameter range. (default: None). Example:: hyperparameter_ranges = [{ 'instance_types': CategoricalParameter(['ml.c5.xlarge', 'ml.c5.2xlarge']), 'OMP_NUM_THREADS': CategoricalParameter(['1', '2', '3', '4']) }] phases (list[Phase]): Shape of the traffic pattern to use in the load test (default: None). traffic_type (str): Specifies the traffic pattern type. Currently only supports one type 'PHASES' (default: None). max_invocations (str): defines the minimum invocations per minute for the endpoint to support (default: None). model_latency_thresholds (list[ModelLatencyThreshold]): defines the maximum response latency for endpoints to support (default: None). max_tests (int): restricts how many endpoints in total are allowed to be spun up for this job (default: None). max_parallel_tests (int): restricts how many concurrent endpoints this job is allowed to spin up (default: None). log_level (str): specifies the inline output when waiting for right_size to complete (default: "Verbose"). Returns: sagemaker.model.Model: A SageMaker ``Model`` object. See :func:`~sagemaker.model.Model` for full details. """ # Auto-detect framework from image URI if not provided if not framework and hasattr(self, 'image_uri'): detected_framework, detected_version = self._extract_framework_from_image_uri() if detected_framework: # Convert framework enum to string if needed framework_str = getattr(detected_framework, 'value', str(detected_framework)).lower() framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING.get( framework_str, str(detected_framework) ) framework_version = framework_version or detected_version endpoint_configurations = self._convert_to_endpoint_configurations_json( hyperparameter_ranges=hyperparameter_ranges ) traffic_pattern = self._convert_to_traffic_pattern_json( traffic_type=traffic_type, phases=phases ) stopping_conditions = self._convert_to_stopping_conditions_json( max_invocations=max_invocations, model_latency_thresholds=model_latency_thresholds ) resource_limit = self._convert_to_resource_limit_json( max_tests=max_tests, max_parallel_tests=max_parallel_tests ) # Determine job type based on advanced parameters if endpoint_configurations or traffic_pattern or stopping_conditions or resource_limit: logger.info("Advanced job parameters specified. Running Advanced recommendation job...") job_type = "Advanced" else: logger.info("No advanced parameters specified. Running Default recommendation job...") job_type = "Default" # Initialize SageMaker session if needed (method from ModelBuilder mixin) if hasattr(self, '_init_sagemaker_session_if_does_not_exist'): self._init_sagemaker_session_if_does_not_exist() # Create inference recommendations job ret_name = self.sagemaker_session.create_inference_recommendations_job( role=getattr(self, 'role_arn', None), job_name=job_name, job_type=job_type, job_duration_in_seconds=job_duration_in_seconds, model_name=getattr(self, 'model_name', None), model_package_version_arn=getattr(self, "model_package_arn", None), framework=framework, framework_version=framework_version, sample_payload_url=sample_payload_url, supported_content_types=supported_content_types, supported_instance_types=supported_instance_types, endpoint_configurations=endpoint_configurations, traffic_pattern=traffic_pattern, stopping_conditions=stopping_conditions, resource_limit=resource_limit, ) # Wait for job completion and store results self.inference_recommender_job_results = ( self.sagemaker_session.wait_for_inference_recommendations_job( ret_name, log_level=log_level ) ) self.inference_recommendations = self.inference_recommender_job_results.get( "InferenceRecommendations", [] ) return self def _update_params(self, **kwargs) -> Optional[Tuple[str, int]]: """Update deployment parameters based on inference recommendations. Processes inference recommendation ID or right-size results to determine optimal instance type and count for model deployment. Args: **kwargs: Deployment parameters including instance_type, initial_instance_count, inference_recommendation_id, etc. Returns: Tuple of (instance_type, initial_instance_count) if recommendations found, otherwise None to use provided parameters. """ instance_type = kwargs.get("instance_type") initial_instance_count = kwargs.get("initial_instance_count") accelerator_type = kwargs.get("accelerator_type") async_inference_config = kwargs.get("async_inference_config") serverless_inference_config = kwargs.get("serverless_inference_config") explainer_config = kwargs.get("explainer_config") inference_recommendation_id = kwargs.get("inference_recommendation_id") inference_recommender_job_results = kwargs.get("inference_recommender_job_results") inference_recommendation = None if inference_recommendation_id is not None: inference_recommendation = self._update_params_for_recommendation_id( instance_type=instance_type, initial_instance_count=initial_instance_count, accelerator_type=accelerator_type, async_inference_config=async_inference_config, serverless_inference_config=serverless_inference_config, inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, ) elif inference_recommender_job_results is not None: inference_recommendation = self._update_params_for_right_size( instance_type, initial_instance_count, accelerator_type, serverless_inference_config, async_inference_config, explainer_config, ) return ( inference_recommendation if inference_recommendation else (instance_type, initial_instance_count) ) def _update_params_for_right_size( self, instance_type: Optional[str] = None, initial_instance_count: Optional[int] = None, accelerator_type: Optional[str] = None, serverless_inference_config: Optional[Any] = None, async_inference_config: Optional[Any] = None, explainer_config: Optional[Any] = None, ) -> Optional[Tuple[str, int]]: """Validates that Inference Recommendation parameters can be used in `model.deploy()` Args: instance_type (str): The initial number of instances to run in the ``Endpoint`` created from this ``Model``. If not using serverless inference or the model has not called ``right_size()``, then it need to be a number larger or equals to 1 (default: None) initial_instance_count (int):The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge', or 'local' for local mode. If not using serverless inference or the model has not called ``right_size()``, then it is required to deploy a model. (default: None) accelerator_type (str): whether accelerator_type has been passed into `model.deploy()`. serverless_inference_config (sagemaker.serve.serverless.ServerlessInferenceConfig)): whether serverless_inference_config has been passed into `model.deploy()`. async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): whether async_inference_config has been passed into `model.deploy()`. explainer_config (sagemaker.explainer.ExplainerConfig): whether explainer_config has been passed into `model.deploy()`. Returns: (string, int) or None: Top instance_type and associated initial_instance_count if self.inference_recommender_job_results has been generated. Otherwise, return None. """ if accelerator_type: raise ValueError("accelerator_type is not compatible with right_size().") if instance_type or initial_instance_count: logger.warning( "instance_type or initial_instance_count specified." "Overriding right_size() recommendations." ) return None if async_inference_config: logger.warning( "async_inference_config is specified. Overriding right_size() recommendations." ) return None if serverless_inference_config: logger.warning( "serverless_inference_config is specified. Overriding right_size() recommendations." ) return None if explainer_config: logger.warning( "explainer_config is specified. Overriding right_size() recommendations." ) return None return self._filter_recommendations_for_realtime() def _update_params_for_recommendation_id( self, instance_type: Optional[str], initial_instance_count: Optional[int], accelerator_type: Optional[str], async_inference_config: Optional[Any], serverless_inference_config: Optional[Any], inference_recommendation_id: str, explainer_config: Optional[Any], ) -> Tuple[str, int]: """Update parameters with inference recommendation results. Args: instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge', or 'local' for local mode. If not using serverless inference, then it is required to deploy a model. initial_instance_count (int): The initial number of instances to run in the ``Endpoint`` created from this ``Model``. If not using serverless inference, then it need to be a number larger or equals to 1. accelerator_type (str): Type of Elastic Inference accelerator to deploy this model for model loading and inference, for example, 'ml.eia1.medium'. If not specified, no Elastic Inference accelerator will be attached to the endpoint. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies configuration related to async endpoint. Use this configuration when trying to create async endpoint and make async inference. If empty config object passed through, will use default config to deploy async endpoint. Deploy a real-time endpoint if it's None. serverless_inference_config (sagemaker.serve.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Use this configuration when trying to create serverless endpoint and make serverless inference. If empty object passed through, will use pre-defined values in ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an instance based endpoint if it's None. inference_recommendation_id (str): The recommendation id which specifies the recommendation you picked from inference recommendation job results and would like to deploy the model and endpoint with recommended parameters. explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability configuration for use with Amazon SageMaker Clarify. Default: None. Raises: ValueError: If arguments combination check failed in these circumstances: - If only one of instance type or instance count specified or - If recommendation id does not follow the required format or - If recommendation id is not valid or - If inference recommendation id is specified along with incompatible parameters Returns: (string, int): instance type and associated instance count from selected inference recommendation id if arguments combination check passed. """ if instance_type is not None and initial_instance_count is not None: logger.warning( "Both instance_type and initial_instance_count are specified," "overriding the recommendation result." ) return (instance_type, initial_instance_count) # Validate non-compatible parameters with recommendation id if accelerator_type is not None: raise ValueError("accelerator_type is not compatible with inference_recommendation_id.") if async_inference_config is not None: raise ValueError( "async_inference_config is not compatible with inference_recommendation_id." ) if serverless_inference_config is not None: raise ValueError( "serverless_inference_config is not compatible with inference_recommendation_id." ) if explainer_config is not None: raise ValueError("explainer_config is not compatible with inference_recommendation_id.") # Validate recommendation ID format if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id): raise ValueError( f"Invalid inference_recommendation_id format: {inference_recommendation_id}. " f"Expected format: <job-or-model-name>/<8-character-id>" ) job_or_model_name = inference_recommendation_id.split("/")[0] sage_client = self.sagemaker_session.sagemaker_client # Get recommendation from right size job and model ( right_size_recommendation, model_recommendation, right_size_job_res, ) = self._get_recommendation( sage_client=sage_client, job_or_model_name=job_or_model_name, inference_recommendation_id=inference_recommendation_id, ) # Update params based on model recommendation if model_recommendation: if initial_instance_count is None: raise ValueError( "Must specify initial_instance_count when using model recommendation ID." ) # Update environment variables if they exist env_vars = getattr(self, 'env_vars', {}) env_vars.update(model_recommendation.get("Environment", {})) instance_type = model_recommendation["InstanceType"] return (instance_type, initial_instance_count) # Update params based on default inference recommendation if bool(instance_type) != bool(initial_instance_count): raise ValueError( "instance_type and initial_instance_count must both be specified together " "to override recommendation, or both omitted to use recommendation values." ) input_config = right_size_job_res["InputConfig"] model_config = right_size_recommendation["ModelConfiguration"] envs = model_config.get("EnvironmentParameters") # Update environment variables from recommendation recommend_envs = {} if envs: for env in envs: recommend_envs[env["Key"]] = env["Value"] # Safely update env_vars current_env_vars = getattr(self, 'env_vars', {}) current_env_vars.update(recommend_envs) # Update params with non-compilation recommendation results if ( "InferenceSpecificationName" not in model_config and "CompilationJobName" not in model_config ): if "ModelPackageVersionArn" in input_config: modelpkg_res = sage_client.describe_model_package( ModelPackageName=input_config["ModelPackageVersionArn"] ) self.s3_model_data_url = modelpkg_res["InferenceSpecification"]["Containers"][0][ "ModelDataUrl" ] self.image_uri = modelpkg_res["InferenceSpecification"]["Containers"][0]["Image"] elif "ModelName" in input_config: model_res = sage_client.describe_model(ModelName=input_config["ModelName"]) self.s3_model_data_url = model_res["PrimaryContainer"]["ModelDataUrl"] self.image_uri = model_res["PrimaryContainer"]["Image"] else: if "InferenceSpecificationName" in model_config: modelpkg_res = sage_client.describe_model_package( ModelPackageName=input_config["ModelPackageVersionArn"] ) self.s3_model_data_url = modelpkg_res["AdditionalInferenceSpecificationDefinition"][ "Containers" ][0]["ModelDataUrl"] self.image_uri = modelpkg_res["AdditionalInferenceSpecificationDefinition"][ "Containers" ][0]["Image"] elif "CompilationJobName" in model_config: compilation_res = sage_client.describe_compilation_job( CompilationJobName=model_config["CompilationJobName"] ) self.s3_model_data_url = compilation_res["ModelArtifacts"]["S3ModelArtifacts"] self.image_uri = compilation_res["InferenceImage"] instance_type = right_size_recommendation["EndpointConfiguration"]["InstanceType"] initial_instance_count = right_size_recommendation["EndpointConfiguration"][ "InitialInstanceCount" ] return (instance_type, initial_instance_count) def _convert_to_endpoint_configurations_json( self, hyperparameter_ranges: Optional[List[Dict[str, CategoricalParameter]]] ) -> Optional[List[Dict[str, Any]]]: """Convert hyperparameter ranges to endpoint configurations for Advanced jobs. Args: hyperparameter_ranges: List of hyperparameter range dictionaries Returns: List of endpoint configuration dictionaries, or None if no ranges provided Raises: ValueError: If instance_types not specified in hyperparameter ranges """ if not hyperparameter_ranges: return None endpoint_configurations_to_json = [] for parameter_range in hyperparameter_ranges: if not parameter_range.get("instance_types"): raise ValueError( "instance_types must be defined as a hyperparameter range for Advanced jobs" ) parameter_range = parameter_range.copy() instance_types = parameter_range.get("instance_types").values parameter_range.pop("instance_types") for instance_type in instance_types: parameter_ranges = [ {"Name": name, "Value": param.values} for name, param in parameter_range.items() ] endpoint_configurations_to_json.append( { "EnvironmentParameterRanges": { "CategoricalParameterRanges": parameter_ranges }, "InstanceType": instance_type, } ) return endpoint_configurations_to_json def _convert_to_traffic_pattern_json( self, traffic_type: Optional[str], phases: Optional[List[Phase]] ) -> Optional[Dict[str, Any]]: """Convert traffic pattern parameters for Advanced jobs. Args: traffic_type: Type of traffic pattern (defaults to "PHASES") phases: List of Phase objects defining load test pattern Returns: Traffic pattern dictionary, or None if no phases provided """ if not phases: return None return { "Phases": [phase.to_json for phase in phases], "TrafficType": traffic_type if traffic_type else "PHASES", } def _convert_to_resource_limit_json( self, max_tests: Optional[int], max_parallel_tests: Optional[int] ) -> Optional[Dict[str, int]]: """Convert resource limit parameters for Advanced jobs. Args: max_tests: Maximum number of tests to run max_parallel_tests: Maximum number of parallel tests Returns: Resource limit dictionary, or None if no limits specified """ if not max_tests and not max_parallel_tests: return None resource_limit = {} if max_tests: resource_limit["MaxNumberOfTests"] = max_tests if max_parallel_tests: resource_limit["MaxParallelOfTests"] = max_parallel_tests return resource_limit def _convert_to_stopping_conditions_json( self, max_invocations: Optional[int], model_latency_thresholds: Optional[List[ModelLatencyThreshold]] ) -> Optional[Dict[str, Any]]: """Convert stopping condition parameters for Advanced jobs. Args: max_invocations: Maximum number of invocations per minute model_latency_thresholds: List of latency threshold requirements Returns: Stopping conditions dictionary, or None if no conditions specified """ if not max_invocations and not model_latency_thresholds: return None stopping_conditions = {} if max_invocations: stopping_conditions["MaxInvocations"] = max_invocations if model_latency_thresholds: stopping_conditions["ModelLatencyThresholds"] = [ threshold.to_json for threshold in model_latency_thresholds ] return stopping_conditions def _get_recommendation( self, sage_client: Any, job_or_model_name: str, inference_recommendation_id: str ) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: """Retrieve recommendation from right-size job or model. Args: sage_client: SageMaker client for API calls job_or_model_name: Name of the job or model inference_recommendation_id: ID of the specific recommendation Returns: Tuple of (right_size_recommendation, model_recommendation, right_size_job_res) Raises: ValueError: If recommendation ID is not found in any source """ right_size_recommendation, model_recommendation, right_size_job_res = None, None, None # Try to get recommendation from right-size job first right_size_recommendation, right_size_job_res = self._get_right_size_recommendation( sage_client=sage_client, job_or_model_name=job_or_model_name, inference_recommendation_id=inference_recommendation_id, ) # If not found in job, try model recommendations if right_size_recommendation is None: model_recommendation = self._get_model_recommendation( sage_client=sage_client, job_or_model_name=job_or_model_name, inference_recommendation_id=inference_recommendation_id, ) if model_recommendation is None: raise ValueError( f"Recommendation ID '{inference_recommendation_id}' not found in " f"job '{job_or_model_name}' or associated model recommendations" ) return right_size_recommendation, model_recommendation, right_size_job_res def _get_right_size_recommendation( self, sage_client: Any, job_or_model_name: str, inference_recommendation_id: str, ) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: """Get recommendation from right-size job. Args: sage_client: SageMaker client job_or_model_name: Name of the inference recommendations job inference_recommendation_id: Specific recommendation ID to find Returns: Tuple of (recommendation, job_results) or (None, None) if not found """ right_size_recommendation, right_size_job_res = None, None try: right_size_job_res = sage_client.describe_inference_recommendations_job( JobName=job_or_model_name ) if right_size_job_res: right_size_recommendation = self._search_recommendation( recommendation_list=right_size_job_res.get("InferenceRecommendations", []), inference_recommendation_id=inference_recommendation_id, ) except sage_client.exceptions.ResourceNotFound: pass return right_size_recommendation, right_size_job_res def _get_model_recommendation( self, sage_client: Any, job_or_model_name: str, inference_recommendation_id: str, ) -> Optional[Dict[str, Any]]: """Get recommendation from model deployment recommendations. Args: sage_client: SageMaker client job_or_model_name: Name of the model inference_recommendation_id: Specific recommendation ID to find Returns: Model recommendation dictionary or None if not found """ model_recommendation = None try: model_res = sage_client.describe_model(ModelName=job_or_model_name) if model_res: deployment_rec = model_res.get("DeploymentRecommendation", {}) realtime_recs = deployment_rec.get("RealTimeInferenceRecommendations", []) model_recommendation = self._search_recommendation( recommendation_list=realtime_recs, inference_recommendation_id=inference_recommendation_id, ) except sage_client.exceptions.ResourceNotFound: pass return model_recommendation def _search_recommendation( self, recommendation_list: List[Dict[str, Any]], inference_recommendation_id: str ) -> Optional[Dict[str, Any]]: """Search for specific recommendation by ID. Args: recommendation_list: List of recommendation dictionaries inference_recommendation_id: ID to search for Returns: Matching recommendation dictionary or None if not found """ return next( ( rec for rec in recommendation_list if rec.get("RecommendationId") == inference_recommendation_id ), None, ) def _filter_recommendations_for_realtime(self) -> Tuple[Optional[str], Optional[int]]: """Filter recommendations to find real-time (non-serverless) instance. Returns: Tuple of (instance_type, initial_instance_count) for first real-time recommendation found, or (None, None) if none found. Note: TODO: Integrate right_size + deploy with serverless support """ instance_type = None initial_instance_count = None inference_recommendations = getattr(self, 'inference_recommendations', []) for recommendation in inference_recommendations: endpoint_config = recommendation.get("EndpointConfiguration", {}) if "ServerlessConfig" not in endpoint_config: instance_type = endpoint_config.get("InstanceType") initial_instance_count = endpoint_config.get("InitialInstanceCount") break return (instance_type, initial_instance_count)