mirror of
https://github.com/hwchase17/langchain
synced 2024-11-11 19:11:02 +00:00
1cc7d4c9eb
- **Description:** Simple change of the Class that ContentHandler inherits from. To create an object of type SagemakerEndpointEmbeddings, the property content_handler must be of type EmbeddingsContentHandler not ContentHandlerBase anymore, - **Twitter handle:** @Juanjo_Torres11 Co-authored-by: Bagatur <baskaryan@gmail.com>
137 lines
3.5 KiB
Plaintext
137 lines
3.5 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1f83f273",
|
|
"metadata": {},
|
|
"source": [
|
|
"# SageMaker Endpoint Embeddings\n",
|
|
"\n",
|
|
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
|
|
"\n",
|
|
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
|
|
"\n",
|
|
"Change from\n",
|
|
"\n",
|
|
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
|
|
"\n",
|
|
"to:\n",
|
|
"\n",
|
|
"`return {\"vectors\": sentence_embeddings.tolist()}`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "88d366bd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip3 install langchain boto3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "1e9b926a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Dict, List\n",
|
|
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
|
"from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
|
|
"import json\n",
|
|
"\n",
|
|
"\n",
|
|
"class ContentHandler(EmbeddingsContentHandler):\n",
|
|
" content_type = \"application/json\"\n",
|
|
" accepts = \"application/json\"\n",
|
|
"\n",
|
|
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
|
|
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
|
|
" return input_str.encode(\"utf-8\")\n",
|
|
"\n",
|
|
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
|
|
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
|
" return response_json[\"vectors\"]\n",
|
|
"\n",
|
|
"\n",
|
|
"content_handler = ContentHandler()\n",
|
|
"\n",
|
|
"\n",
|
|
"embeddings = SagemakerEndpointEmbeddings(\n",
|
|
" # endpoint_name=\"endpoint-name\",\n",
|
|
" # credentials_profile_name=\"credentials-profile-name\",\n",
|
|
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
|
|
" region_name=\"us-east-1\",\n",
|
|
" content_handler=content_handler,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fe9797b8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(\"foo\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "76f1b752",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_results = embeddings.embed_documents([\"foo\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fff99b21",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_results"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "aaad49f8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.1"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|