Add serialisation arguments to Bedrock and ChatBedrock (#13465)

This commit is contained in:
David Duong 2023-11-17 01:33:24 +01:00 committed by GitHub
parent 427331d621
commit ea6e017b85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 0 deletions

View File

@ -41,6 +41,22 @@ class BedrockChat(BaseChatModel, BedrockBase):
"""Return type of chat model."""
return "amazon_bedrock_chat"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
print(self.region_name)
if self.region_name:
attributes["region_name"] = self.region_name
return attributes
class Config:
"""Configuration for this pydantic object."""

View File

@ -12,6 +12,7 @@ from langchain.utilities.anthropic import (
get_num_tokens_anthropic,
get_token_ids_anthropic,
)
from langchain.utils import get_from_dict_or_env
HUMAN_PROMPT = "\n\nHuman:"
ASSISTANT_PROMPT = "\n\nAssistant:"
@ -195,6 +196,13 @@ class BedrockBase(BaseModel, ABC):
# use default credentials
session = boto3.Session()
values["region_name"] = get_from_dict_or_env(
values,
"region_name",
"AWS_DEFAULT_REGION",
default=None,
)
client_params = {}
if values["region_name"]:
client_params["region_name"] = values["region_name"]
@ -340,6 +348,20 @@ class Bedrock(LLM, BedrockBase):
"""Return type of llm."""
return "amazon_bedrock"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.region_name:
attributes["region_name"] = self.region_name
return attributes
class Config:
"""Configuration for this pydantic object."""