From 20c742d8a254466f080268b3f15dd9d99204935a Mon Sep 17 00:00:00 2001 From: Tze Min <40569118+tmin97@users.noreply.github.com> Date: Fri, 8 Sep 2023 05:58:28 +0800 Subject: [PATCH] Enhancement: add parameter boto3_session for AWS DynamoDB cross account use cases (#10326) - Description: to allow boto3 assume role for AWS cross account use cases to read and update the chat history, - Issue: use case I faced in my company, - Dependencies: no - Tag maintainer: @baskaryan , - Twitter handle: @tmin97 --------- Co-authored-by: Bagatur --- .../memory/chat_message_histories/dynamodb.py | 46 +++++++++++++++---- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py index 704efa9ea7..06d7897dbd 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import logging -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from langchain.schema import ( BaseChatMessageHistory, @@ -11,6 +13,9 @@ from langchain.schema.messages import ( messages_to_dict, ) +if TYPE_CHECKING: + from boto3.session import Session + logger = logging.getLogger(__name__) @@ -42,13 +47,21 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): endpoint_url: Optional[str] = None, primary_key_name: str = "SessionId", key: Optional[Dict[str, str]] = None, + boto3_session: Optional[Session] = None, ): - import boto3 - - if endpoint_url: - client = boto3.resource("dynamodb", endpoint_url=endpoint_url) + if boto3_session: + client = boto3_session.resource("dynamodb") else: - client = boto3.resource("dynamodb") + try: + import boto3 + except ImportError as e: + raise ImportError( + "Unable to import boto3, please install with `pip install boto3`." + ) from e + if endpoint_url: + client = boto3.resource("dynamodb", endpoint_url=endpoint_url) + else: + client = boto3.resource("dynamodb") self.table = client.Table(table_name) self.session_id = session_id self.key: Dict = key or {primary_key_name: session_id} @@ -56,7 +69,12 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve the messages from DynamoDB""" - from botocore.exceptions import ClientError + try: + from botocore.exceptions import ClientError + except ImportError as e: + raise ImportError( + "Unable to import botocore, please install with `pip install botocore`." + ) from e response = None try: @@ -77,7 +95,12 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): def add_message(self, message: BaseMessage) -> None: """Append the message to the record in DynamoDB""" - from botocore.exceptions import ClientError + try: + from botocore.exceptions import ClientError + except ImportError as e: + raise ImportError( + "Unable to import botocore, please install with `pip install botocore`." + ) from e messages = messages_to_dict(self.messages) _message = _message_to_dict(message) @@ -90,7 +113,12 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): def clear(self) -> None: """Clear session memory from DynamoDB""" - from botocore.exceptions import ClientError + try: + from botocore.exceptions import ClientError + except ImportError as e: + raise ImportError( + "Unable to import botocore, please install with `pip install botocore`." + ) from e try: self.table.delete_item(self.key)