Source code for sagemaker.serve.local_resources

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Local Model and Endpoint classes for V3 ModelBuilder local mode support.

These classes provide sagemaker-core compatible interfaces for local deployment,
wrapping V2 local mode functionality.
"""

from __future__ import absolute_import
import datetime
import json
import logging
from typing import Any, Dict, Optional, Tuple
import io
import json
from sagemaker.serve.utils.types import ModelServer
from sagemaker.core.serializers import JSONSerializer, IdentitySerializer
from sagemaker.core.deserializers import JSONDeserializer, BytesDeserializer
from sagemaker.core.resources import Model

logger = logging.getLogger(__name__)

APPLICATION_X_NPY = "application/x-npy"

# Triton gets serializer/deserializer from schema_builder
DEFAULT_SERIALIZERS_BY_SERVER: Dict[ModelServer, Tuple] = {
    ModelServer.TORCHSERVE: (IdentitySerializer(), BytesDeserializer()),
    ModelServer.TENSORFLOW_SERVING: (JSONSerializer(), JSONDeserializer()),  # TF Serving expects JSON
    ModelServer.DJL_SERVING: (JSONSerializer(), JSONDeserializer()),
    ModelServer.TEI: (JSONSerializer(), JSONDeserializer()),
    ModelServer.TGI: (JSONSerializer(), JSONDeserializer()),
    ModelServer.MMS: (JSONSerializer(), JSONDeserializer()),
    ModelServer.SMD: (JSONSerializer(), JSONDeserializer()),
}


[docs] class InvokeEndpointOutput: """Response wrapper to match sagemaker-core Endpoint.invoke() output format.""" def __init__(self, body: bytes, content_type: str = "application/json"): self.body = body self.content_type = content_type
[docs] class LocalEndpoint: """Local endpoint that mimics sagemaker.core.Endpoint interface. This class wraps V2 LocalSession endpoint functionality to provide a unified interface compatible with sagemaker-core Endpoint resources. """ def __init__( self, endpoint_name: str, endpoint_config_name: str, local_session=None, local_model=None, in_process_mode=False, local_container_mode_obj=None, in_process_mode_obj=None, model_server=None, secret_key=None, serializer=None, deserializer=None, container_config="auto", **kwargs ): """Initialize local endpoint. Args: endpoint_name: Name of the endpoint endpoint_config_name: Name of the endpoint configuration local_session: V2 LocalSession instance """ self.endpoint_name = endpoint_name self.endpoint_config_name = endpoint_config_name self.creation_time = datetime.datetime.now() self._local_model = local_model self.in_process_mode = in_process_mode self.local_container_mode_obj=local_container_mode_obj self.in_process_mode_obj=in_process_mode_obj self.model_server=model_server self.secret_key=secret_key self.serializer=serializer self.deserializer=deserializer self.container_config=container_config # Import V3 LocalSession if local_session is None: from sagemaker.core.local.local_session import LocalSession self._local_session = LocalSession() else: self._local_session = local_session # @property # def endpoint_arn(self) -> str: # """Fake ARN for compatibility with sagemaker-core interface.""" # return f"arn:aws:sagemaker:local:000000000000:endpoint/{self.endpoint_name}" @property def endpoint_status(self) -> str: """Get endpoint status. Implementation based on V2 LocalSession.describe_endpoint() Reference: /sagemaker/local/local_session.py:describe_endpoint() """ try: endpoint_info = self._local_session.sagemaker_client.describe_endpoint( EndpointName=self.endpoint_name ) return endpoint_info["EndpointStatus"] except Exception: return "Failed" def _universal_deep_ping(self) -> tuple[bool, Any]: """Universal ping function that works for all model servers.""" response = None logger.info("Pinging local endpoint...") try: # Get sample input from schema builder if self.in_process_mode: sample_input = self.in_process_mode_obj.schema_builder.sample_input else: sample_input = self.local_container_mode_obj.schema_builder.sample_input # Use unified invoke interface invoke_response = self.invoke(body=sample_input) if self.in_process_mode: # IN_PROCESS: Response is already deserialized response = invoke_response.body healthy = response is not None else: # LOCAL_CONTAINER: Response needs decoding response_body = invoke_response.body.read().decode('utf-8') response = json.loads(response_body) healthy = response is not None return (healthy, response) except Exception as e: if "422 Client Error: Unprocessable Entity for url" in str(e): from sagemaker.serve.utils.exceptions import LocalModelInvocationException raise LocalModelInvocationException(str(e)) return (False, None)
[docs] def invoke( self, body: Any, content_type: str = "application/json", accept: str = "application/json", **kwargs ) -> InvokeEndpointOutput: """Invoke the local endpoint using model server-specific logic.""" if self.in_process_mode: if not self.in_process_mode_obj: raise ValueError("In Process container mode not available") serializer = self.serializer or JSONSerializer() deserializer = self.deserializer or JSONDeserializer() serialized_data = serializer.serialize(body) raw_response = self.in_process_mode_obj._invoke_serving( serialized_data, content_type, accept ) return InvokeEndpointOutput( body=deserializer.deserialize(io.BytesIO(raw_response)), content_type=accept ) else: if not self.model_server or not self.local_container_mode_obj: raise ValueError("Model server or container mode not available") # Get serializers (use defaults if not provided by model) serializer = self.serializer or JSONSerializer() deserializer = self.deserializer or JSONDeserializer() content_type = content_type if content_type != "application/json" else serializer.CONTENT_TYPE deserializer_accept = deserializer.ACCEPT if not isinstance(deserializer_accept, str): deserializer_accept = deserializer_accept[0] accept = accept if accept != "application/json" else deserializer_accept # Route to appropriate model server invoke method if self.model_server == ModelServer.TORCHSERVE: # TorchServe: Use serializer-derived content types (V2 pattern) serialized_data = serializer.serialize(body) if not isinstance(body, str) else body raw_response = self.local_container_mode_obj._invoke_torch_serve( serialized_data, content_type, accept ) response_data = deserializer.deserialize(io.BytesIO(raw_response)) elif self.model_server == ModelServer.TRITON: # Triton: Direct data, no serialization, fixed content types (V2 pattern) from sagemaker.serve.utils.predictors import APPLICATION_X_NPY raw_response = self.local_container_mode_obj._invoke_triton_server( body, # ← Direct data, no serialization APPLICATION_X_NPY, APPLICATION_X_NPY ) response_data = raw_response elif self.model_server == ModelServer.DJL_SERVING: # DJL: Use serializer-derived content types + deserialize with content_type serialized_data = serializer.serialize(body) if not isinstance(body, str) else body raw_response = self.local_container_mode_obj._invoke_djl_serving( serialized_data, content_type, accept ) response_data = deserializer.deserialize( io.BytesIO(raw_response), content_type ) elif self.model_server == ModelServer.TGI: # TGI: Use serializer-derived content types + list format serialized_data = serializer.serialize(body) if not isinstance(body, str) else body raw_response = self.local_container_mode_obj._invoke_tgi_serving( serialized_data, content_type, accept ) response_data = [deserializer.deserialize( io.BytesIO(raw_response), content_type )] elif self.model_server == ModelServer.MMS: # MMS: Use serializer-derived content types + list format serialized_data = serializer.serialize(body) if not isinstance(body, str) else body raw_response = self.local_container_mode_obj._invoke_multi_model_server_serving( serialized_data, content_type, accept ) response_data = [deserializer.deserialize( io.BytesIO(raw_response), content_type )] elif self.model_server == ModelServer.TENSORFLOW_SERVING: # TensorFlow: Use serializer-derived content types serialized_data = serializer.serialize(body) if not isinstance(body, str) else body raw_response = self.local_container_mode_obj._invoke_tensorflow_serving( serialized_data, content_type, accept ) response_data = deserializer.deserialize(io.BytesIO(raw_response)) elif self.model_server == ModelServer.TEI: # TEI: Use serializer-derived content types serialized_data = serializer.serialize(body) if not isinstance(body, str) else body raw_response = self.local_container_mode_obj._invoke_serving( serialized_data, content_type, accept ) response_data = deserializer.deserialize(io.BytesIO(raw_response)) else: raise ValueError(f"Unsupported model server: {self.model_server}") # Return in sagemaker-core compatible format return InvokeEndpointOutput( body=io.BytesIO(json.dumps(response_data).encode('utf-8')), content_type=accept )
[docs] @classmethod def create( cls, endpoint_name: str, endpoint_config_name: Optional[str] = None, local_model: Optional[Model] = None, local_session=None, in_process_mode=False, local_container_mode_obj=None, in_process_mode_obj=None, model_server=None, secret_key=None, serializer=None, deserializer=None, container_config="auto", **kwargs ) -> "LocalEndpoint": """Create and start local endpoint.""" if local_session is None: from sagemaker.core.local.local_session import LocalSession local_session = LocalSession() if in_process_mode: endpoint = cls( endpoint_name=endpoint_name, endpoint_config_name=endpoint_config_name or f"{endpoint_name}-config", local_session=local_session, local_model=local_model, in_process_mode=in_process_mode, local_container_mode_obj=local_container_mode_obj, in_process_mode_obj=in_process_mode_obj, model_server=model_server, secret_key=secret_key, serializer=serializer, deserializer=deserializer, container_config=container_config, **kwargs ) endpoint.in_process_mode_obj.create_server( ping_fn=endpoint._universal_deep_ping ) return endpoint else: # Create endpoint instance first so we can reference its ping method endpoint = cls( endpoint_name=endpoint_name, endpoint_config_name=endpoint_config_name or f"{endpoint_name}-config", local_session=local_session, local_model=local_model, in_process_mode=in_process_mode, local_container_mode_obj=local_container_mode_obj, in_process_mode_obj=in_process_mode_obj, model_server=model_server, secret_key=secret_key, serializer=serializer, deserializer=deserializer, container_config=container_config, **kwargs ) # Start container with ping function endpoint.local_container_mode_obj.create_server( image=local_model.primary_container.image, container_timeout_seconds=kwargs.get("container_timeout_seconds", 300), secret_key=endpoint.secret_key, ping_fn=endpoint._universal_deep_ping, env_vars=local_model.primary_container.environment or {}, model_path=endpoint.local_container_mode_obj.model_path, container_config=_get_container_config(endpoint.container_config) ) # Register endpoint with V2 LocalSession production_variants = [{ "VariantName": "AllTraffic", "ModelName": local_model.model_name, "InitialInstanceCount": 1, "InstanceType": "local" }] local_session.sagemaker_client.create_endpoint_config( EndpointConfigName=endpoint.endpoint_config_name, ProductionVariants=production_variants ) # Then create endpoint local_session.sagemaker_client.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint.endpoint_config_name ) return endpoint
[docs] @classmethod def get(cls, endpoint_name: str, local_session=None) -> Optional["LocalEndpoint"]: """Get existing local endpoint. Implementation based on V2 LocalSession.describe_endpoint() Reference: /sagemaker/local/local_session.py:describe_endpoint() """ if local_session is None: from sagemaker.core.local.local_session import LocalSession local_session = LocalSession() try: # Call V2 describe_endpoint to get endpoint info endpoint_info = local_session.sagemaker_client.describe_endpoint( EndpointName=endpoint_name ) return cls( endpoint_name=endpoint_name, endpoint_config_name=endpoint_info["EndpointConfigName"], local_session=local_session ) except Exception: # Endpoint not found return None
[docs] def refresh(self) -> "LocalEndpoint": """Refresh endpoint state. Implementation based on V2 LocalSession.describe_endpoint() Reference: /sagemaker/local/local_session.py:describe_endpoint() """ endpoint_info = self._local_session.sagemaker_client.describe_endpoint( EndpointName=self.endpoint_name ) # Update attributes from V2 response self.endpoint_config_name = endpoint_info["EndpointConfigName"] return self
[docs] def delete(self) -> None: """Delete local endpoint and cleanup container. Implementation based on V2 LocalSession.delete_endpoint() Reference: /sagemaker/local/local_session.py:delete_endpoint() This calls _LocalEndpoint.stop() which stops the Docker container """ self._local_session.sagemaker_client.delete_endpoint( EndpointName=self.endpoint_name )
[docs] def update(self, endpoint_config_name: str) -> None: """Update endpoint configuration. V2 Reference: /sagemaker/local/local_session.py:update_endpoint() Note: V2 raises NotImplementedError for update_endpoint """ raise NotImplementedError("Update endpoint is not supported in local mode")
[docs] class LocalEndpointConfig: """Local endpoint configuration that mimics sagemaker.core.EndpointConfig interface.""" def __init__( self, endpoint_config_name: str, production_variants: list, local_session=None, **kwargs ): """Initialize local endpoint config. Args: endpoint_config_name: Name of the endpoint configuration production_variants: List of production variant configurations local_session: V2 LocalSession instance """ self.endpoint_config_name = endpoint_config_name self.production_variants = production_variants self.creation_time = datetime.datetime.now() if local_session is None: from sagemaker.core.local.local_session import LocalSession self._local_session = LocalSession() else: self._local_session = local_session
[docs] @classmethod def create( cls, endpoint_config_name: str, production_variants: list, local_session=None, **kwargs ) -> "LocalEndpointConfig": """Create local endpoint configuration. Implementation based on V2 LocalSession.create_endpoint_config() Reference: /sagemaker/local/local_session.py:create_endpoint_config() """ if local_session is None: from sagemaker.core.local.local_session import LocalSession local_session = LocalSession() # Create instance local_config = cls( endpoint_config_name=endpoint_config_name, production_variants=production_variants, local_session=local_session ) # Call V2 LocalSession.create_endpoint_config() local_session.sagemaker_client.create_endpoint_config( EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants ) return local_config
[docs] def delete(self) -> None: """Delete local endpoint configuration. Implementation based on V2 LocalSession.delete_endpoint_config() Reference: /sagemaker/local/local_session.py:delete_endpoint_config() """ self._local_session.sagemaker_client.delete_endpoint_config( EndpointConfigName=self.endpoint_config_name )
def _get_container_config(config: str) -> dict: """Get container configuration based on config type.""" if config == "host": return {"network_mode": "host"} elif config == "bridge": return {"ports": {'8080/tcp': 8080}} elif config == "auto": import platform if platform.system().lower() == "linux": return {"network_mode": "host"} else: return {"ports": {'8080/tcp': 8080}} else: raise ValueError("container_config must be 'host', 'bridge', or 'auto'")