mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
7cf2d2759d
Added missed docstrings. Format docstings to the consistent form.
371 lines
13 KiB
Python
371 lines
13 KiB
Python
"""Sagemaker InvokeEndpoint API."""
|
|
import io
|
|
import json
|
|
from abc import abstractmethod
|
|
from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.llms import LLM
|
|
from langchain_core.pydantic_v1 import Extra, root_validator
|
|
|
|
from langchain_community.llms.utils import enforce_stop_tokens
|
|
|
|
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
|
|
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])
|
|
|
|
|
|
class LineIterator:
|
|
"""Parse the byte stream input.
|
|
|
|
The output of the model will be in the following format:
|
|
|
|
b'{"outputs": [" a"]}\n'
|
|
b'{"outputs": [" challenging"]}\n'
|
|
b'{"outputs": [" problem"]}\n'
|
|
...
|
|
|
|
While usually each PayloadPart event from the event stream will
|
|
contain a byte array with a full json, this is not guaranteed
|
|
and some of the json objects may be split acrossPayloadPart events.
|
|
|
|
For example:
|
|
|
|
{'PayloadPart': {'Bytes': b'{"outputs": '}}
|
|
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
|
|
|
|
|
|
This class accounts for this by concatenating bytes written via the 'write' function
|
|
and then exposing a method which will return lines (ending with a '\n' character)
|
|
within the buffer via the 'scan_lines' function.
|
|
It maintains the position of the last read position to ensure
|
|
that previous bytes are not exposed again.
|
|
|
|
For more details see:
|
|
https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
|
|
"""
|
|
|
|
def __init__(self, stream: Any) -> None:
|
|
self.byte_iterator = iter(stream)
|
|
self.buffer = io.BytesIO()
|
|
self.read_pos = 0
|
|
|
|
def __iter__(self) -> "LineIterator":
|
|
return self
|
|
|
|
def __next__(self) -> Any:
|
|
while True:
|
|
self.buffer.seek(self.read_pos)
|
|
line = self.buffer.readline()
|
|
if line and line[-1] == ord("\n"):
|
|
self.read_pos += len(line)
|
|
return line[:-1]
|
|
try:
|
|
chunk = next(self.byte_iterator)
|
|
except StopIteration:
|
|
if self.read_pos < self.buffer.getbuffer().nbytes:
|
|
continue
|
|
raise
|
|
if "PayloadPart" not in chunk:
|
|
# Unknown Event Type
|
|
continue
|
|
self.buffer.seek(0, io.SEEK_END)
|
|
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
|
|
|
|
|
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
|
|
"""Handler class to transform input from LLM to a
|
|
format that SageMaker endpoint expects.
|
|
|
|
Similarly, the class handles transforming output from the
|
|
SageMaker endpoint to a format that LLM class expects.
|
|
"""
|
|
|
|
"""
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
class ContentHandler(ContentHandlerBase):
|
|
content_type = "application/json"
|
|
accepts = "application/json"
|
|
|
|
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
|
input_str = json.dumps({prompt: prompt, **model_kwargs})
|
|
return input_str.encode('utf-8')
|
|
|
|
def transform_output(self, output: bytes) -> str:
|
|
response_json = json.loads(output.read().decode("utf-8"))
|
|
return response_json[0]["generated_text"]
|
|
"""
|
|
|
|
content_type: Optional[str] = "text/plain"
|
|
"""The MIME type of the input data passed to endpoint"""
|
|
|
|
accepts: Optional[str] = "text/plain"
|
|
"""The MIME type of the response data returned from endpoint"""
|
|
|
|
@abstractmethod
|
|
def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes:
|
|
"""Transforms the input to a format that model can accept
|
|
as the request Body. Should return bytes or seekable file
|
|
like object in the format specified in the content_type
|
|
request header.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def transform_output(self, output: bytes) -> OUTPUT_TYPE:
|
|
"""Transforms the output from the model to string that
|
|
the LLM class expects.
|
|
"""
|
|
|
|
|
|
class LLMContentHandler(ContentHandlerBase[str, str]):
|
|
"""Content handler for LLM class."""
|
|
|
|
|
|
class SagemakerEndpoint(LLM):
|
|
"""Sagemaker Inference Endpoint models.
|
|
|
|
To use, you must supply the endpoint name from your deployed
|
|
Sagemaker model & the region where it is deployed.
|
|
|
|
To authenticate, the AWS client uses the following methods to
|
|
automatically load credentials:
|
|
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
|
|
|
If a specific credential profile should be used, you must pass
|
|
the name of the profile from the ~/.aws/credentials file that is to be used.
|
|
|
|
Make sure the credentials / roles used have the required policies to
|
|
access the Sagemaker endpoint.
|
|
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
|
"""
|
|
|
|
"""
|
|
Args:
|
|
|
|
region_name: The aws region e.g., `us-west-2`.
|
|
Fallsback to AWS_DEFAULT_REGION env variable
|
|
or region specified in ~/.aws/config.
|
|
|
|
credentials_profile_name: The name of the profile in the ~/.aws/credentials
|
|
or ~/.aws/config files, which has either access keys or role information
|
|
specified. If not specified, the default credential profile or, if on an
|
|
EC2 instance, credentials from IMDS will be used.
|
|
|
|
client: boto3 client for Sagemaker Endpoint
|
|
|
|
content_handler: Implementation for model specific LLMContentHandler
|
|
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.llms import SagemakerEndpoint
|
|
endpoint_name = (
|
|
"my-endpoint-name"
|
|
)
|
|
region_name = (
|
|
"us-west-2"
|
|
)
|
|
credentials_profile_name = (
|
|
"default"
|
|
)
|
|
se = SagemakerEndpoint(
|
|
endpoint_name=endpoint_name,
|
|
region_name=region_name,
|
|
credentials_profile_name=credentials_profile_name
|
|
)
|
|
|
|
#Use with boto3 client
|
|
client = boto3.client(
|
|
"sagemaker-runtime",
|
|
region_name=region_name
|
|
)
|
|
|
|
se = SagemakerEndpoint(
|
|
endpoint_name=endpoint_name,
|
|
client=client
|
|
)
|
|
|
|
"""
|
|
client: Any = None
|
|
"""Boto3 client for sagemaker runtime"""
|
|
|
|
endpoint_name: str = ""
|
|
"""The name of the endpoint from the deployed Sagemaker model.
|
|
Must be unique within an AWS Region."""
|
|
|
|
region_name: str = ""
|
|
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
|
|
|
|
credentials_profile_name: Optional[str] = None
|
|
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
|
has either access keys or role information specified.
|
|
If not specified, the default credential profile or, if on an EC2 instance,
|
|
credentials from IMDS will be used.
|
|
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
|
"""
|
|
|
|
content_handler: LLMContentHandler
|
|
"""The content handler class that provides an input and
|
|
output transform functions to handle formats between LLM
|
|
and the endpoint.
|
|
"""
|
|
|
|
streaming: bool = False
|
|
"""Whether to stream the results."""
|
|
|
|
"""
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
|
|
|
|
class ContentHandler(LLMContentHandler):
|
|
content_type = "application/json"
|
|
accepts = "application/json"
|
|
|
|
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
|
input_str = json.dumps({prompt: prompt, **model_kwargs})
|
|
return input_str.encode('utf-8')
|
|
|
|
def transform_output(self, output: bytes) -> str:
|
|
response_json = json.loads(output.read().decode("utf-8"))
|
|
return response_json[0]["generated_text"]
|
|
"""
|
|
|
|
model_kwargs: Optional[Dict] = None
|
|
"""Keyword arguments to pass to the model."""
|
|
|
|
endpoint_kwargs: Optional[Dict] = None
|
|
"""Optional attributes passed to the invoke_endpoint
|
|
function. See `boto3`_. docs for more info.
|
|
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
|
|
"""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Dont do anything if client provided externally"""
|
|
if values.get("client") is not None:
|
|
return values
|
|
|
|
"""Validate that AWS credentials to and python package exists in environment."""
|
|
try:
|
|
import boto3
|
|
|
|
try:
|
|
if values["credentials_profile_name"] is not None:
|
|
session = boto3.Session(
|
|
profile_name=values["credentials_profile_name"]
|
|
)
|
|
else:
|
|
# use default credentials
|
|
session = boto3.Session()
|
|
|
|
values["client"] = session.client(
|
|
"sagemaker-runtime", region_name=values["region_name"]
|
|
)
|
|
|
|
except Exception as e:
|
|
raise ValueError(
|
|
"Could not load credentials to authenticate with AWS client. "
|
|
"Please check that credentials in the specified "
|
|
"profile name are valid."
|
|
) from e
|
|
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import boto3 python package. "
|
|
"Please install it with `pip install boto3`."
|
|
)
|
|
return values
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
return {
|
|
**{"endpoint_name": self.endpoint_name},
|
|
**{"model_kwargs": _model_kwargs},
|
|
}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "sagemaker_endpoint"
|
|
|
|
def _call(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Call out to Sagemaker inference endpoint.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
response = se("Tell me a joke.")
|
|
"""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
_model_kwargs = {**_model_kwargs, **kwargs}
|
|
_endpoint_kwargs = self.endpoint_kwargs or {}
|
|
|
|
body = self.content_handler.transform_input(prompt, _model_kwargs)
|
|
content_type = self.content_handler.content_type
|
|
accepts = self.content_handler.accepts
|
|
|
|
if self.streaming and run_manager:
|
|
try:
|
|
resp = self.client.invoke_endpoint_with_response_stream(
|
|
EndpointName=self.endpoint_name,
|
|
Body=body,
|
|
ContentType=self.content_handler.content_type,
|
|
**_endpoint_kwargs,
|
|
)
|
|
iterator = LineIterator(resp["Body"])
|
|
current_completion: str = ""
|
|
for line in iterator:
|
|
resp = json.loads(line)
|
|
resp_output = resp.get("outputs")[0]
|
|
if stop is not None:
|
|
# Uses same approach as below
|
|
resp_output = enforce_stop_tokens(resp_output, stop)
|
|
current_completion += resp_output
|
|
run_manager.on_llm_new_token(resp_output)
|
|
return current_completion
|
|
except Exception as e:
|
|
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
|
|
else:
|
|
try:
|
|
response = self.client.invoke_endpoint(
|
|
EndpointName=self.endpoint_name,
|
|
Body=body,
|
|
ContentType=content_type,
|
|
Accept=accepts,
|
|
**_endpoint_kwargs,
|
|
)
|
|
except Exception as e:
|
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
|
|
|
text = self.content_handler.transform_output(response["Body"])
|
|
if stop is not None:
|
|
# This is a bit hacky, but I can't figure out a better way to enforce
|
|
# stop tokens when making calls to the sagemaker endpoint.
|
|
text = enforce_stop_tokens(text, stop)
|
|
|
|
return text
|