@ -82,6 +82,15 @@
"]"
"]"
]
]
},
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example to initialize with external boto3 session\n",
"\n",
"### for cross account scenarios"
]
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": null,
"execution_count": null,
@ -92,7 +101,77 @@
"source": [
"source": [
"from typing import Dict\n",
"from typing import Dict\n",
"\n",
"\n",
"from langchain.prompts import PromptTemplate\nfrom langchain.llms import SagemakerEndpoint\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.llms import SagemakerEndpoint\n",
"from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"import json\n",
"import boto3\n",
"\n",
"query = \"\"\"How long was Elizabeth hospitalized?\n",
"\"\"\"\n",
"\n",
"prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n",
"\n",
"{context}\n",
"\n",
"Question: {question}\n",
"Answer:\"\"\"\n",
"PROMPT = PromptTemplate(\n",
" template=prompt_template, input_variables=[\"context\", \"question\"]\n",
")\n",
"\n",
"roleARN = 'arn:aws:iam::123456789:role/cross-account-role'\n",
"sts_client = boto3.client('sts')\n",
"response = sts_client.assume_role(RoleArn=roleARN, \n",
" RoleSessionName='CrossAccountSession')\n",
"\n",
"client = boto3.client(\n",
" \"sagemaker-runtime\",\n",
" region_name=\"us-west-2\", \n",
" aws_access_key_id=response['Credentials']['AccessKeyId'],\n",
" aws_secret_access_key=response['Credentials']['SecretAccessKey'],\n",
" aws_session_token = response['Credentials']['SessionToken']\n",
")\n",
"\n",
"class ContentHandler(LLMContentHandler):\n",
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({prompt: prompt, **model_kwargs})\n",
" return input_str.encode(\"utf-8\")\n",
"\n",
" def transform_output(self, output: bytes) -> str:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[0][\"generated_text\"]\n",
"\n",
"\n",
"content_handler = ContentHandler()\n",
"\n",
"chain = load_qa_chain(\n",
" llm=SagemakerEndpoint(\n",
" endpoint_name=\"endpoint-name\",\n",
" client=client,\n",
" model_kwargs={\"temperature\": 1e-10},\n",
" content_handler=content_handler,\n",
" ),\n",
" prompt=PROMPT,\n",
")\n",
"\n",
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict\n",
"\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.llms import SagemakerEndpoint\n",
"from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
"from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"import json\n",
"import json\n",