Source code for sagemaker.core.iterators

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Implements iterators for deserializing data returned from an inference streaming endpoint."""
from __future__ import absolute_import

from abc import ABC, abstractmethod
import io

from sagemaker.core.exceptions import ModelStreamError, InternalStreamFailure
from sagemaker.core.common_utils import _MAX_BUFFER_SIZE


[docs] def handle_stream_errors(chunk): """Handle API Response errors within `invoke_endpoint_with_response_stream` API if any. Args: chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream` response object. Raises: ModelStreamError: If `ModelStreamError` error is detected in a chunk of `botocore.eventstream.EventStream` response object. InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of `botocore.eventstream.EventStream` response object. """ if "ModelStreamError" in chunk: raise ModelStreamError( chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"] ) if "InternalStreamFailure" in chunk: raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"])
[docs] class BaseIterator(ABC): """Abstract base class for Inference Streaming iterators. Provides a skeleton for customization requiring the overriding of iterator methods __iter__ and __next__. Tenets of iterator class for Streaming Inference API Response (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ sagemaker-runtime/client/invoke_endpoint_with_response_stream.html): 1. Needs to accept an botocore.eventstream.EventStream response. 2. Needs to implement logic in __next__ to: 2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream. While doing so parse the response_chunk["PayloadPart"]["Bytes"]. 2.2. If PayloadPart not in EventStream response, handle Errors [Recommended to use `iterators.handle_stream_errors` method]. """ def __init__(self, event_stream): """Initialises a Iterator object to help parse the byte event stream input. Args: event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. """ self.event_stream = event_stream @abstractmethod def __iter__(self): """Abstract method, returns an iterator object itself""" return self @abstractmethod def __next__(self): """Abstract method, is responsible for returning the next element in the iteration"""
[docs] class ByteIterator(BaseIterator): """A helper class for parsing the byte Event Stream input to provide Byte iteration.""" def __init__(self, event_stream): """Initialises a BytesIterator Iterator object Args: event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. """ super().__init__(event_stream) self.byte_iterator = iter(event_stream) def __iter__(self): """Returns an iterator object itself, which allows the object to be iterated. Returns: iter : object An iterator object representing the iterable. """ return self def __next__(self): """Returns the next chunk of Byte directly.""" # Even with "while True" loop the function still behaves like a generator # and sends the next new byte chunk. while True: chunk = next(self.byte_iterator) if "PayloadPart" not in chunk: # handle API response errors and force terminate. handle_stream_errors(chunk) # print and move on to next response byte print("Unknown event type:" + chunk) continue return chunk["PayloadPart"]["Bytes"]
[docs] class LineIterator(BaseIterator): """A helper class for parsing the byte Event Stream input to provide Line iteration.""" def __init__(self, event_stream): """Initialises a LineIterator Iterator object Args: event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated. """ super().__init__(event_stream) self.byte_iterator = iter(self.event_stream) self.buffer = io.BytesIO() self.read_pos = 0 def __iter__(self): """Returns an iterator object itself, which allows the object to be iterated. Returns: iter : object An iterator object representing the iterable. """ return self def __next__(self): r"""Returns the next Line for an Line iterable. The output of the event stream will be in the following format: ``` b'{"outputs": [" a"]}\n' b'{"outputs": [" challenging"]}\n' b'{"outputs": [" problem"]}\n' ... ``` While usually each PayloadPart event from the event stream will contain a byte array with a full json, this is not guaranteed and some of the json objects may be split across PayloadPart events. For example: ``` {'PayloadPart': {'Bytes': b'{"outputs": '}} {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} ``` This class accounts for this by concatenating bytes written via the 'write' function and then exposing a method which will return lines (ending with a '\n' character) within the buffer via the 'scan_lines' function. It maintains the position of the last read position to ensure that previous bytes are not exposed again. Returns: str: Read and return one line from the event stream. """ # Even with "while True" loop the function still behaves like a generator # and sends the next new concatenated line while True: self.buffer.seek(self.read_pos) line = self.buffer.readline() if line and line[-1] == ord("\n"): self.read_pos += len(line) return line[:-1] try: chunk = next(self.byte_iterator) except StopIteration: if self.read_pos < self.buffer.getbuffer().nbytes: continue raise if "PayloadPart" not in chunk: # handle API response errors and force terminate. handle_stream_errors(chunk) # print and move on to next response byte print("Unknown event type:" + chunk) continue # Check buffer size before writing to prevent unbounded memory consumption chunk_size = len(chunk["PayloadPart"]["Bytes"]) current_size = self.buffer.getbuffer().nbytes if current_size + chunk_size > _MAX_BUFFER_SIZE: raise RuntimeError( f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. " f"No newline found in stream." ) self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"])