diff --git a/langchain/embeddings/bedrock.py b/langchain/embeddings/bedrock.py index e884b139..16381b14 100644 --- a/langchain/embeddings/bedrock.py +++ b/langchain/embeddings/bedrock.py @@ -68,6 +68,10 @@ class BedrockEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that AWS credentials to and python package exists in environment.""" + + if "client" in values: + return values + try: import boto3 diff --git a/langchain/llms/bedrock.py b/langchain/llms/bedrock.py index a59a81d2..a985b2b0 100644 --- a/langchain/llms/bedrock.py +++ b/langchain/llms/bedrock.py @@ -99,6 +99,11 @@ class Bedrock(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that AWS credentials to and python package exists in environment.""" + + # Skip creating new client if passed in constructor + if "client" in values: + return values + try: import boto3