diff --git a/docs/modules/models/text_embedding/examples/sagemaker-endpoint.ipynb b/docs/modules/models/text_embedding/examples/sagemaker-endpoint.ipynb index 040e4558..b7a0fb7f 100644 --- a/docs/modules/models/text_embedding/examples/sagemaker-endpoint.ipynb +++ b/docs/modules/models/text_embedding/examples/sagemaker-endpoint.ipynb @@ -9,7 +9,15 @@ "\n", "Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n", "\n", - "For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)" + "For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n", + "\n", + "Change from\n", + "\n", + "`return {\"vectors\": sentence_embeddings[0].tolist()}`\n", + "\n", + "to:\n", + "\n", + "`return {\"vectors\": sentence_embeddings.tolist()}`." ] }, { @@ -29,7 +37,7 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Dict\n", + "from typing import Dict, List\n", "from langchain.embeddings import SagemakerEndpointEmbeddings\n", "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", "import json\n", @@ -39,13 +47,13 @@ " content_type = \"application/json\"\n", " accepts = \"application/json\"\n", "\n", - " def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n", - " input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n", + " def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n", + " input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n", " return input_str.encode('utf-8')\n", - " \n", - " def transform_output(self, output: bytes) -> str:\n", + "\n", + " def transform_output(self, output: bytes) -> List[List[float]]:\n", " response_json = json.loads(output.read().decode(\"utf-8\"))\n", - " return response_json[\"embeddings\"]\n", + " return response_json[\"vectors\"]\n", "\n", "content_handler = ContentHandler()\n", "\n", diff --git a/langchain/embeddings/sagemaker_endpoint.py b/langchain/embeddings/sagemaker_endpoint.py index e1371a7d..25ba961d 100644 --- a/langchain/embeddings/sagemaker_endpoint.py +++ b/langchain/embeddings/sagemaker_endpoint.py @@ -7,6 +7,10 @@ from langchain.embeddings.base import Embeddings from langchain.llms.sagemaker_endpoint import ContentHandlerBase +class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]]): + """Content handler for LLM class.""" + + class SagemakerEndpointEmbeddings(BaseModel, Embeddings): """Wrapper around custom Sagemaker Inference Endpoints. @@ -62,7 +66,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings): See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html """ - content_handler: ContentHandlerBase + content_handler: EmbeddingsContentHandler """The content handler class that provides an input and output transform functions to handle formats between LLM and the endpoint. @@ -71,21 +75,21 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings): """ Example: .. code-block:: python - - from langchain.llms.sagemaker_endpoint import ContentHandlerBase - class ContentHandler(ContentHandlerBase): + from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler + + class ContentHandler(EmbeddingsContentHandler): content_type = "application/json" accepts = "application/json" - def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: - input_str = json.dumps({prompt: prompt, **model_kwargs}) + def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes: + input_str = json.dumps({prompts: prompts, **model_kwargs}) return input_str.encode('utf-8') - - def transform_output(self, output: bytes) -> str: + + def transform_output(self, output: bytes) -> List[List[float]]: response_json = json.loads(output.read().decode("utf-8")) - return response_json[0]["generated_text"] - """ + return response_json["vectors"] + """ # noqa: E501 model_kwargs: Optional[Dict] = None """Key word arguments to pass to the model.""" @@ -135,7 +139,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings): ) return values - def _embedding_func(self, texts: List[str]) -> List[float]: + def _embedding_func(self, texts: List[str]) -> List[List[float]]: """Call out to SageMaker Inference embedding endpoint.""" # replace newlines, which can negatively affect performance. texts = list(map(lambda x: x.replace("\n", " "), texts)) @@ -179,7 +183,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings): _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size for i in range(0, len(texts), _chunk_size): response = self._embedding_func(texts[i : i + _chunk_size]) - results.append(response) + results.extend(response) return results def embed_query(self, text: str) -> List[float]: @@ -191,4 +195,4 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - return self._embedding_func([text]) + return self._embedding_func([text])[0] diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py index d9efe51a..34f236b9 100644 --- a/langchain/llms/sagemaker_endpoint.py +++ b/langchain/llms/sagemaker_endpoint.py @@ -1,14 +1,17 @@ """Wrapper around Sagemaker InvokeEndpoint API.""" -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Mapping, Optional, Union +from abc import abstractmethod +from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union from pydantic import Extra, root_validator from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens +INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]]) +OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]]) -class ContentHandlerBase(ABC): + +class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]): """A handler class to transform input from LLM to a format that SageMaker endpoint expects. Similarily, the class also handles transforming output from the @@ -39,9 +42,7 @@ class ContentHandlerBase(ABC): """The MIME type of the response data returned from endpoint""" @abstractmethod - def transform_input( - self, prompt: Union[str, List[str]], model_kwargs: Dict - ) -> bytes: + def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes: """Transforms the input to a format that model can accept as the request Body. Should return bytes or seekable file like object in the format specified in the content_type @@ -49,12 +50,16 @@ class ContentHandlerBase(ABC): """ @abstractmethod - def transform_output(self, output: bytes) -> Any: + def transform_output(self, output: bytes) -> OUTPUT_TYPE: """Transforms the output from the model to string that the LLM class expects. """ +class LLMContentHandler(ContentHandlerBase[str, str]): + """Content handler for LLM class.""" + + class SagemakerEndpoint(LLM): """Wrapper around custom Sagemaker Inference Endpoints. @@ -110,7 +115,7 @@ class SagemakerEndpoint(LLM): See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html """ - content_handler: ContentHandlerBase + content_handler: LLMContentHandler """The content handler class that provides an input and output transform functions to handle formats between LLM and the endpoint. @@ -120,7 +125,9 @@ class SagemakerEndpoint(LLM): Example: .. code-block:: python - class ContentHandler(ContentHandlerBase): + from langchain.llms.sagemaker_endpoint import LLMContentHandler + + class ContentHandler(LLMContentHandler): content_type = "application/json" accepts = "application/json"