Source code for sagemaker.serve.model_server.triton.server

"""Placeholder docerting"""

from __future__ import absolute_import
import uuid
import logging
import importlib
import platform

from sagemaker.core import fw_utils
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.common_utils import _is_s3_uri
from sagemaker.serve.utils.uploader import upload
from sagemaker.core.s3.utils import determine_bucket_and_prefix, parse_s3_url
from sagemaker.core.local.local_session import get_docker_host
import docker
from docker.types import DeviceRequest

logger = logging.getLogger(__name__)

# TODO: automatically update memory size
_SHM_SIZE = "2G"


[docs] class LocalTritonServer: """Placeholder docstring""" def __init__(self) -> None: self.triton_client = None def _start_triton_server( self, docker_client: docker.DockerClient, model_path: str, secret_key: str, image_uri: str, env_vars: dict, ): """Placeholder docstring""" self.container_name = "triton" + uuid.uuid1().hex model_repository = model_path + "/model_repository" env_vars.update( { "TRITON_MODEL_DIR": "/models/model", "LOCAL_PYTHON": platform.python_version(), } ) if "cpu" not in image_uri: self.container = docker_client.containers.run( image=image_uri, command=["tritonserver", "--model-repository=/models"], shm_size=_SHM_SIZE, device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])], network_mode="host", detach=True, auto_remove=True, volumes={model_repository: {"bind": "/models", "mode": "rw"}}, environment=env_vars, ) else: self.container = docker_client.containers.run( image=image_uri, command=["tritonserver", "--model-repository=/models"], shm_size=_SHM_SIZE, network_mode="host", detach=True, auto_remove=True, volumes={model_repository: {"bind": "/models", "mode": "rw"}}, environment=env_vars, ) def _invoke_triton_server(self, payload, *args, **kwargs): """Placeholder docstring""" httpClient = importlib.import_module("tritonclient.http") if not self.triton_client: self.triton_client = httpClient.InferenceServerClient(url=f"{get_docker_host()}:8000") payload = self.schema_builder.input_serializer.serialize(payload) dtype = self.schema_builder._input_triton_dtype.split("_")[-1] input_request = httpClient.InferInput("input_1", payload.shape, datatype=dtype) input_request.set_data_from_numpy(payload, binary_data=True) response = self.triton_client.infer(model_name="model", inputs=[input_request]) response_name = response.get_response().get("outputs")[0].get("name") return self.schema_builder.output_deserializer.deserialize(response.as_numpy(response_name))
[docs] class SageMakerTritonServer: """Placeholder docstring""" def __init__(self) -> None: pass def _upload_triton_artifacts( self, model_path: str, sagemaker_session: Session, secret_key: str, s3_model_data_url: str = None, image: str = None, should_upload_artifacts: bool = False, ): """Tar triton artifacts and upload to s3""" s3_upload_path = None if _is_s3_uri(model_path): s3_upload_path = model_path elif should_upload_artifacts: if s3_model_data_url: bucket, key_prefix = parse_s3_url(url=s3_model_data_url) else: bucket, key_prefix = None, None code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) bucket, code_key_prefix = determine_bucket_and_prefix( bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session ) logger.debug( "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix, ) model_repository = model_path + "/model_repository" s3_upload_path = upload(sagemaker_session, model_repository, bucket, code_key_prefix) logger.debug("Model resources uploaded to: %s", s3_upload_path) env_vars = { "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model", "TRITON_MODEL_DIR": "/opt/ml/model/model", "LOCAL_PYTHON": platform.python_version(), } return s3_upload_path, env_vars