forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
196 lines
6.3 KiB
Python
196 lines
6.3 KiB
Python
import json
|
|
from typing import Any, Dict, List, Mapping, Optional
|
|
|
|
from pydantic import Extra, root_validator
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
from langchain.llms.base import LLM
|
|
from langchain.llms.utils import enforce_stop_tokens
|
|
|
|
|
|
class LLMInputOutputAdapter:
|
|
"""Adapter class to prepare the inputs from Langchain to a format
|
|
that LLM model expects. Also, provides helper function to extract
|
|
the generated text from the model response."""
|
|
|
|
@classmethod
|
|
def prepare_input(
|
|
cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
input_body = {**model_kwargs}
|
|
if provider == "anthropic" or provider == "ai21":
|
|
input_body["prompt"] = prompt
|
|
elif provider == "amazon":
|
|
input_body = dict()
|
|
input_body["inputText"] = prompt
|
|
input_body["textGenerationConfig"] = {**model_kwargs}
|
|
else:
|
|
input_body["inputText"] = prompt
|
|
|
|
if provider == "anthropic" and "max_tokens_to_sample" not in input_body:
|
|
input_body["max_tokens_to_sample"] = 50
|
|
|
|
return input_body
|
|
|
|
@classmethod
|
|
def prepare_output(cls, provider: str, response: Any) -> str:
|
|
if provider == "anthropic":
|
|
response_body = json.loads(response.get("body").read().decode())
|
|
return response_body.get("completion")
|
|
else:
|
|
response_body = json.loads(response.get("body").read())
|
|
|
|
if provider == "ai21":
|
|
return response_body.get("completions")[0].get("data").get("text")
|
|
else:
|
|
return response_body.get("results")[0].get("outputText")
|
|
|
|
|
|
class Bedrock(LLM):
|
|
"""LLM provider to invoke Bedrock models.
|
|
|
|
To authenticate, the AWS client uses the following methods to
|
|
automatically load credentials:
|
|
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
|
|
|
If a specific credential profile should be used, you must pass
|
|
the name of the profile from the ~/.aws/credentials file that is to be used.
|
|
|
|
Make sure the credentials / roles used have the required policies to
|
|
access the Bedrock service.
|
|
"""
|
|
|
|
"""
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from bedrock_langchain.bedrock_llm import BedrockLLM
|
|
|
|
llm = BedrockLLM(
|
|
credentials_profile_name="default",
|
|
model_id="amazon.titan-tg1-large"
|
|
)
|
|
|
|
"""
|
|
|
|
client: Any #: :meta private:
|
|
|
|
region_name: Optional[str] = None
|
|
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
|
|
or region specified in ~/.aws/config in case it is not provided here.
|
|
"""
|
|
|
|
credentials_profile_name: Optional[str] = None
|
|
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
|
has either access keys or role information specified.
|
|
If not specified, the default credential profile or, if on an EC2 instance,
|
|
credentials from IMDS will be used.
|
|
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
|
"""
|
|
|
|
model_id: str
|
|
"""Id of the model to call, e.g., amazon.titan-tg1-large, this is
|
|
equivalent to the modelId property in the list-foundation-models api"""
|
|
|
|
model_kwargs: Optional[Dict] = None
|
|
"""Key word arguments to pass to the model."""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that AWS credentials to and python package exists in environment."""
|
|
|
|
# Skip creating new client if passed in constructor
|
|
if values["client"] is not None:
|
|
return values
|
|
|
|
try:
|
|
import boto3
|
|
|
|
if values["credentials_profile_name"] is not None:
|
|
session = boto3.Session(profile_name=values["credentials_profile_name"])
|
|
else:
|
|
# use default credentials
|
|
session = boto3.Session()
|
|
|
|
client_params = {}
|
|
if values["region_name"]:
|
|
client_params["region_name"] = values["region_name"]
|
|
|
|
values["client"] = session.client("bedrock", **client_params)
|
|
|
|
except ImportError:
|
|
raise ModuleNotFoundError(
|
|
"Could not import boto3 python package. "
|
|
"Please install it with `pip install boto3`."
|
|
)
|
|
except Exception as e:
|
|
raise ValueError(
|
|
"Could not load credentials to authenticate with AWS client. "
|
|
"Please check that credentials in the specified "
|
|
"profile name are valid."
|
|
) from e
|
|
|
|
return values
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
return {
|
|
**{"model_kwargs": _model_kwargs},
|
|
}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "amazon_bedrock"
|
|
|
|
def _call(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Call out to Bedrock service model.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
response = se("Tell me a joke.")
|
|
"""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
|
|
provider = self.model_id.split(".")[0]
|
|
params = {**_model_kwargs, **kwargs}
|
|
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
|
body = json.dumps(input_body)
|
|
accept = "application/json"
|
|
contentType = "application/json"
|
|
|
|
try:
|
|
response = self.client.invoke_model(
|
|
body=body, modelId=self.model_id, accept=accept, contentType=contentType
|
|
)
|
|
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
|
|
|
except Exception as e:
|
|
raise ValueError(f"Error raised by bedrock service: {e}")
|
|
|
|
if stop is not None:
|
|
text = enforce_stop_tokens(text, stop)
|
|
|
|
return text
|