Fix Sagemaker Batch Endpoints (#3249)

Add different typing for @evandiewald 's heplful PR

---------

Co-authored-by: Evan Diewald <evandiewald@gmail.com>
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent 7e79f8c136
commit 61d40ba042
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,15 @@
"\n",
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
"\n",
"For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
"\n",
"Change from\n",
"\n",
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
"\n",
"to:\n",
"\n",
"`return {\"vectors\": sentence_embeddings.tolist()}`."
]
},
{
@ -29,7 +37,7 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict\n",
"from typing import Dict, List\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
"import json\n",
@ -39,13 +47,13 @@
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
" return input_str.encode('utf-8')\n",
" \n",
" def transform_output(self, output: bytes) -> str:\n",
"\n",
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"embeddings\"]\n",
" return response_json[\"vectors\"]\n",
"\n",
"content_handler = ContentHandler()\n",
"\n",

@ -7,6 +7,10 @@ from langchain.embeddings.base import Embeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]):
"""Content handler for LLM class."""
class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
"""Wrapper around custom Sagemaker Inference Endpoints.
@ -62,7 +66,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""
content_handler: ContentHandlerBase
content_handler: EmbeddingsContentHandler
"""The content handler class that provides an input and
output transform functions to handle formats between LLM
and the endpoint.
@ -71,21 +75,21 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
"""
Example:
.. code-block:: python
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
class ContentHandler(ContentHandlerBase):
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompt: prompt, **model_kwargs})
def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompts: prompts, **model_kwargs})
return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str:
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
"""
return response_json["vectors"]
""" # noqa: E501
model_kwargs: Optional[Dict] = None
"""Key word arguments to pass to the model."""
@ -135,7 +139,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
)
return values
def _embedding_func(self, texts: List[str]) -> List[float]:
def _embedding_func(self, texts: List[str]) -> List[List[float]]:
"""Call out to SageMaker Inference embedding endpoint."""
# replace newlines, which can negatively affect performance.
texts = list(map(lambda x: x.replace("\n", " "), texts))
@ -179,7 +183,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
for i in range(0, len(texts), _chunk_size):
response = self._embedding_func(texts[i : i + _chunk_size])
results.append(response)
results.extend(response)
return results
def embed_query(self, text: str) -> List[float]:
@ -191,4 +195,4 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
return self._embedding_func([text])
return self._embedding_func([text])[0]

@ -1,14 +1,17 @@
"""Wrapper around Sagemaker InvokeEndpoint API."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Union
from abc import abstractmethod
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
from pydantic import Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])
class ContentHandlerBase(ABC):
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
"""A handler class to transform input from LLM to a
format that SageMaker endpoint expects. Similarily,
the class also handles transforming output from the
@ -39,9 +42,7 @@ class ContentHandlerBase(ABC):
"""The MIME type of the response data returned from endpoint"""
@abstractmethod
def transform_input(
self, prompt: Union[str, List[str]], model_kwargs: Dict
) -> bytes:
def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes:
"""Transforms the input to a format that model can accept
as the request Body. Should return bytes or seekable file
like object in the format specified in the content_type
@ -49,12 +50,16 @@ class ContentHandlerBase(ABC):
"""
@abstractmethod
def transform_output(self, output: bytes) -> Any:
def transform_output(self, output: bytes) -> OUTPUT_TYPE:
"""Transforms the output from the model to string that
the LLM class expects.
"""
class LLMContentHandler(ContentHandlerBase[str, str]):
"""Content handler for LLM class."""
class SagemakerEndpoint(LLM):
"""Wrapper around custom Sagemaker Inference Endpoints.
@ -110,7 +115,7 @@ class SagemakerEndpoint(LLM):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""
content_handler: ContentHandlerBase
content_handler: LLMContentHandler
"""The content handler class that provides an input and
output transform functions to handle formats between LLM
and the endpoint.
@ -120,7 +125,9 @@ class SagemakerEndpoint(LLM):
Example:
.. code-block:: python
class ContentHandler(ContentHandlerBase):
from langchain.llms.sagemaker_endpoint import LLMContentHandler
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

Loading…
Cancel
Save