diff --git a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py index 34ba5694d7..318afdce13 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py +++ b/libs/langchain/langchain/memory/chat_message_histories/dynamodb.py @@ -38,6 +38,8 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): This argument is optional, but useful when using composite dynamodb keys, or isolating records based off of application details such as a user id. This may also contain global and local secondary index keys. + kms_key_id: an optional AWS KMS Key ID, AWS KMS Key ARN, or AWS KMS Alias for + client-side encryption """ def __init__( @@ -48,6 +50,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): primary_key_name: str = "SessionId", key: Optional[Dict[str, str]] = None, boto3_session: Optional[Session] = None, + kms_key_id: Optional[str] = None, ): if boto3_session: client = boto3_session.resource("dynamodb") @@ -66,6 +69,32 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): self.session_id = session_id self.key: Dict = key or {primary_key_name: session_id} + if kms_key_id: + try: + from dynamodb_encryption_sdk.encrypted.table import EncryptedTable + from dynamodb_encryption_sdk.identifiers import CryptoAction + from dynamodb_encryption_sdk.material_providers.aws_kms import ( + AwsKmsCryptographicMaterialsProvider, + ) + from dynamodb_encryption_sdk.structures import AttributeActions + except ImportError as e: + raise ImportError( + "Unable to import dynamodb_encryption_sdk, please install with " + "`pip install dynamodb-encryption-sdk`." + ) from e + + actions = AttributeActions( + default_action=CryptoAction.DO_NOTHING, + attribute_actions={"History": CryptoAction.ENCRYPT_AND_SIGN}, + ) + aws_kms_cmp = AwsKmsCryptographicMaterialsProvider(key_id=kms_key_id) + self.table = EncryptedTable( + table=self.table, + materials_provider=aws_kms_cmp, + attribute_actions=actions, + auto_refresh_table_indexes=False, + ) + @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve the messages from DynamoDB"""