{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SageMakerEndpoint\n", "\n", "This notebooks goes over how to use an LLM hosted on a SageMaker endpoint." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip3 install langchain boto3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain.docstore.document import Document" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "example_doc_1 = \"\"\"\n", "Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n", "Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n", "Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n", "\"\"\"\n", "\n", "docs = [\n", " Document(\n", " page_content=example_doc_1,\n", " )\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from typing import Dict\n", "\n", "from langchain import PromptTemplate, SagemakerEndpoint\n", "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", "from langchain.chains.question_answering import load_qa_chain\n", "import json\n", "\n", "query = \"\"\"How long was Elizabeth hospitalized?\n", "\"\"\"\n", "\n", "prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n", "\n", "{context}\n", "\n", "Question: {question}\n", "Answer:\"\"\"\n", "PROMPT = PromptTemplate(\n", " template=prompt_template, input_variables=[\"context\", \"question\"]\n", ")\n", "\n", "class ContentHandler(ContentHandlerBase):\n", " content_type = \"application/json\"\n", " accepts = \"application/json\"\n", "\n", " def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n", " input_str = json.dumps({prompt: prompt, **model_kwargs})\n", " return input_str.encode('utf-8')\n", " \n", " def transform_output(self, output: bytes) -> str:\n", " response_json = json.loads(output.read().decode(\"utf-8\"))\n", " return response_json[0][\"generated_text\"]\n", "\n", "content_handler = ContentHandler()\n", "\n", "chain = load_qa_chain(\n", " llm=SagemakerEndpoint(\n", " endpoint_name=\"endpoint-name\", \n", " credentials_profile_name=\"credentials-profile-name\", \n", " region_name=\"us-west-2\", \n", " model_kwargs={\"temperature\":1e-10},\n", " content_handler=content_handler\n", " ),\n", " prompt=PROMPT\n", ")\n", "\n", "chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 2 }