diff --git a/libs/community/langchain_community/chat_message_histories/dynamodb.py b/libs/community/langchain_community/chat_message_histories/dynamodb.py index a804e75018..4429d2e1f1 100644 --- a/libs/community/langchain_community/chat_message_histories/dynamodb.py +++ b/libs/community/langchain_community/chat_message_histories/dynamodb.py @@ -38,6 +38,11 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): This may also contain global and local secondary index keys. kms_key_id: an optional AWS KMS Key ID, AWS KMS Key ARN, or AWS KMS Alias for client-side encryption + ttl: Optional Time-to-live (TTL) in seconds. Allows you to define a per-item + expiration timestamp that indicates when an item can be deleted from the + table. DynamoDB handles deletion of expired items without consuming + write throughput. To enable this feature on the table, follow the + [AWS DynamoDB documentation](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/time-to-live-ttl-how-to.html) """ def __init__( @@ -49,6 +54,8 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): key: Optional[Dict[str, str]] = None, boto3_session: Optional[Session] = None, kms_key_id: Optional[str] = None, + ttl: Optional[int] = None, + ttl_key_name: str = "expireAt", ): if boto3_session: client = boto3_session.resource("dynamodb", endpoint_url=endpoint_url) @@ -66,6 +73,8 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): self.table = client.Table(table_name) self.session_id = session_id self.key: Dict = key or {primary_key_name: session_id} + self.ttl = ttl + self.ttl_key_name = ttl_key_name if kms_key_id: try: @@ -134,7 +143,15 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): messages.append(_message) try: - self.table.put_item(Item={**self.key, "History": messages}) + if self.ttl: + import time + + expireAt = int(time.time()) + self.ttl + self.table.put_item( + Item={**self.key, "History": messages, self.ttl_key_name: expireAt} + ) + else: + self.table.put_item(Item={**self.key, "History": messages}) except ClientError as err: logger.error(err)