Source code for sagemaker.serve.mode.local_container_mode

"""Module that defines the LocalContainerMode class"""

from __future__ import absolute_import
from pathlib import Path
import logging
import os
from datetime import datetime, timedelta
from typing import Dict, Type
import base64
import time
import subprocess
import docker

from sagemaker.core.local.utils import check_for_studio

from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.utils.logging_agent import pull_logs
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.utils.exceptions import LocalDeepPingException
from sagemaker.serve.model_server.torchserve.server import LocalTorchServe
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
from sagemaker.serve.model_server.triton.server import LocalTritonServer
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
from sagemaker.serve.model_server.tei.server import LocalTeiServing
from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer
from sagemaker.core.helper.session_helper import Session

logger = logging.getLogger(__name__)

_PING_HEALTH_CHECK_INTERVAL_SEC = 5

_PING_HEALTH_CHECK_FAIL_MSG = (
    "Container did not pass the ping health check. "
    + "Please increase container_timeout_seconds or review your inference code."
)

STUDIO_DOCKER_SOCKET_PATHS = [
    "/docker/proxy/docker.sock",
    "/var/run/docker.sock",
]


def _get_docker_client():
    """Get a Docker client, handling SageMaker Studio's non-standard socket path."""
    if os.environ.get("DOCKER_HOST"):
        return docker.from_env()
    try:
        if check_for_studio():
            for socket_path in STUDIO_DOCKER_SOCKET_PATHS:
                if os.path.exists(socket_path):
                    return docker.DockerClient(base_url=f"unix://{socket_path}")
    except (NotImplementedError, Exception):
        pass
    return docker.from_env()


[docs] class LocalContainerMode( LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalMultiModelServer, LocalTensorflowServing, ): """A class that holds methods to deploy model to a container in local environment""" def __init__( self, model_server: ModelServer, inference_spec: Type[InferenceSpec], schema_builder: Type[SchemaBuilder], session: Session, model_path: str = None, env_vars: Dict = None, ): # pylint: disable=bad-super-call super().__init__() super(LocalDJLServing, self).__init__() super(LocalTritonServer, self).__init__() self.inference_spec = inference_spec self.model_path = model_path self.env_vars = env_vars self.session = session self.schema_builder = schema_builder self.ecr = session.boto_session.client("ecr") self.model_server = model_server self.client = None self.container = None self.secret_key = None self._ping_container = None self._invoke_serving = None
[docs] def load(self, model_path: str = None): """Placeholder docstring""" path = Path(model_path if model_path else self.model_path) if not path.exists(): raise Exception("model_path does not exist") if not path.is_dir(): raise Exception("model_path is not a valid directory") return self.inference_spec.load(str(path))
[docs] def prepare(self): """Placeholder docstring"""
[docs] def create_server( self, image: str, container_timeout_seconds: int, secret_key: str, container_config: Dict, ping_fn = None, env_vars: Dict[str, str] = None, model_path: str = None, jumpstart: bool = False, ): """Placeholder docstring""" self._pull_image(image=image) self.destroy_server() logger.info("Waiting for model server %s to start up...", self.model_server) self._ping_container = ping_fn or self._ping_container if self.model_server == ModelServer.TRITON: self._start_triton_server( docker_client=self.client, model_path=model_path if model_path else self.model_path, image_uri=image, secret_key=secret_key, env_vars=env_vars if env_vars else self.env_vars, ) elif self.model_server == ModelServer.DJL_SERVING: self._start_djl_serving( client=self.client, image=image, model_path=model_path if model_path else self.model_path, secret_key=secret_key, env_vars=env_vars if env_vars else self.env_vars, ) elif self.model_server == ModelServer.TORCHSERVE: self._start_torch_serve( client=self.client, image=image, model_path=model_path if model_path else self.model_path, secret_key=secret_key, env_vars=env_vars if env_vars else self.env_vars, ) elif self.model_server == ModelServer.TGI: self._start_tgi_serving( client=self.client, image=image, model_path=model_path if model_path else self.model_path, secret_key=secret_key, env_vars=env_vars if env_vars else self.env_vars, jumpstart=jumpstart, ) elif self.model_server == ModelServer.MMS: self._start_serving( client=self.client, image=image, model_path=model_path if model_path else self.model_path, secret_key=secret_key, env_vars=env_vars if env_vars else self.env_vars, ) elif self.model_server == ModelServer.TENSORFLOW_SERVING: self._start_tensorflow_serving( client=self.client, image=image, model_path=model_path if model_path else self.model_path, secret_key=secret_key, env_vars=env_vars if env_vars else self.env_vars, ) elif self.model_server == ModelServer.TEI: tei_serving = LocalTeiServing() tei_serving._start_tei_serving( client=self.client, image=image, model_path=model_path if model_path else self.model_path, secret_key=secret_key, env_vars=env_vars if env_vars else self.env_vars, ) tei_serving.schema_builder = self.schema_builder self.container = tei_serving.container self._invoke_serving = tei_serving._invoke_tei_serving # allow some time for container to be ready time.sleep(10) log_generator = self.container.logs(follow=True, stream=True) container_timeout_seconds = 1200 time_limit = datetime.now() + timedelta(seconds=container_timeout_seconds) healthy = False while True: now = datetime.now() final_pull = now > time_limit pull_logs( (x.decode("UTF-8").rstrip() for x in log_generator), log_generator.close, datetime.now() + timedelta(seconds=_PING_HEALTH_CHECK_INTERVAL_SEC), now > time_limit, ) if final_pull: break # allow some time for container to be ready time.sleep(10) healthy, response = self._ping_container() if healthy: logger.debug("Ping health check has passed. Returned %s", str(response)) break if not healthy: raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
[docs] def destroy_server(self): """Placeholder docstring""" if self.container: try: logger.debug("Stopping currently running container...") self.container.kill() except docker.errors.APIError as exc: if exc.response.status_code < 400 or exc.response.status_code > 499: raise Exception("Error encountered when cleaning up local container") from exc self.container = None
def _pull_image(self, image: str): """Pull image with proper error handling and early failure detection.""" # Check if Docker is available first try: self.client = _get_docker_client() self.client.ping() # Test Docker connection except Exception as e: raise RuntimeError( f"Docker is not available or not running. Please ensure Docker is installed and running. " f"Error: {e}" ) from e # Handle ECR authentication for ECR images if self._is_ecr_image(image): try: encoded_token = ( self.ecr.get_authorization_token() .get("authorizationData")[0] .get("authorizationToken") ) decoded_token = base64.b64decode(encoded_token).decode("utf-8") username, password = decoded_token.split(":") ecr_uri = image.split("/")[0] login_command = ["docker", "login", "-u", username, "-p", password, ecr_uri] result = subprocess.run(login_command, check=True, capture_output=True, text=True) logger.info("Successfully authenticated with ECR") except subprocess.CalledProcessError as e: error_msg = f"ECR authentication failed: {e.stderr if e.stderr else str(e)}" logger.error(error_msg) raise RuntimeError(error_msg) from e except Exception as e: error_msg = f"ECR authentication error: {str(e)}" logger.error(error_msg) raise RuntimeError(error_msg) from e # Pull the image try: logger.info("Pulling image %s from repository...", image) self.client.images.pull(image) logger.info("Successfully pulled image %s", image) except docker.errors.NotFound as e: raise ValueError(f"Could not find image '{image}' in repository") from e except docker.errors.APIError as e: raise RuntimeError(f"Failed to pull image '{image}': {e}") from e def _is_ecr_image(self, image: str) -> bool: """Check if image is from ECR.""" return ".dkr.ecr." in image and ".amazonaws.com" in image