langchain/docs/modules/models/text_embedding/examples/sagemaker-endpoint.ipynb

128 lines
3.1 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "1f83f273",
"metadata": {},
"source": [
"# SageMaker Endpoint Embeddings\n",
"\n",
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
"\n",
"For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88d366bd",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install langchain boto3"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1e9b926a",
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
"import json\n",
"\n",
"\n",
"class ContentHandler(ContentHandlerBase):\n",
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
" return input_str.encode('utf-8')\n",
" \n",
" def transform_output(self, output: bytes) -> str:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"embeddings\"]\n",
"\n",
"content_handler = ContentHandler()\n",
"\n",
"\n",
"embeddings = SagemakerEndpointEmbeddings(\n",
" # endpoint_name=\"endpoint-name\", \n",
" # credentials_profile_name=\"credentials-profile-name\", \n",
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n",
" region_name=\"us-east-1\", \n",
" content_handler=content_handler\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe9797b8",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "76f1b752",
"metadata": {},
"outputs": [],
"source": [
"doc_results = embeddings.embed_documents([\"foo\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fff99b21",
"metadata": {},
"outputs": [],
"source": [
"doc_results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aaad49f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
},
"vscode": {
"interpreter": {
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}