langchain/docs/extras/integrations/text_embedding/sagemaker-endpoint.ipynb

159 lines
4.6 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "1f83f273",
"metadata": {},
"source": [
"# SageMaker\n",
"\n",
"Let's load the `SageMaker Endpoints Embeddings` class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
"\n",
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). \n",
"\n",
"**Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
"\n",
"Change from\n",
"\n",
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
"\n",
"to:\n",
"\n",
"`return {\"vectors\": sentence_embeddings.tolist()}`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88d366bd",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install langchain boto3"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1e9b926a",
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, List\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
"import json\n",
"\n",
"\n",
"class ContentHandler(EmbeddingsContentHandler):\n",
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
" \"\"\"\n",
" Transforms the input into bytes that can be consumed by SageMaker endpoint.\n",
" Args:\n",
" inputs: List of input strings.\n",
" model_kwargs: Additional keyword arguments to be passed to the endpoint.\n",
" Returns:\n",
" The transformed bytes input.\n",
" \"\"\"\n",
" # Example: inference.py expects a JSON string with a \"inputs\" key:\n",
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs}) \n",
" return input_str.encode(\"utf-8\")\n",
"\n",
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
" \"\"\"\n",
" Transforms the bytes output from the endpoint into a list of embeddings.\n",
" Args:\n",
" output: The bytes output from SageMaker endpoint.\n",
" Returns:\n",
" The transformed output - list of embeddings\n",
" Note:\n",
" The length of the outer list is the number of input strings.\n",
" The length of the inner lists is the embedding dimension.\n",
" \"\"\"\n",
" # Example: inference.py returns a JSON string with the list of\n",
" # embeddings in a \"vectors\" key:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"vectors\"]\n",
"\n",
"\n",
"content_handler = ContentHandler()\n",
"\n",
"\n",
"embeddings = SagemakerEndpointEmbeddings(\n",
" # credentials_profile_name=\"credentials-profile-name\",\n",
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
" region_name=\"us-east-1\",\n",
" content_handler=content_handler,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe9797b8",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "76f1b752",
"metadata": {},
"outputs": [],
"source": [
"doc_results = embeddings.embed_documents([\"foo\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fff99b21",
"metadata": {},
"outputs": [],
"source": [
"doc_results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aaad49f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"vscode": {
"interpreter": {
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}