mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
171 lines
4.8 KiB
Plaintext
171 lines
4.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# SageMakerEndpoint\n",
|
|
"\n",
|
|
"[Amazon SageMaker](https://aws.amazon.com/sagemaker/) is a system that can build, train, and deploy machine learning (ML) models for any use case with fully managed infrastructure, tools, and workflows.\n",
|
|
"\n",
|
|
"This notebooks goes over how to use an LLM hosted on a `SageMaker endpoint`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip3 install langchain boto3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Set up"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"You have to set up following required parameters of the `SagemakerEndpoint` call:\n",
|
|
"- `endpoint_name`: The name of the endpoint from the deployed Sagemaker model.\n",
|
|
" Must be unique within an AWS Region.\n",
|
|
"- `credentials_profile_name`: The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which\n",
|
|
" has either access keys or role information specified.\n",
|
|
" If not specified, the default credential profile or, if on an EC2 instance,\n",
|
|
" credentials from IMDS will be used.\n",
|
|
" See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Example"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.docstore.document import Document"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"example_doc_1 = \"\"\"\n",
|
|
"Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n",
|
|
"Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n",
|
|
"Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"docs = [\n",
|
|
" Document(\n",
|
|
" page_content=example_doc_1,\n",
|
|
" )\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Dict\n",
|
|
"\n",
|
|
"from langchain.prompts import PromptTemplate\nfrom langchain.llms import SagemakerEndpoint\n",
|
|
"from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
|
|
"from langchain.chains.question_answering import load_qa_chain\n",
|
|
"import json\n",
|
|
"\n",
|
|
"query = \"\"\"How long was Elizabeth hospitalized?\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n",
|
|
"\n",
|
|
"{context}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"Answer:\"\"\"\n",
|
|
"PROMPT = PromptTemplate(\n",
|
|
" template=prompt_template, input_variables=[\"context\", \"question\"]\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"class ContentHandler(LLMContentHandler):\n",
|
|
" content_type = \"application/json\"\n",
|
|
" accepts = \"application/json\"\n",
|
|
"\n",
|
|
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
|
" input_str = json.dumps({prompt: prompt, **model_kwargs})\n",
|
|
" return input_str.encode(\"utf-8\")\n",
|
|
"\n",
|
|
" def transform_output(self, output: bytes) -> str:\n",
|
|
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
|
" return response_json[0][\"generated_text\"]\n",
|
|
"\n",
|
|
"\n",
|
|
"content_handler = ContentHandler()\n",
|
|
"\n",
|
|
"chain = load_qa_chain(\n",
|
|
" llm=SagemakerEndpoint(\n",
|
|
" endpoint_name=\"endpoint-name\",\n",
|
|
" credentials_profile_name=\"credentials-profile-name\",\n",
|
|
" region_name=\"us-west-2\",\n",
|
|
" model_kwargs={\"temperature\": 1e-10},\n",
|
|
" content_handler=content_handler,\n",
|
|
" ),\n",
|
|
" prompt=PROMPT,\n",
|
|
")\n",
|
|
"\n",
|
|
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.6"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|