diff --git a/docs/modules/document_loaders/examples/sagemaker.ipynb b/docs/modules/document_loaders/examples/sagemaker.ipynb new file mode 100644 index 00000000..be779dfa --- /dev/null +++ b/docs/modules/document_loaders/examples/sagemaker.ipynb @@ -0,0 +1,122 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip3 install langchain boto3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.docstore.document import Document" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "example_doc_1 = \"\"\"\n", + "Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n", + "Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n", + "Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n", + "\"\"\"\n", + "\n", + "docs = [\n", + " Document(\n", + " page_content=example_doc_1,\n", + " )\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Dict\n", + "\n", + "from langchain import PromptTemplate, SagemakerEndpoint\n", + "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", + "from langchain.chains.question_answering import load_qa_chain\n", + "import json\n", + "\n", + "query = \"\"\"How long was Elizabeth hospitalized?\n", + "\"\"\"\n", + "\n", + "prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n", + "\n", + "{context}\n", + "\n", + "Question: {question}\n", + "Answer:\"\"\"\n", + "PROMPT = PromptTemplate(\n", + " template=prompt_template, input_variables=[\"context\", \"question\"]\n", + ")\n", + "\n", + "class ContentHandler(ContentHandlerBase):\n", + " content_type = \"application/json\"\n", + " accepts = \"application/json\"\n", + "\n", + " def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n", + " input_str = json.dumps({prompt: prompt, **model_kwargs})\n", + " return input_str.encode('utf-8')\n", + " \n", + " def transform_output(self, output: bytes) -> str:\n", + " response_json = json.loads(output.read().decode(\"utf-8\"))\n", + " return response_json[0][\"generated_text\"]\n", + "\n", + "content_handler = ContentHandler()\n", + "\n", + "chain = load_qa_chain(\n", + " llm=SagemakerEndpoint(\n", + " endpoint_name=\"endpoint-name\", \n", + " credentials_profile_name=\"credentials-profile-name\", \n", + " region_name=\"us-west-2\", \n", + " model_kwargs={\"temperature\":1e-10},\n", + " content_handler=content_handler\n", + " ),\n", + " prompt=PROMPT\n", + ")\n", + "\n", + "chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/__init__.py b/langchain/__init__.py index 3b8e0ac0..86bf4f7a 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -33,6 +33,7 @@ from langchain.llms import ( Modal, OpenAI, Petals, + SagemakerEndpoint, StochasticAI, Writer, ) @@ -90,6 +91,7 @@ __all__ = [ "ReActChain", "Wikipedia", "HuggingFaceHub", + "SagemakerEndpoint", "HuggingFacePipeline", "SQLDatabase", "SQLDatabaseChain", diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index ed9610d8..06bc999f 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -19,6 +19,7 @@ from langchain.llms.nlpcloud import NLPCloud from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat from langchain.llms.petals import Petals from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat +from langchain.llms.sagemaker_endpoint import SagemakerEndpoint from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM from langchain.llms.stochasticai import StochasticAI @@ -40,6 +41,7 @@ __all__ = [ "Petals", "HuggingFaceEndpoint", "HuggingFaceHub", + "SagemakerEndpoint", "HuggingFacePipeline", "AI21", "AzureOpenAI", @@ -64,6 +66,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "huggingface_hub": HuggingFaceHub, "huggingface_endpoint": HuggingFaceEndpoint, "modal": Modal, + "sagemaker_endpoint": SagemakerEndpoint, "nlpcloud": NLPCloud, "openai": OpenAI, "petals": Petals, diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py new file mode 100644 index 00000000..246f38e4 --- /dev/null +++ b/langchain/llms/sagemaker_endpoint.py @@ -0,0 +1,235 @@ +"""Wrapper around Sagemaker InvokeEndpoint API.""" +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Mapping, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens + + +class ContentHandlerBase(ABC): + """A handler class to transform input from LLM to a + format that SageMaker endpoint expects. Similarily, + the class also handles transforming output from the + SageMaker endpoint to a format that LLM class expects. + """ + + """ + Example: + .. code-block:: python + + class ContentHandler(ContentHandlerBase): + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + input_str = json.dumps({prompt: prompt, **model_kwargs}) + return input_str.encode('utf-8') + + def transform_output(self, output: bytes) -> str: + response_json = json.loads(output.read().decode("utf-8")) + return response_json[0]["generated_text"] + """ + + content_type: Optional[str] = "text/plain" + """The MIME type of the input data passed to endpoint""" + + accepts: Optional[str] = "text/plain" + """The MIME type of the response data returned from endpoint""" + + @abstractmethod + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + """Transforms the input to a format that model can accept + as the request Body. Should return bytes or seekable file + like object in the format specified in the content_type + request header. + """ + + @abstractmethod + def transform_output(self, output: bytes) -> str: + """Transforms the output from the model to string that + the LLM class expects. + """ + + +class SagemakerEndpoint(LLM, BaseModel): + """Wrapper around custom Sagemaker Inference Endpoints. + + To use, you must supply the endpoint name from your deployed + Sagemaker model & the region where it is deployed. + + To authenticate, the AWS client uses the following methods to + automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Sagemaker endpoint. + See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html + """ + + """ + Example: + .. code-block:: python + + from langchain import SagemakerEndpoint + endpoint_name = ( + "my-endpoint-name" + ) + region_name = ( + "us-west-2" + ) + credentials_profile_name = ( + "default" + ) + se = SagemakerEndpoint( + endpoint_name=endpoint_name, + region_name=region_name, + credentials_profile_name=credentials_profile_name + ) + """ + client: Any #: :meta private: + + endpoint_name: str = "" + """The name of the endpoint from the deployed Sagemaker model. + Must be unique within an AWS Region.""" + + region_name: str = "" + """The aws region where the Sagemaker model is deployed, eg. `us-west-2`.""" + + credentials_profile_name: Optional[str] = None + """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which + has either access keys or role information specified. + If not specified, the default credential profile or, if on an EC2 instance, + credentials from IMDS will be used. + See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + """ + + content_handler: ContentHandlerBase + """The content handler class that provides an input and + output transform functions to handle formats between LLM + and the endpoint. + """ + + """ + Example: + .. code-block:: python + + class ContentHandler(ContentHandlerBase): + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + input_str = json.dumps({prompt: prompt, **model_kwargs}) + return input_str.encode('utf-8') + + def transform_output(self, output: bytes) -> str: + response_json = json.loads(output.read().decode("utf-8")) + return response_json[0]["generated_text"] + """ + + model_kwargs: Optional[Dict] = None + """Key word arguments to pass to the model.""" + + endpoint_kwargs: Optional[Dict] = None + """Optional attributes passed to the invoke_endpoint + function. See `boto3`_. docs for more info. + .. _boto3: + """ + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that AWS credentials to and python package exists in environment.""" + try: + import boto3 + + try: + if values["credentials_profile_name"] is not None: + session = boto3.Session( + profile_name=values["credentials_profile_name"] + ) + else: + # use default credentials + session = boto3.Session() + + values["client"] = session.client( + "sagemaker-runtime", region_name=values["region_name"] + ) + + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + except ImportError: + raise ValueError( + "Could not import boto3 python package. " + "Please it install it with `pip install boto3`." + ) + return values + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return { + **{"endpoint_name": self.endpoint_name}, + **{"model_kwargs": _model_kwargs}, + } + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "sagemaker_endpoint" + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Call out to Sagemaker inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = se("Tell me a joke.") + """ + _model_kwargs = self.model_kwargs or {} + _endpoint_kwargs = self.endpoint_kwargs or {} + + body = self.content_handler.transform_input(prompt, _model_kwargs) + content_type = self.content_handler.content_type + accepts = self.content_handler.accepts + + # send request + try: + response = self.client.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=body, + ContentType=content_type, + Accept=accepts, + **_endpoint_kwargs, + ) + except Exception as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + text = self.content_handler.transform_output(response["Body"]) + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to the sagemaker endpoint. + text = enforce_stop_tokens(text, stop) + + return text