|
|
|
@ -1,6 +1,8 @@
|
|
|
|
|
"""Sagemaker InvokeEndpoint API."""
|
|
|
|
|
import io
|
|
|
|
|
import json
|
|
|
|
|
from abc import abstractmethod
|
|
|
|
|
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
|
|
|
|
|
from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
@ -8,7 +10,66 @@ from langchain.llms.utils import enforce_stop_tokens
|
|
|
|
|
from langchain.pydantic_v1 import Extra, root_validator
|
|
|
|
|
|
|
|
|
|
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
|
|
|
|
|
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])
|
|
|
|
|
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LineIterator:
|
|
|
|
|
"""
|
|
|
|
|
A helper class for parsing the byte stream input.
|
|
|
|
|
|
|
|
|
|
The output of the model will be in the following format:
|
|
|
|
|
|
|
|
|
|
b'{"outputs": [" a"]}\n'
|
|
|
|
|
b'{"outputs": [" challenging"]}\n'
|
|
|
|
|
b'{"outputs": [" problem"]}\n'
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
While usually each PayloadPart event from the event stream will
|
|
|
|
|
contain a byte array with a full json, this is not guaranteed
|
|
|
|
|
and some of the json objects may be split acrossPayloadPart events.
|
|
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
|
|
{'PayloadPart': {'Bytes': b'{"outputs": '}}
|
|
|
|
|
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
This class accounts for this by concatenating bytes written via the 'write' function
|
|
|
|
|
and then exposing a method which will return lines (ending with a '\n' character)
|
|
|
|
|
within the buffer via the 'scan_lines' function.
|
|
|
|
|
It maintains the position of the last read position to ensure
|
|
|
|
|
that previous bytes are not exposed again.
|
|
|
|
|
|
|
|
|
|
For more details see:
|
|
|
|
|
https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, stream: Any) -> None:
|
|
|
|
|
self.byte_iterator = iter(stream)
|
|
|
|
|
self.buffer = io.BytesIO()
|
|
|
|
|
self.read_pos = 0
|
|
|
|
|
|
|
|
|
|
def __iter__(self) -> "LineIterator":
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __next__(self) -> Any:
|
|
|
|
|
while True:
|
|
|
|
|
self.buffer.seek(self.read_pos)
|
|
|
|
|
line = self.buffer.readline()
|
|
|
|
|
if line and line[-1] == ord("\n"):
|
|
|
|
|
self.read_pos += len(line)
|
|
|
|
|
return line[:-1]
|
|
|
|
|
try:
|
|
|
|
|
chunk = next(self.byte_iterator)
|
|
|
|
|
except StopIteration:
|
|
|
|
|
if self.read_pos < self.buffer.getbuffer().nbytes:
|
|
|
|
|
continue
|
|
|
|
|
raise
|
|
|
|
|
if "PayloadPart" not in chunk:
|
|
|
|
|
# Unknown Event Type
|
|
|
|
|
continue
|
|
|
|
|
self.buffer.seek(0, io.SEEK_END)
|
|
|
|
|
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
|
|
|
|
@ -151,6 +212,9 @@ class SagemakerEndpoint(LLM):
|
|
|
|
|
and the endpoint.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
streaming: bool = False
|
|
|
|
|
"""Whether to stream the results."""
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
@ -264,7 +328,28 @@ class SagemakerEndpoint(LLM):
|
|
|
|
|
content_type = self.content_handler.content_type
|
|
|
|
|
accepts = self.content_handler.accepts
|
|
|
|
|
|
|
|
|
|
# send request
|
|
|
|
|
if self.streaming and run_manager:
|
|
|
|
|
try:
|
|
|
|
|
resp = self.client.invoke_endpoint_with_response_stream(
|
|
|
|
|
EndpointName=self.endpoint_name,
|
|
|
|
|
Body=body,
|
|
|
|
|
ContentType=self.content_handler.content_type,
|
|
|
|
|
**_endpoint_kwargs,
|
|
|
|
|
)
|
|
|
|
|
iterator = LineIterator(resp["Body"])
|
|
|
|
|
current_completion: str = ""
|
|
|
|
|
for line in iterator:
|
|
|
|
|
resp = json.loads(line)
|
|
|
|
|
resp_output = resp.get("outputs")[0]
|
|
|
|
|
if stop is not None:
|
|
|
|
|
# Uses same approach as below
|
|
|
|
|
resp_output = enforce_stop_tokens(resp_output, stop)
|
|
|
|
|
current_completion += resp_output
|
|
|
|
|
run_manager.on_llm_new_token(resp_output)
|
|
|
|
|
return current_completion
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
response = self.client.invoke_endpoint(
|
|
|
|
|
EndpointName=self.endpoint_name,
|
|
|
|
|