Added Streaming Capability to SageMaker LLMs (#10535)

This PR adds the ability to declare a Streaming response in the
SageMaker LLM by leveraging the `invoke_endpoint_with_response_stream`
capability in `boto3`. It is heavily based on the AWS Blog Post
announcement linked
[here](https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/).

It does not add any additional dependencies since it uses the existing
`boto3` version.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/6605/head
Juan Daza 1 year ago committed by GitHub
parent d9670a5945
commit 4236ae3851
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,8 @@
"""Sagemaker InvokeEndpoint API."""
import io
import json
from abc import abstractmethod
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
@ -8,7 +10,66 @@ from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])
class LineIterator:
"""
A helper class for parsing the byte stream input.
The output of the model will be in the following format:
b'{"outputs": [" a"]}\n'
b'{"outputs": [" challenging"]}\n'
b'{"outputs": [" problem"]}\n'
...
While usually each PayloadPart event from the event stream will
contain a byte array with a full json, this is not guaranteed
and some of the json objects may be split acrossPayloadPart events.
For example:
{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character)
within the buffer via the 'scan_lines' function.
It maintains the position of the last read position to ensure
that previous bytes are not exposed again.
For more details see:
https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
"""
def __init__(self, stream: Any) -> None:
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self) -> "LineIterator":
return self
def __next__(self) -> Any:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
# Unknown Event Type
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
@ -151,6 +212,9 @@ class SagemakerEndpoint(LLM):
and the endpoint.
"""
streaming: bool = False
"""Whether to stream the results."""
"""
Example:
.. code-block:: python
@ -264,7 +328,28 @@ class SagemakerEndpoint(LLM):
content_type = self.content_handler.content_type
accepts = self.content_handler.accepts
# send request
if self.streaming and run_manager:
try:
resp = self.client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=body,
ContentType=self.content_handler.content_type,
**_endpoint_kwargs,
)
iterator = LineIterator(resp["Body"])
current_completion: str = ""
for line in iterator:
resp = json.loads(line)
resp_output = resp.get("outputs")[0]
if stop is not None:
# Uses same approach as below
resp_output = enforce_stop_tokens(resp_output, stop)
current_completion += resp_output
run_manager.on_llm_new_token(resp_output)
return current_completion
except Exception as e:
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
else:
try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,

Loading…
Cancel
Save