mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
132 lines
3.6 KiB
Plaintext
132 lines
3.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# SageMakerEndpoint\n",
|
|
"\n",
|
|
"This notebooks goes over how to use an LLM hosted on a SageMaker endpoint."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip3 install langchain boto3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.docstore.document import Document"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"example_doc_1 = \"\"\"\n",
|
|
"Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n",
|
|
"Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n",
|
|
"Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"docs = [\n",
|
|
" Document(\n",
|
|
" page_content=example_doc_1,\n",
|
|
" )\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Dict\n",
|
|
"\n",
|
|
"from langchain import PromptTemplate, SagemakerEndpoint\n",
|
|
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
|
|
"from langchain.chains.question_answering import load_qa_chain\n",
|
|
"import json\n",
|
|
"\n",
|
|
"query = \"\"\"How long was Elizabeth hospitalized?\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n",
|
|
"\n",
|
|
"{context}\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"Answer:\"\"\"\n",
|
|
"PROMPT = PromptTemplate(\n",
|
|
" template=prompt_template, input_variables=[\"context\", \"question\"]\n",
|
|
")\n",
|
|
"\n",
|
|
"class ContentHandler(ContentHandlerBase):\n",
|
|
" content_type = \"application/json\"\n",
|
|
" accepts = \"application/json\"\n",
|
|
"\n",
|
|
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
|
" input_str = json.dumps({prompt: prompt, **model_kwargs})\n",
|
|
" return input_str.encode('utf-8')\n",
|
|
" \n",
|
|
" def transform_output(self, output: bytes) -> str:\n",
|
|
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
|
" return response_json[0][\"generated_text\"]\n",
|
|
"\n",
|
|
"content_handler = ContentHandler()\n",
|
|
"\n",
|
|
"chain = load_qa_chain(\n",
|
|
" llm=SagemakerEndpoint(\n",
|
|
" endpoint_name=\"endpoint-name\", \n",
|
|
" credentials_profile_name=\"credentials-profile-name\", \n",
|
|
" region_name=\"us-west-2\", \n",
|
|
" model_kwargs={\"temperature\":1e-10},\n",
|
|
" content_handler=content_handler\n",
|
|
" ),\n",
|
|
" prompt=PROMPT\n",
|
|
")\n",
|
|
"\n",
|
|
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.1"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|