mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Fix Sagemaker Batch Endpoints (#3249)
Add different typing for @evandiewald 's heplful PR --------- Co-authored-by: Evan Diewald <evandiewald@gmail.com>
This commit is contained in:
parent
7e79f8c136
commit
61d40ba042
@ -9,7 +9,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
|
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
|
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
|
||||||
|
"\n",
|
||||||
|
"Change from\n",
|
||||||
|
"\n",
|
||||||
|
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
|
||||||
|
"\n",
|
||||||
|
"to:\n",
|
||||||
|
"\n",
|
||||||
|
"`return {\"vectors\": sentence_embeddings.tolist()}`."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -29,7 +37,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from typing import Dict\n",
|
"from typing import Dict, List\n",
|
||||||
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
||||||
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
|
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
@ -39,13 +47,13 @@
|
|||||||
" content_type = \"application/json\"\n",
|
" content_type = \"application/json\"\n",
|
||||||
" accepts = \"application/json\"\n",
|
" accepts = \"application/json\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
|
||||||
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
|
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
|
||||||
" return input_str.encode('utf-8')\n",
|
" return input_str.encode('utf-8')\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def transform_output(self, output: bytes) -> str:\n",
|
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
|
||||||
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
||||||
" return response_json[\"embeddings\"]\n",
|
" return response_json[\"vectors\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"content_handler = ContentHandler()\n",
|
"content_handler = ContentHandler()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -7,6 +7,10 @@ from langchain.embeddings.base import Embeddings
|
|||||||
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]):
|
||||||
|
"""Content handler for LLM class."""
|
||||||
|
|
||||||
|
|
||||||
class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||||
|
|
||||||
@ -62,7 +66,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content_handler: ContentHandlerBase
|
content_handler: EmbeddingsContentHandler
|
||||||
"""The content handler class that provides an input and
|
"""The content handler class that provides an input and
|
||||||
output transform functions to handle formats between LLM
|
output transform functions to handle formats between LLM
|
||||||
and the endpoint.
|
and the endpoint.
|
||||||
@ -72,20 +76,20 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
|
||||||
|
|
||||||
class ContentHandler(ContentHandlerBase):
|
class ContentHandler(EmbeddingsContentHandler):
|
||||||
content_type = "application/json"
|
content_type = "application/json"
|
||||||
accepts = "application/json"
|
accepts = "application/json"
|
||||||
|
|
||||||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:
|
||||||
input_str = json.dumps({prompt: prompt, **model_kwargs})
|
input_str = json.dumps({prompts: prompts, **model_kwargs})
|
||||||
return input_str.encode('utf-8')
|
return input_str.encode('utf-8')
|
||||||
|
|
||||||
def transform_output(self, output: bytes) -> str:
|
def transform_output(self, output: bytes) -> List[List[float]]:
|
||||||
response_json = json.loads(output.read().decode("utf-8"))
|
response_json = json.loads(output.read().decode("utf-8"))
|
||||||
return response_json[0]["generated_text"]
|
return response_json["vectors"]
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
model_kwargs: Optional[Dict] = None
|
model_kwargs: Optional[Dict] = None
|
||||||
"""Key word arguments to pass to the model."""
|
"""Key word arguments to pass to the model."""
|
||||||
@ -135,7 +139,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _embedding_func(self, texts: List[str]) -> List[float]:
|
def _embedding_func(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Call out to SageMaker Inference embedding endpoint."""
|
"""Call out to SageMaker Inference embedding endpoint."""
|
||||||
# replace newlines, which can negatively affect performance.
|
# replace newlines, which can negatively affect performance.
|
||||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||||
@ -179,7 +183,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
|
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
|
||||||
for i in range(0, len(texts), _chunk_size):
|
for i in range(0, len(texts), _chunk_size):
|
||||||
response = self._embedding_func(texts[i : i + _chunk_size])
|
response = self._embedding_func(texts[i : i + _chunk_size])
|
||||||
results.append(response)
|
results.extend(response)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
@ -191,4 +195,4 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
return self._embedding_func([text])
|
return self._embedding_func([text])[0]
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
|
||||||
|
|
||||||
from pydantic import Extra, root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
|
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
|
||||||
|
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])
|
||||||
|
|
||||||
class ContentHandlerBase(ABC):
|
|
||||||
|
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
|
||||||
"""A handler class to transform input from LLM to a
|
"""A handler class to transform input from LLM to a
|
||||||
format that SageMaker endpoint expects. Similarily,
|
format that SageMaker endpoint expects. Similarily,
|
||||||
the class also handles transforming output from the
|
the class also handles transforming output from the
|
||||||
@ -39,9 +42,7 @@ class ContentHandlerBase(ABC):
|
|||||||
"""The MIME type of the response data returned from endpoint"""
|
"""The MIME type of the response data returned from endpoint"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform_input(
|
def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes:
|
||||||
self, prompt: Union[str, List[str]], model_kwargs: Dict
|
|
||||||
) -> bytes:
|
|
||||||
"""Transforms the input to a format that model can accept
|
"""Transforms the input to a format that model can accept
|
||||||
as the request Body. Should return bytes or seekable file
|
as the request Body. Should return bytes or seekable file
|
||||||
like object in the format specified in the content_type
|
like object in the format specified in the content_type
|
||||||
@ -49,12 +50,16 @@ class ContentHandlerBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform_output(self, output: bytes) -> Any:
|
def transform_output(self, output: bytes) -> OUTPUT_TYPE:
|
||||||
"""Transforms the output from the model to string that
|
"""Transforms the output from the model to string that
|
||||||
the LLM class expects.
|
the LLM class expects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMContentHandler(ContentHandlerBase[str, str]):
|
||||||
|
"""Content handler for LLM class."""
|
||||||
|
|
||||||
|
|
||||||
class SagemakerEndpoint(LLM):
|
class SagemakerEndpoint(LLM):
|
||||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||||
|
|
||||||
@ -110,7 +115,7 @@ class SagemakerEndpoint(LLM):
|
|||||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content_handler: ContentHandlerBase
|
content_handler: LLMContentHandler
|
||||||
"""The content handler class that provides an input and
|
"""The content handler class that provides an input and
|
||||||
output transform functions to handle formats between LLM
|
output transform functions to handle formats between LLM
|
||||||
and the endpoint.
|
and the endpoint.
|
||||||
@ -120,7 +125,9 @@ class SagemakerEndpoint(LLM):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
class ContentHandler(ContentHandlerBase):
|
from langchain.llms.sagemaker_endpoint import LLMContentHandler
|
||||||
|
|
||||||
|
class ContentHandler(LLMContentHandler):
|
||||||
content_type = "application/json"
|
content_type = "application/json"
|
||||||
accepts = "application/json"
|
accepts = "application/json"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user