Source code for sagemaker.serve.marshalling.triton_translator

"""Implements class converts data from and to np.ndarray"""

from __future__ import absolute_import

import logging

import numpy as np

logger = logging.getLogger(__name__)


# pylint: disable=unused-argument
[docs] class TorchTensorTranslator: """Translate torch.Tensor from and to numpy.ndarray""" def __init__(self) -> None: import torch self.convert_from_numpy = torch.from_numpy # pylint: disable=E1101 self.CONTENT_TYPE = "tensor/pt" self.ACCEPT = "tensor/pt"
[docs] def serialize(self, data, content_type: str = "tensor/pt"): """Translate torch.Tensor to numpy ndarray""" try: return data.detach().numpy() except Exception as e: logger.error(e) raise ValueError("Unable to translate data %s to np.ndarray: %s" % (type(data), e))
[docs] def deserialize(self, data, content_type: str = "application/x-npy"): """Translate numpy ndarray to torch.Tensor""" try: return self.convert_from_numpy(data) except Exception as e: logger.error(e) raise ValueError("Unable to translate data %s to torch.Tensor: %s " % (type(data), e))
def _deserializer(self): """Dummy function to align with DeserializerWrapper in SchemaBuilder""" raise ValueError("This method is not meant to be invoked.")
[docs] class TensorflowTensorTranslator: """Converts tf.Tensor from and to numpy.ndarray""" def __init__(self) -> None: import tensorflow as tf self.convert_to_tensor = tf.convert_to_tensor self.CONTENT_TYPE = "tensor/tf" self.ACCEPT = "tensor/tf"
[docs] def serialize(self, data, content_type: str = "tensor/tf"): """Translate tf.Tensor to numpy ndarray""" try: return data.numpy() except Exception as e: logger.error(e) raise ValueError("Unable to convert data %s to np.ndarray" % type(data)) from e
[docs] def deserialize(self, data, content_type: str = "application/x-npy"): """Translate numpy ndarray to torch.Tensor""" try: return self.convert_to_tensor(data) except Exception as e: logger.error(e) raise ValueError("Unable to convert data %s to tf.Tensor" % type(data)) from e
def _deserializer(self): """Dummy function to align with DeserializerWrapper in SchemaBuilder""" raise ValueError("This method is not meant to be invoked.")
[docs] class NumpyTranslator: """A dummy class to make sure the translator interface is aligned""" def __init__(self) -> None: self.CONTENT_TYPE = "application/x-npy" self.ACCEPT = "application/x-npy"
[docs] def serialize(self, data, content_type: str = "application/x-npy"): """Placeholder docstring""" return data
[docs] def deserialize(self, data, content_type: str = "application/x-npy"): """Placeholder docstring""" return data
def _deserializer(self): """Dummy function to align with DeserializerWrapper in SchemaBuilder""" raise ValueError("This method is not meant to be invoked.")
[docs] class ListTranslator: """Translate python list from and to numpy.ndarray""" def __init__(self) -> None: self.CONTENT_TYPE = "application/list" self.ACCEPT = "application/list"
[docs] def serialize(self, data, content_type: str = "application/list"): """Placeholder docstring""" try: return np.array(data) except Exception as e: logger.error(e) raise ValueError("Unable to convert data %s to np.ndarray" % type(data)) from e
[docs] def deserialize(self, data, content_type: str = "application/x-npy"): """Placeholder docstring""" try: return data.tolist() except Exception as e: logger.error(e) raise ValueError("Unable to convert data %s to python list" % type(data)) from e
def _deserializer(self): """Dummy function to align with DeserializerWrapper in SchemaBuilder""" raise ValueError("This method is not meant to be invoked.")