Source code for sagemaker.core.remote_function.core.serialization

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""SageMaker remote function data serializer/deserializer."""
from __future__ import absolute_import

import dataclasses
import json

import io

import sys
import hashlib
import pickle

from typing import Any, Callable, Union

import cloudpickle
from tblib import pickling_support

from sagemaker.core.remote_function.errors import (
    ServiceError,
    SerializationError,
    DeserializationError,
)
from sagemaker.core.s3 import S3Downloader, S3Uploader
from sagemaker.core.helper.session_helper import Session
from ._custom_dispatch_table import dispatch_table

# Note: do not use os.path.join for s3 uris, fails on windows


def _get_python_version():
    """Returns the current python version."""
    return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"


@dataclasses.dataclass
class _MetaData:
    """Metadata about the serialized data or functions."""

    sha256_hash: str
    version: str = "2023-04-24"
    python_version: str = _get_python_version()
    serialization_module: str = "cloudpickle"

    def to_json(self):
        """Converts metadata to json string."""
        return json.dumps(dataclasses.asdict(self)).encode()

    @staticmethod
    def from_json(s):
        """Converts json string to metadata object."""
        try:
            obj = json.loads(s)
        except json.decoder.JSONDecodeError:
            raise DeserializationError("Corrupt metadata file. It is not a valid json file.")

        sha256_hash = obj.get("sha256_hash")
        metadata = _MetaData(sha256_hash=sha256_hash)
        metadata.version = obj.get("version")
        metadata.python_version = obj.get("python_version")
        metadata.serialization_module = obj.get("serialization_module")

        if not sha256_hash:
            raise DeserializationError(
                "Corrupt metadata file. SHA256 hash for the serialized data does not exist. "
                "Please make sure to install SageMaker SDK version >= 2.156.0 on the client side "
                "and try again."
            )

        if not (
            metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle"
        ):
            raise DeserializationError(
                f"Corrupt metadata file. Serialization approach {s} is not supported."
            )

        return metadata


[docs] class CloudpickleSerializer: """Serializer using cloudpickle."""
[docs] @staticmethod def serialize(obj: Any) -> bytes: """Serializes data object and uploads it to S3. Args: obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ try: io_buffer = io.BytesIO() custom_pickler = cloudpickle.CloudPickler(io_buffer) dt = pickle.Pickler.dispatch_table.__get__(custom_pickler) # pylint: disable=no-member new_dt = dt.new_child(dispatch_table) pickle.Pickler.dispatch_table.__set__( # pylint: disable=no-member custom_pickler, new_dt ) custom_pickler.dump(obj) return io_buffer.getvalue() except Exception as e: if isinstance( e, NotImplementedError ) and "Instance of Run type is not allowed to be pickled." in str(e): raise SerializationError( """You are trying to pass a sagemaker.experiments.run.Run object to a remote function or are trying to access a global sagemaker.experiments.run.Run object from within the function. This is not supported. You must use `load_run` to load an existing Run in the remote function or instantiate a new Run in the function.""" ) raise SerializationError( "Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e)) ) from e
[docs] @staticmethod def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: """Downloads from S3 and then deserializes data objects. Args: s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. bytes_to_deserialize: bytes to be deserialized. Returns : List of deserialized python objects. Raises: DeserializationError: when fail to serialize object to bytes. """ try: return cloudpickle.loads(bytes_to_deserialize) except Exception as e: raise DeserializationError( "Error when deserializing bytes downloaded from {}: {}. " "NOTE: this may be caused by inconsistent sagemaker python sdk versions " "where remote function runs versus the one used on client side. " "If the sagemaker versions do not match, a warning message would " "be logged starting with 'Inconsistent sagemaker versions found'. " "Please check it to validate.".format(s3_uri, repr(e)) ) from e
# TODO: use dask serializer in case dask distributed is installed in users' environment.
[docs] def serialize_func_to_s3( func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes function and uploads it to S3. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. func: function to be serialized and persisted Raises: SerializationError: when fail to serialize function to bytes. """ _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(func), s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, )
[docs] def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable: """Downloads from S3 and then deserializes data objects. This method downloads the serialized training job outputs to a temporary directory and then deserializes them using dask. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. Returns : The deserialized function. Raises: DeserializationError: when fail to serialize function to bytes. """ metadata = _MetaData.from_json( _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) ) bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
[docs] def serialize_obj_to_s3( obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes data object and uploads it to S3. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(obj), s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, )
[docs] def json_serialize_obj_to_s3( obj: Any, json_key: str, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None, ): """Json serializes data object and uploads it to S3. If a function step's output is data referenced by other steps via JsonGet, its output should be json serialized and uploaded to S3. Args: obj: (Any) object to be serialized and persisted. json_key: (str) the json key pointing to function step output. sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. """ json_serialized_result = {} try: to_dump = {json_key: obj, "Exception": None} json_serialized_result = json.dumps(to_dump) except TypeError as e: if "is not JSON serializable" in str(e): to_dump = { json_key: None, "Exception": f"The function return ({obj}) is not JSON serializable.", } json_serialized_result = json.dumps(to_dump) S3Uploader.upload_string_as_file_body( body=json_serialized_result, desired_s3_uri=s3_uri, sagemaker_session=sagemaker_session, kms_key=s3_kms_key, )
[docs] def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: """Downloads from S3 and then deserializes data objects. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. Returns : Deserialized python objects. Raises: DeserializationError: when fail to serialize object to bytes. """ metadata = _MetaData.from_json( _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) ) bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
[docs] def serialize_exception_to_s3( exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None ): """Serializes exception with traceback and uploads it to S3. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. exc: Exception to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ pickling_support.install() _upload_payload_and_metadata_to_s3( bytes_to_upload=CloudpickleSerializer.serialize(exc), s3_uri=s3_uri, sagemaker_session=sagemaker_session, s3_kms_key=s3_kms_key, )
def _upload_payload_and_metadata_to_s3( bytes_to_upload: Union[bytes, io.BytesIO], s3_uri: str, sagemaker_session: Session, s3_kms_key, ): """Uploads serialized payload and metadata to s3. Args: bytes_to_upload (bytes): Serialized bytes to upload. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. sagemaker_session (sagemaker.core.helper.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. """ _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) sha256_hash = _compute_hash(bytes_to_upload) _upload_bytes_to_s3( _MetaData(sha256_hash).to_json(), f"{s3_uri}/metadata.json", s3_kms_key, sagemaker_session, )
[docs] def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any: """Downloads from S3 and then deserializes exception. Args: sagemaker_session (sagemaker.core.helper.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. Returns : Deserialized exception with traceback. Raises: DeserializationError: when fail to serialize object to bytes. """ metadata = _MetaData.from_json( _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) ) bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize ) return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
def _upload_bytes_to_s3(b: Union[bytes, io.BytesIO], s3_uri, s3_kms_key, sagemaker_session): """Wrapping s3 uploading with exception translation for remote function.""" try: S3Uploader.upload_bytes(b, s3_uri, kms_key=s3_kms_key, sagemaker_session=sagemaker_session) except Exception as e: raise ServiceError( "Failed to upload serialized bytes to {}: {}".format(s3_uri, repr(e)) ) from e def _read_bytes_from_s3(s3_uri, sagemaker_session): """Wrapping s3 downloading with exception translation for remote function.""" try: return S3Downloader.read_bytes(s3_uri, sagemaker_session=sagemaker_session) except Exception as e: raise ServiceError( "Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e)) ) from e def _compute_hash(buffer: bytes) -> str: """Compute the sha256 hash""" return hashlib.sha256(buffer).hexdigest() def _perform_integrity_check(expected_hash_value: str, buffer: bytes): """Performs integrity checks for serialized code/arguments uploaded to s3. Verifies whether the hash read from s3 matches the hash calculated during remote function execution. """ actual_hash_value = _compute_hash(buffer=buffer) if expected_hash_value != actual_hash_value: raise DeserializationError( "Integrity check for the serialized function or data failed. " "Please restrict access to your S3 bucket" )