Fix Sagemaker Batch Endpoints (#3249)

Add different typing for @evandiewald 's heplful PR

---------

Co-authored-by: Evan Diewald <evandiewald@gmail.com>
This commit is contained in:
Zander Chase 2023-04-22 08:49:51 -07:00 committed by GitHub
parent 7e79f8c136
commit 61d40ba042
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 29 deletions

View File

@ -9,7 +9,15 @@
"\n", "\n",
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n", "Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
"\n", "\n",
"For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)" "For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
"\n",
"Change from\n",
"\n",
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
"\n",
"to:\n",
"\n",
"`return {\"vectors\": sentence_embeddings.tolist()}`."
] ]
}, },
{ {
@ -29,7 +37,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from typing import Dict\n", "from typing import Dict, List\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n", "from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
"import json\n", "import json\n",
@ -39,13 +47,13 @@
" content_type = \"application/json\"\n", " content_type = \"application/json\"\n",
" accepts = \"application/json\"\n", " accepts = \"application/json\"\n",
"\n", "\n",
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n", " def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n", " input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
" return input_str.encode('utf-8')\n", " return input_str.encode('utf-8')\n",
"\n", "\n",
" def transform_output(self, output: bytes) -> str:\n", " def transform_output(self, output: bytes) -> List[List[float]]:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n", " response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"embeddings\"]\n", " return response_json[\"vectors\"]\n",
"\n", "\n",
"content_handler = ContentHandler()\n", "content_handler = ContentHandler()\n",
"\n", "\n",

View File

@ -7,6 +7,10 @@ from langchain.embeddings.base import Embeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase from langchain.llms.sagemaker_endpoint import ContentHandlerBase
class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]):
"""Content handler for LLM class."""
class SagemakerEndpointEmbeddings(BaseModel, Embeddings): class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
"""Wrapper around custom Sagemaker Inference Endpoints. """Wrapper around custom Sagemaker Inference Endpoints.
@ -62,7 +66,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
""" """
content_handler: ContentHandlerBase content_handler: EmbeddingsContentHandler
"""The content handler class that provides an input and """The content handler class that provides an input and
output transform functions to handle formats between LLM output transform functions to handle formats between LLM
and the endpoint. and the endpoint.
@ -72,20 +76,20 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
Example: Example:
.. code-block:: python .. code-block:: python
from langchain.llms.sagemaker_endpoint import ContentHandlerBase from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(ContentHandlerBase): class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json" content_type = "application/json"
accepts = "application/json" accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompt: prompt, **model_kwargs}) input_str = json.dumps({prompts: prompts, **model_kwargs})
return input_str.encode('utf-8') return input_str.encode('utf-8')
def transform_output(self, output: bytes) -> str: def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8")) response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"] return response_json["vectors"]
""" """ # noqa: E501
model_kwargs: Optional[Dict] = None model_kwargs: Optional[Dict] = None
"""Key word arguments to pass to the model.""" """Key word arguments to pass to the model."""
@ -135,7 +139,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
) )
return values return values
def _embedding_func(self, texts: List[str]) -> List[float]: def _embedding_func(self, texts: List[str]) -> List[List[float]]:
"""Call out to SageMaker Inference embedding endpoint.""" """Call out to SageMaker Inference embedding endpoint."""
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
texts = list(map(lambda x: x.replace("\n", " "), texts)) texts = list(map(lambda x: x.replace("\n", " "), texts))
@ -179,7 +183,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
for i in range(0, len(texts), _chunk_size): for i in range(0, len(texts), _chunk_size):
response = self._embedding_func(texts[i : i + _chunk_size]) response = self._embedding_func(texts[i : i + _chunk_size])
results.append(response) results.extend(response)
return results return results
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
@ -191,4 +195,4 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
Returns: Returns:
Embeddings for the text. Embeddings for the text.
""" """
return self._embedding_func([text]) return self._embedding_func([text])[0]

View File

@ -1,14 +1,17 @@
"""Wrapper around Sagemaker InvokeEndpoint API.""" """Wrapper around Sagemaker InvokeEndpoint API."""
from abc import ABC, abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Union from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
from pydantic import Extra, root_validator from pydantic import Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])
class ContentHandlerBase(ABC):
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
"""A handler class to transform input from LLM to a """A handler class to transform input from LLM to a
format that SageMaker endpoint expects. Similarily, format that SageMaker endpoint expects. Similarily,
the class also handles transforming output from the the class also handles transforming output from the
@ -39,9 +42,7 @@ class ContentHandlerBase(ABC):
"""The MIME type of the response data returned from endpoint""" """The MIME type of the response data returned from endpoint"""
@abstractmethod @abstractmethod
def transform_input( def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes:
self, prompt: Union[str, List[str]], model_kwargs: Dict
) -> bytes:
"""Transforms the input to a format that model can accept """Transforms the input to a format that model can accept
as the request Body. Should return bytes or seekable file as the request Body. Should return bytes or seekable file
like object in the format specified in the content_type like object in the format specified in the content_type
@ -49,12 +50,16 @@ class ContentHandlerBase(ABC):
""" """
@abstractmethod @abstractmethod
def transform_output(self, output: bytes) -> Any: def transform_output(self, output: bytes) -> OUTPUT_TYPE:
"""Transforms the output from the model to string that """Transforms the output from the model to string that
the LLM class expects. the LLM class expects.
""" """
class LLMContentHandler(ContentHandlerBase[str, str]):
"""Content handler for LLM class."""
class SagemakerEndpoint(LLM): class SagemakerEndpoint(LLM):
"""Wrapper around custom Sagemaker Inference Endpoints. """Wrapper around custom Sagemaker Inference Endpoints.
@ -110,7 +115,7 @@ class SagemakerEndpoint(LLM):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
""" """
content_handler: ContentHandlerBase content_handler: LLMContentHandler
"""The content handler class that provides an input and """The content handler class that provides an input and
output transform functions to handle formats between LLM output transform functions to handle formats between LLM
and the endpoint. and the endpoint.
@ -120,7 +125,9 @@ class SagemakerEndpoint(LLM):
Example: Example:
.. code-block:: python .. code-block:: python
class ContentHandler(ContentHandlerBase): from langchain.llms.sagemaker_endpoint import LLMContentHandler
class ContentHandler(LLMContentHandler):
content_type = "application/json" content_type = "application/json"
accepts = "application/json" accepts = "application/json"