mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
fdba711d28
Updated `integrations/embeddings`: fixed titles; added links, descriptions Updated `integrations/providers`.
159 lines
4.6 KiB
Plaintext
159 lines
4.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1f83f273",
|
|
"metadata": {},
|
|
"source": [
|
|
"# SageMaker\n",
|
|
"\n",
|
|
"Let's load the `SageMaker Endpoints Embeddings` class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
|
|
"\n",
|
|
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). \n",
|
|
"\n",
|
|
"**Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
|
|
"\n",
|
|
"Change from\n",
|
|
"\n",
|
|
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
|
|
"\n",
|
|
"to:\n",
|
|
"\n",
|
|
"`return {\"vectors\": sentence_embeddings.tolist()}`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "88d366bd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip3 install langchain boto3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "1e9b926a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Dict, List\n",
|
|
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
|
"from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
|
|
"import json\n",
|
|
"\n",
|
|
"\n",
|
|
"class ContentHandler(EmbeddingsContentHandler):\n",
|
|
" content_type = \"application/json\"\n",
|
|
" accepts = \"application/json\"\n",
|
|
"\n",
|
|
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
|
|
" \"\"\"\n",
|
|
" Transforms the input into bytes that can be consumed by SageMaker endpoint.\n",
|
|
" Args:\n",
|
|
" inputs: List of input strings.\n",
|
|
" model_kwargs: Additional keyword arguments to be passed to the endpoint.\n",
|
|
" Returns:\n",
|
|
" The transformed bytes input.\n",
|
|
" \"\"\"\n",
|
|
" # Example: inference.py expects a JSON string with a \"inputs\" key:\n",
|
|
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs}) \n",
|
|
" return input_str.encode(\"utf-8\")\n",
|
|
"\n",
|
|
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
|
|
" \"\"\"\n",
|
|
" Transforms the bytes output from the endpoint into a list of embeddings.\n",
|
|
" Args:\n",
|
|
" output: The bytes output from SageMaker endpoint.\n",
|
|
" Returns:\n",
|
|
" The transformed output - list of embeddings\n",
|
|
" Note:\n",
|
|
" The length of the outer list is the number of input strings.\n",
|
|
" The length of the inner lists is the embedding dimension.\n",
|
|
" \"\"\"\n",
|
|
" # Example: inference.py returns a JSON string with the list of\n",
|
|
" # embeddings in a \"vectors\" key:\n",
|
|
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
|
" return response_json[\"vectors\"]\n",
|
|
"\n",
|
|
"\n",
|
|
"content_handler = ContentHandler()\n",
|
|
"\n",
|
|
"\n",
|
|
"embeddings = SagemakerEndpointEmbeddings(\n",
|
|
" # credentials_profile_name=\"credentials-profile-name\",\n",
|
|
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
|
|
" region_name=\"us-east-1\",\n",
|
|
" content_handler=content_handler,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fe9797b8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(\"foo\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "76f1b752",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_results = embeddings.embed_documents([\"foo\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fff99b21",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_results"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "aaad49f8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|