Source code for sagemaker.core.utils.code_injection.codec
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import logging
from dataclasses import asdict
import re
from sagemaker.core.utils.code_injection.shape_dag import SHAPE_DAG
from sagemaker.core.utils.code_injection.constants import (
BASIC_TYPES,
STRUCTURE_TYPE,
LIST_TYPE,
MAP_TYPE,
)
from io import BytesIO
[docs]
def pascal_to_snake(pascal_str):
"""
Converts a PascalCase string to snake_case.
Args:
pascal_str (str): The PascalCase string to be converted.
Returns:
str: The converted snake_case string.
"""
snake_case = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", pascal_str)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case).lower()
[docs]
def deserialize(data, cls) -> object:
"""
Deserialize the given data into an instance of the specified class.
Args:
data (dict): The data to be deserialized.
cls (str or type): The class or class name to deserialize into.
Returns:
object: An instance of the specified class with the deserialized data.
"""
# Convert the keys to snake_case
logging.debug(f"Deserialize: pascal cased data: {data}")
data = {pascal_to_snake(k): v for k, v in data.items()}
logging.debug(f"Deserialize: snake cased data: {data}")
# Get the class from the cls_name string
if type(cls) == str:
cls = globals()[cls]
# Create a new instance of the class
instance = cls(**data)
return instance
[docs]
def snake_to_pascal(snake_str):
"""
Convert a snake_case string to PascalCase.
Args:
snake_str (str): The snake_case string to be converted.
Returns:
str: The PascalCase string.
"""
components = snake_str.split("_")
return "".join(x.title() for x in components[0:])
[docs]
def serialize(data) -> object:
"""
Serializes the given data object into a dictionary.
Args:
data: The data object to be serialized.
Returns:
A dictionary containing the serialized data.
"""
data_dict = asdict(data)
# Convert the keys to pascalCase
data_dict = {snake_to_pascal(k): v for k, v in data_dict.items() if v is not None}
return data_dict
def _evaluate_list_type(raw_list, shape) -> list:
"""
Evaluates a list type based on the given shape.
Args:
raw_list (list): The raw list to be evaluated.
shape (dict): The shape of the list.
Returns:
list: The evaluated list based on the shape.
Raises:
ValueError: If an unhandled list member type is encountered.
"""
_shape_member_type = shape["member_type"]
_shape_member_shape = shape["member_shape"]
_evaluated_list = []
if _shape_member_type in BASIC_TYPES:
# if basic types directly assign list value.
_evaluated_list = raw_list
elif _shape_member_type == STRUCTURE_TYPE:
# if structure type, transform each list item and assign value.
# traverse through response list and evaluate item
for item in raw_list:
_evaluated_item = transform(item, _shape_member_shape)
_evaluated_list.append(_evaluated_item)
elif _shape_member_type == LIST_TYPE:
# if list type, transform each list item and assign value.
# traverse through response list and evaluate item
for item in raw_list:
_list_type_shape = SHAPE_DAG[_shape_member_shape]
_evaluated_item = _evaluate_list_type(item, _list_type_shape)
_evaluated_list.append(_evaluated_item)
elif _shape_member_type == MAP_TYPE:
# if structure type, transform each list item and assign value.
# traverse through response list and evaluate item
for item in raw_list:
_map_type_shape = SHAPE_DAG[_shape_member_shape]
_evaluated_item = _evaluate_map_type(item, _map_type_shape)
_evaluated_list.append(_evaluated_item)
else:
raise ValueError(
f"Unhandled List member type "
f"[{_shape_member_type}] encountered. "
"Needs additional logic for support"
)
return _evaluated_list
def _evaluate_map_type(raw_map, shape) -> dict:
"""
Evaluates a map type based on the given shape.
Args:
raw_map (dict): The raw map to be evaluated.
shape (dict): The shape of the map.
Returns:
dict: The evaluated map.
Raises:
ValueError: If an unhandled map key type or list member type is encountered.
"""
_shape_key_type = shape["key_type"]
_shape_value_type = shape["value_type"]
_shape_value_shape = shape["value_shape"]
if _shape_key_type != "string":
raise ValueError(
f"Unhandled Map key type "
f"[{_shape_key_type}] encountered. "
"Needs additional logic for support"
)
_evaluated_map = {}
if _shape_value_type in BASIC_TYPES:
# if basic types directly assign value.
# Ex. response["map_member"] = {"key":"value"}
_evaluated_map = raw_map
elif _shape_value_type == STRUCTURE_TYPE:
# if structure type loop through and evaluate values
for k, v in raw_map.items():
_evaluated_value = transform(v, _shape_value_shape)
_evaluated_map[k] = _evaluated_value
elif _shape_value_type == LIST_TYPE:
for k, v in raw_map.items():
_list_type_shape = SHAPE_DAG[_shape_value_shape]
evaluated_values = _evaluate_list_type(v, _list_type_shape)
_evaluated_map[k] = evaluated_values
elif _shape_value_type == MAP_TYPE:
for k, v in raw_map.items():
_map_type_shape = SHAPE_DAG[_shape_value_shape]
evaluated_values = _evaluate_map_type(v, _map_type_shape)
_evaluated_map[k] = evaluated_values
else:
raise ValueError(
f"Unhandled List member type "
f"[{_shape_value_type}] encountered. "
"Needs additional logic for support"
)
return _evaluated_map