mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Sagemaker Endpoint LLM (#1686)
Updates #965 --------- Co-authored-by: Nimisha Mehta <116048415+nimimeht@users.noreply.github.com> Co-authored-by: Harrison Chase <harrisonchase@Harrisons-MBP.attlocal.net>
This commit is contained in:
parent
cd45adbea2
commit
cdff6c8181
122
docs/modules/document_loaders/examples/sagemaker.ipynb
Normal file
122
docs/modules/document_loaders/examples/sagemaker.ipynb
Normal file
@ -0,0 +1,122 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip3 install langchain boto3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.docstore.document import Document"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"example_doc_1 = \"\"\"\n",
|
||||
"Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n",
|
||||
"Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n",
|
||||
"Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"docs = [\n",
|
||||
" Document(\n",
|
||||
" page_content=example_doc_1,\n",
|
||||
" )\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Dict\n",
|
||||
"\n",
|
||||
"from langchain import PromptTemplate, SagemakerEndpoint\n",
|
||||
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"query = \"\"\"How long was Elizabeth hospitalized?\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n",
|
||||
"\n",
|
||||
"{context}\n",
|
||||
"\n",
|
||||
"Question: {question}\n",
|
||||
"Answer:\"\"\"\n",
|
||||
"PROMPT = PromptTemplate(\n",
|
||||
" template=prompt_template, input_variables=[\"context\", \"question\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"class ContentHandler(ContentHandlerBase):\n",
|
||||
" content_type = \"application/json\"\n",
|
||||
" accepts = \"application/json\"\n",
|
||||
"\n",
|
||||
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
||||
" input_str = json.dumps({prompt: prompt, **model_kwargs})\n",
|
||||
" return input_str.encode('utf-8')\n",
|
||||
" \n",
|
||||
" def transform_output(self, output: bytes) -> str:\n",
|
||||
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
||||
" return response_json[0][\"generated_text\"]\n",
|
||||
"\n",
|
||||
"content_handler = ContentHandler()\n",
|
||||
"\n",
|
||||
"chain = load_qa_chain(\n",
|
||||
" llm=SagemakerEndpoint(\n",
|
||||
" endpoint_name=\"endpoint-name\", \n",
|
||||
" credentials_profile_name=\"credentials-profile-name\", \n",
|
||||
" region_name=\"us-west-2\", \n",
|
||||
" model_kwargs={\"temperature\":1e-10},\n",
|
||||
" content_handler=content_handler\n",
|
||||
" ),\n",
|
||||
" prompt=PROMPT\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -33,6 +33,7 @@ from langchain.llms import (
|
||||
Modal,
|
||||
OpenAI,
|
||||
Petals,
|
||||
SagemakerEndpoint,
|
||||
StochasticAI,
|
||||
Writer,
|
||||
)
|
||||
@ -90,6 +91,7 @@ __all__ = [
|
||||
"ReActChain",
|
||||
"Wikipedia",
|
||||
"HuggingFaceHub",
|
||||
"SagemakerEndpoint",
|
||||
"HuggingFacePipeline",
|
||||
"SQLDatabase",
|
||||
"SQLDatabaseChain",
|
||||
|
@ -19,6 +19,7 @@ from langchain.llms.nlpcloud import NLPCloud
|
||||
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
|
||||
from langchain.llms.petals import Petals
|
||||
from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat
|
||||
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
|
||||
from langchain.llms.self_hosted import SelfHostedPipeline
|
||||
from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM
|
||||
from langchain.llms.stochasticai import StochasticAI
|
||||
@ -40,6 +41,7 @@ __all__ = [
|
||||
"Petals",
|
||||
"HuggingFaceEndpoint",
|
||||
"HuggingFaceHub",
|
||||
"SagemakerEndpoint",
|
||||
"HuggingFacePipeline",
|
||||
"AI21",
|
||||
"AzureOpenAI",
|
||||
@ -64,6 +66,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"huggingface_hub": HuggingFaceHub,
|
||||
"huggingface_endpoint": HuggingFaceEndpoint,
|
||||
"modal": Modal,
|
||||
"sagemaker_endpoint": SagemakerEndpoint,
|
||||
"nlpcloud": NLPCloud,
|
||||
"openai": OpenAI,
|
||||
"petals": Petals,
|
||||
|
235
langchain/llms/sagemaker_endpoint.py
Normal file
235
langchain/llms/sagemaker_endpoint.py
Normal file
@ -0,0 +1,235 @@
|
||||
"""Wrapper around Sagemaker InvokeEndpoint API."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
class ContentHandlerBase(ABC):
|
||||
"""A handler class to transform input from LLM to a
|
||||
format that SageMaker endpoint expects. Similarily,
|
||||
the class also handles transforming output from the
|
||||
SageMaker endpoint to a format that LLM class expects.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class ContentHandler(ContentHandlerBase):
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
|
||||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps({prompt: prompt, **model_kwargs})
|
||||
return input_str.encode('utf-8')
|
||||
|
||||
def transform_output(self, output: bytes) -> str:
|
||||
response_json = json.loads(output.read().decode("utf-8"))
|
||||
return response_json[0]["generated_text"]
|
||||
"""
|
||||
|
||||
content_type: Optional[str] = "text/plain"
|
||||
"""The MIME type of the input data passed to endpoint"""
|
||||
|
||||
accepts: Optional[str] = "text/plain"
|
||||
"""The MIME type of the response data returned from endpoint"""
|
||||
|
||||
@abstractmethod
|
||||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
"""Transforms the input to a format that model can accept
|
||||
as the request Body. Should return bytes or seekable file
|
||||
like object in the format specified in the content_type
|
||||
request header.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def transform_output(self, output: bytes) -> str:
|
||||
"""Transforms the output from the model to string that
|
||||
the LLM class expects.
|
||||
"""
|
||||
|
||||
|
||||
class SagemakerEndpoint(LLM, BaseModel):
|
||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||
|
||||
To use, you must supply the endpoint name from your deployed
|
||||
Sagemaker model & the region where it is deployed.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Sagemaker endpoint.
|
||||
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import SagemakerEndpoint
|
||||
endpoint_name = (
|
||||
"my-endpoint-name"
|
||||
)
|
||||
region_name = (
|
||||
"us-west-2"
|
||||
)
|
||||
credentials_profile_name = (
|
||||
"default"
|
||||
)
|
||||
se = SagemakerEndpoint(
|
||||
endpoint_name=endpoint_name,
|
||||
region_name=region_name,
|
||||
credentials_profile_name=credentials_profile_name
|
||||
)
|
||||
"""
|
||||
client: Any #: :meta private:
|
||||
|
||||
endpoint_name: str = ""
|
||||
"""The name of the endpoint from the deployed Sagemaker model.
|
||||
Must be unique within an AWS Region."""
|
||||
|
||||
region_name: str = ""
|
||||
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
|
||||
|
||||
credentials_profile_name: Optional[str] = None
|
||||
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
||||
has either access keys or role information specified.
|
||||
If not specified, the default credential profile or, if on an EC2 instance,
|
||||
credentials from IMDS will be used.
|
||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
"""
|
||||
|
||||
content_handler: ContentHandlerBase
|
||||
"""The content handler class that provides an input and
|
||||
output transform functions to handle formats between LLM
|
||||
and the endpoint.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class ContentHandler(ContentHandlerBase):
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
|
||||
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
input_str = json.dumps({prompt: prompt, **model_kwargs})
|
||||
return input_str.encode('utf-8')
|
||||
|
||||
def transform_output(self, output: bytes) -> str:
|
||||
response_json = json.loads(output.read().decode("utf-8"))
|
||||
return response_json[0]["generated_text"]
|
||||
"""
|
||||
|
||||
model_kwargs: Optional[Dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
endpoint_kwargs: Optional[Dict] = None
|
||||
"""Optional attributes passed to the invoke_endpoint
|
||||
function. See `boto3`_. docs for more info.
|
||||
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
try:
|
||||
import boto3
|
||||
|
||||
try:
|
||||
if values["credentials_profile_name"] is not None:
|
||||
session = boto3.Session(
|
||||
profile_name=values["credentials_profile_name"]
|
||||
)
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
values["client"] = session.client(
|
||||
"sagemaker-runtime", region_name=values["region_name"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please it install it with `pip install boto3`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"endpoint_name": self.endpoint_name},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "sagemaker_endpoint"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Call out to Sagemaker inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = se("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_endpoint_kwargs = self.endpoint_kwargs or {}
|
||||
|
||||
body = self.content_handler.transform_input(prompt, _model_kwargs)
|
||||
content_type = self.content_handler.content_type
|
||||
accepts = self.content_handler.accepts
|
||||
|
||||
# send request
|
||||
try:
|
||||
response = self.client.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=body,
|
||||
ContentType=content_type,
|
||||
Accept=accepts,
|
||||
**_endpoint_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
text = self.content_handler.transform_output(response["Body"])
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to the sagemaker endpoint.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
Loading…
Reference in New Issue
Block a user