Source code for sagemaker.core.serializers.base

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Implements base methods for serializing data for an inference endpoint."""
from __future__ import absolute_import

import abc
from collections.abc import Iterable
import csv
import io
import json
import numpy as np
from pandas import DataFrame
from six import with_metaclass

# Lazy import to avoid circular dependency with amazon modules
# from sagemaker.core.serializers.utils import write_numpy_to_dense_tensor
from sagemaker.core.common_utils import DeferredError

try:
    import scipy.sparse
except ImportError as e:
    scipy = DeferredError(e)


[docs] class BaseSerializer(abc.ABC): """Abstract base class for creation of new serializers. Provides a skeleton for customization requiring the overriding of the method serialize and the class attribute CONTENT_TYPE. """
[docs] @abc.abstractmethod def serialize(self, data): """Serialize data into the media type specified by CONTENT_TYPE. Args: data (object): Data to be serialized. Returns: object: Serialized data used for a request. """
@property @abc.abstractmethod def CONTENT_TYPE(self): """The MIME type of the data sent to the inference endpoint."""
[docs] class SimpleBaseSerializer(with_metaclass(abc.ABCMeta, BaseSerializer)): """Abstract base class for creation of new serializers. This class extends the API of :class:~`sagemaker.serializers.BaseSerializer` with more user-friendly options for setting the Content-Type header, in situations where it can be provided at init and freely updated. """ def __init__(self, content_type="application/json"): """Initialize a ``SimpleBaseSerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "application/json"). """ super(SimpleBaseSerializer, self).__init__() if not isinstance(content_type, str): raise ValueError( "content_type must be a string specifying the MIME type of the data sent in " "requests: e.g. 'application/json', 'text/csv', etc. Got %s" % content_type ) self.content_type = content_type @property def CONTENT_TYPE(self): """The data MIME type set in the Content-Type header on prediction endpoint requests.""" return self.content_type
[docs] class CSVSerializer(SimpleBaseSerializer): """Serialize data of various formats to a CSV-formatted string.""" def __init__(self, content_type="text/csv"): """Initialize a ``CSVSerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "text/csv"). """ super(CSVSerializer, self).__init__(content_type=content_type)
[docs] def serialize(self, data): """Serialize data of various formats to a CSV-formatted string. Args: data (object): Data to be serialized. Can be a NumPy array, list, file, Pandas DataFrame, or buffer. Returns: str: The data serialized as a CSV-formatted string. """ if hasattr(data, "read"): return data.read() if isinstance(data, DataFrame): return data.to_csv(header=False, index=False) is_mutable_sequence_like = self._is_sequence_like(data) and hasattr(data, "__setitem__") has_multiple_rows = len(data) > 0 and self._is_sequence_like(data[0]) if is_mutable_sequence_like and has_multiple_rows: return "\n".join([self._serialize_row(row) for row in data]) return self._serialize_row(data)
def _serialize_row(self, data): """Serialize data as a CSV-formatted row. Args: data (object): Data to be serialized in a row. Returns: str: The data serialized as a CSV-formatted row. """ if isinstance(data, str): return data if isinstance(data, np.ndarray): data = np.ndarray.flatten(data) if hasattr(data, "__len__"): if len(data) == 0: raise ValueError("Cannot serialize empty array") csv_buffer = io.StringIO() csv_writer = csv.writer(csv_buffer, delimiter=",") csv_writer.writerow(data) return csv_buffer.getvalue().rstrip("\r\n") raise ValueError("Unable to handle input format: %s" % type(data)) def _is_sequence_like(self, data): """Returns true if obj is iterable and subscriptable.""" return hasattr(data, "__iter__") and hasattr(data, "__getitem__")
[docs] class NumpySerializer(SimpleBaseSerializer): """Serialize data to a buffer using the .npy format.""" def __init__(self, dtype=None, content_type="application/x-npy"): """Initialize a ``NumpySerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "application/x-npy"). dtype (str): The dtype of the data. """ super(NumpySerializer, self).__init__(content_type=content_type) self.dtype = dtype
[docs] def serialize(self, data): """Serialize data to a buffer using the .npy format. Args: data (object): Data to be serialized. Can be a NumPy array, list, file, or buffer. Returns: io.BytesIO: A buffer containing data serialzied in the .npy format. """ if isinstance(data, np.ndarray): if data.size == 0: raise ValueError("Cannot serialize empty array.") return self._serialize_array(data) if isinstance(data, list): if len(data) == 0: raise ValueError("Cannot serialize empty array.") return self._serialize_array(np.array(data, self.dtype)) # files and buffers. Assumed to hold npy-formatted data. if hasattr(data, "read"): return data.read() return self._serialize_array(np.array(data))
def _serialize_array(self, array): """Saves a NumPy array in a buffer. Args: array (numpy.ndarray): The array to serialize. Returns: io.BytesIO: A buffer containing the serialized array. """ buffer = io.BytesIO() np.save(buffer, array) return buffer.getvalue()
[docs] class JSONSerializer(SimpleBaseSerializer): """Serialize data to a JSON formatted string."""
[docs] def serialize(self, data): """Serialize data of various formats to a JSON formatted string. Args: data (object): Data to be serialized. Returns: str: The data serialized as a JSON string. """ if isinstance(data, dict): return json.dumps( { key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in data.items() } ) if hasattr(data, "read"): return data.read() if isinstance(data, np.ndarray): return json.dumps(data.tolist()) return json.dumps(data)
[docs] class IdentitySerializer(SimpleBaseSerializer): """Serialize data by returning data without modification. This serializer may be useful if, for example, you're sending raw bytes such as from an image file's .read() method. """ def __init__(self, content_type="application/octet-stream"): """Initialize an ``IdentitySerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "application/octet-stream"). """ super(IdentitySerializer, self).__init__(content_type=content_type)
[docs] def serialize(self, data): """Return data without modification. Args: data (object): Data to be serialized. Returns: object: The unmodified data. """ return data
[docs] class JSONLinesSerializer(SimpleBaseSerializer): """Serialize data to a JSON Lines formatted string.""" def __init__(self, content_type="application/jsonlines"): """Initialize a ``JSONLinesSerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "application/jsonlines"). """ super(JSONLinesSerializer, self).__init__(content_type=content_type)
[docs] def serialize(self, data): """Serialize data of various formats to a JSON Lines formatted string. Args: data (object): Data to be serialized. The data can be a string, iterable of JSON serializable objects, or a file-like object. Returns: str: The data serialized as a string containing newline-separated JSON values. """ if isinstance(data, str): return data if hasattr(data, "read"): return data.read() if isinstance(data, Iterable): return "\n".join(json.dumps(element) for element in data) raise ValueError("Object of type %s is not JSON Lines serializable." % type(data))
[docs] class SparseMatrixSerializer(SimpleBaseSerializer): """Serialize a sparse matrix to a buffer using the .npz format.""" def __init__(self, content_type="application/x-npz"): """Initialize a ``SparseMatrixSerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "application/x-npz"). """ super(SparseMatrixSerializer, self).__init__(content_type=content_type)
[docs] def serialize(self, data): """Serialize a sparse matrix to a buffer using the .npz format. Sparse matrices can be in the ``csc``, ``csr``, ``bsr``, ``dia`` or ``coo`` formats. Args: data (scipy.sparse.spmatrix): The sparse matrix to serialize. Returns: io.BytesIO: A buffer containing the serialized sparse matrix. """ buffer = io.BytesIO() scipy.sparse.save_npz(buffer, data) return buffer.getvalue()
[docs] class LibSVMSerializer(SimpleBaseSerializer): """Serialize data of various formats to a LibSVM-formatted string. The data must already be in LIBSVM file format: <label> <index1>:<value1> <index2>:<value2> ... It is suitable for sparse datasets since it does not store zero-valued features. """ def __init__(self, content_type="text/libsvm"): """Initialize a ``LibSVMSerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "text/libsvm"). """ super(LibSVMSerializer, self).__init__(content_type=content_type)
[docs] def serialize(self, data): """Serialize data of various formats to a LibSVM-formatted string. Args: data (object): Data to be serialized. Can be a string or a file-like object. Returns: str: The data serialized as a LibSVM-formatted string. Raises: ValueError: If unable to handle input format """ if isinstance(data, str): return data if hasattr(data, "read"): return data.read() raise ValueError("Unable to handle input format: %s" % type(data))
[docs] class DataSerializer(SimpleBaseSerializer): """Serialize data in any file by extracting raw bytes from the file.""" def __init__(self, content_type="file-path/raw-bytes"): """Initialize a ``DataSerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "file-path/raw-bytes"). """ super(DataSerializer, self).__init__(content_type=content_type)
[docs] def serialize(self, data): """Serialize file data to a raw bytes. Args: data (object): Data to be serialized. The data can be a string representing file-path or the raw bytes from a file. Returns: raw-bytes: The data serialized as raw-bytes from the input. """ if isinstance(data, str): try: with open(data, "rb") as data_file: data_file_info = data_file.read() return data_file_info except Exception as e: raise ValueError(f"Could not open/read file: {data}. {e}") if isinstance(data, bytes): return data if isinstance(data, dict) and "data" in data: return self.serialize(data["data"]) raise ValueError(f"Object of type {type(data)} is not Data serializable.")
[docs] class StringSerializer(SimpleBaseSerializer): """Encode the string to utf-8 bytes.""" def __init__(self, content_type="text/plain"): """Initialize a ``StringSerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "text/plain"). """ super(StringSerializer, self).__init__(content_type=content_type)
[docs] def serialize(self, data): """Encode the string to utf-8 bytes. Args: data (object): Data to be serialized. Returns: raw-bytes: The data serialized as raw-bytes from the input. """ if isinstance(data, str): return data.encode("utf-8") raise ValueError(f"Object of type {type(data)} is not String serializable.")
[docs] class TorchTensorSerializer(SimpleBaseSerializer): """Serialize torch.Tensor to a buffer by converting tensor to numpy and call NumpySerializer. Args: data (object): Data to be serialized. The data must be of torch.Tensor type. Returns: raw-bytes: The data serialized as raw-bytes from the input. """ def __init__(self, content_type="tensor/pt"): super(TorchTensorSerializer, self).__init__(content_type=content_type) from torch import Tensor self.torch_tensor = Tensor self.numpy_serializer = NumpySerializer()
[docs] def serialize(self, data): """Serialize torch.Tensor to a buffer. Args: data (object): Data to be serialized. The data must be of torch.Tensor type. Returns: raw-bytes: The data serialized as raw-bytes from the input. """ if isinstance(data, self.torch_tensor): try: return self.numpy_serializer.serialize(data.detach().numpy()) except Exception as e: raise ValueError( "Unable to serialize your data because: %s.\ Please provide custom serialization in InferenceSpec. " % e ) raise ValueError("Object of type %s is not a torch.Tensor" % type(data))
# TODO fix the unit test for this serializer
[docs] class RecordSerializer(SimpleBaseSerializer): """Serialize a NumPy array for an inference request.""" def __init__(self, content_type="application/x-recordio-protobuf"): """Initialize a ``RecordSerializer`` instance. Args: content_type (str): The MIME type to signal to the inference endpoint when sending request data (default: "application/x-recordio-protobuf"). """ super(RecordSerializer, self).__init__(content_type=content_type)
[docs] def serialize(self, data): """Serialize a NumPy array into a buffer containing RecordIO records. Args: data (numpy.ndarray): The data to serialize. Returns: io.BytesIO: A buffer containing the data serialized as records. """ if len(data.shape) == 1: data = data.reshape(1, data.shape[0]) if len(data.shape) != 2: raise ValueError( "Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape) ) buffer = io.BytesIO() # Lazy import to avoid circular dependency from sagemaker.core.serializers.utils import write_numpy_to_dense_tensor write_numpy_to_dense_tensor(buffer, data) buffer.seek(0) return buffer