forked from Archives/langchain
Fix Sagemaker Batch Endpoints (#3249)
Add different typing for @evandiewald 's heplful PR --------- Co-authored-by: Evan Diewald <evandiewald@gmail.com>
This commit is contained in:
parent
7e79f8c136
commit
61d40ba042
@ -9,7 +9,15 @@
|
||||
"\n",
|
||||
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
|
||||
"\n",
|
||||
"For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
|
||||
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
|
||||
"\n",
|
||||
"Change from\n",
|
||||
"\n",
|
||||
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
|
||||
"\n",
|
||||
"to:\n",
|
||||
"\n",
|
||||
"`return {\"vectors\": sentence_embeddings.tolist()}`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -29,7 +37,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Dict\n",
|
||||
"from typing import Dict, List\n",
|
||||
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
||||
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
|
||||
"import json\n",
|
||||
@ -39,13 +47,13 @@
|
||||
" content_type = \"application/json\"\n",
|
||||
" accepts = \"application/json\"\n",
|
||||
"\n",
|
||||
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
||||
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
|
||||
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
|
||||
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
|
||||
" return input_str.encode('utf-8')\n",
|
||||
" \n",
|
||||
" def transform_output(self, output: bytes) -> str:\n",
|
||||
"\n",
|
||||
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
|
||||
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
||||
" return response_json[\"embeddings\"]\n",
|
||||
" return response_json[\"vectors\"]\n",
|
||||
"\n",
|
||||
"content_handler = ContentHandler()\n",
|
||||
"\n",
|
||||
|
@ -7,6 +7,10 @@ from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
||||
|
||||
|
||||
class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]):
|
||||
"""Content handler for LLM class."""
|
||||
|
||||
|
||||
class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||
|
||||
@ -62,7 +66,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
"""
|
||||
|
||||
content_handler: ContentHandlerBase
|
||||
content_handler: EmbeddingsContentHandler
|
||||
"""The content handler class that provides an input and
|
||||
output transform functions to handle formats between LLM
|
||||
and the endpoint.
|
||||
@ -72,20 +76,20 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
||||
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
|
||||
|
||||
class ContentHandler(ContentHandlerBase):
|
||||
class ContentHandler(EmbeddingsContentHandler):
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
|
||||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps({prompt: prompt, **model_kwargs})
|
||||
def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps({prompts: prompts, **model_kwargs})
|
||||
return input_str.encode('utf-8')
|
||||
|
||||
def transform_output(self, output: bytes) -> str:
|
||||
def transform_output(self, output: bytes) -> List[List[float]]:
|
||||
response_json = json.loads(output.read().decode("utf-8"))
|
||||
return response_json[0]["generated_text"]
|
||||
"""
|
||||
return response_json["vectors"]
|
||||
""" # noqa: E501
|
||||
|
||||
model_kwargs: Optional[Dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
@ -135,7 +139,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def _embedding_func(self, texts: List[str]) -> List[float]:
|
||||
def _embedding_func(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to SageMaker Inference embedding endpoint."""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
@ -179,7 +183,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
|
||||
for i in range(0, len(texts), _chunk_size):
|
||||
response = self._embedding_func(texts[i : i + _chunk_size])
|
||||
results.append(response)
|
||||
results.extend(response)
|
||||
return results
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@ -191,4 +195,4 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self._embedding_func([text])
|
||||
return self._embedding_func([text])[0]
|
||||
|
@ -1,14 +1,17 @@
|
||||
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
|
||||
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])
|
||||
|
||||
class ContentHandlerBase(ABC):
|
||||
|
||||
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
|
||||
"""A handler class to transform input from LLM to a
|
||||
format that SageMaker endpoint expects. Similarily,
|
||||
the class also handles transforming output from the
|
||||
@ -39,9 +42,7 @@ class ContentHandlerBase(ABC):
|
||||
"""The MIME type of the response data returned from endpoint"""
|
||||
|
||||
@abstractmethod
|
||||
def transform_input(
|
||||
self, prompt: Union[str, List[str]], model_kwargs: Dict
|
||||
) -> bytes:
|
||||
def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes:
|
||||
"""Transforms the input to a format that model can accept
|
||||
as the request Body. Should return bytes or seekable file
|
||||
like object in the format specified in the content_type
|
||||
@ -49,12 +50,16 @@ class ContentHandlerBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def transform_output(self, output: bytes) -> Any:
|
||||
def transform_output(self, output: bytes) -> OUTPUT_TYPE:
|
||||
"""Transforms the output from the model to string that
|
||||
the LLM class expects.
|
||||
"""
|
||||
|
||||
|
||||
class LLMContentHandler(ContentHandlerBase[str, str]):
|
||||
"""Content handler for LLM class."""
|
||||
|
||||
|
||||
class SagemakerEndpoint(LLM):
|
||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||
|
||||
@ -110,7 +115,7 @@ class SagemakerEndpoint(LLM):
|
||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
"""
|
||||
|
||||
content_handler: ContentHandlerBase
|
||||
content_handler: LLMContentHandler
|
||||
"""The content handler class that provides an input and
|
||||
output transform functions to handle formats between LLM
|
||||
and the endpoint.
|
||||
@ -120,7 +125,9 @@ class SagemakerEndpoint(LLM):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class ContentHandler(ContentHandlerBase):
|
||||
from langchain.llms.sagemaker_endpoint import LLMContentHandler
|
||||
|
||||
class ContentHandler(LLMContentHandler):
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user