From ccb6238de19f743947a4fa66193484b104a221da Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Mon, 29 May 2023 16:18:59 +0200 Subject: [PATCH] Implemented appending arbitrary messages (#5293) # Implemented appending arbitrary messages to the base chat message history, the in-memory and cosmos ones. As discussed this is the alternative way instead of #4480, with a add_message method added that takes a BaseMessage as input, so that the user can control what is in the base message like kwargs. Fixes # (issue) ## Before submitting ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @hwchase17 --------- Co-authored-by: Harrison Chase --- .../chat_message_histories/cassandra.py | 10 +-------- .../chat_message_histories/cosmos_db.py | 17 +++++---------- .../memory/chat_message_histories/dynamodb.py | 10 +-------- .../memory/chat_message_histories/file.py | 10 +-------- .../chat_message_histories/firestore.py | 14 +++---------- .../chat_message_histories/in_memory.py | 10 +++------ .../memory/chat_message_histories/momento.py | 20 +----------------- .../memory/chat_message_histories/mongodb.py | 10 +-------- .../memory/chat_message_histories/postgres.py | 10 +-------- .../memory/chat_message_histories/redis.py | 10 +-------- .../memory/chat_message_histories/sql.py | 10 +-------- .../memory/chat_message_histories/zep.py | 8 +------ langchain/schema.py | 21 ++++++++----------- .../memory/chat_message_histories/test_zep.py | 2 +- 14 files changed, 30 insertions(+), 132 deletions(-) diff --git a/langchain/memory/chat_message_histories/cassandra.py b/langchain/memory/chat_message_histories/cassandra.py index d424792a..0b468dca 100644 --- a/langchain/memory/chat_message_histories/cassandra.py +++ b/langchain/memory/chat_message_histories/cassandra.py @@ -3,10 +3,8 @@ import logging from typing import List from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, _message_to_dict, messages_from_dict, ) @@ -143,13 +141,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): return messages - def add_user_message(self, message: str) -> None: - self.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.append(AIMessage(content=message)) - - def append(self, message: BaseMessage) -> None: + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in Cassandra""" import uuid diff --git a/langchain/memory/chat_message_histories/cosmos_db.py b/langchain/memory/chat_message_histories/cosmos_db.py index 3c021928..5318c805 100644 --- a/langchain/memory/chat_message_histories/cosmos_db.py +++ b/langchain/memory/chat_message_histories/cosmos_db.py @@ -6,10 +6,8 @@ from types import TracebackType from typing import TYPE_CHECKING, Any, List, Optional, Type from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, messages_from_dict, messages_to_dict, ) @@ -145,18 +143,13 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory): if "messages" in item and len(item["messages"]) > 0: self.messages = messages_from_dict(item["messages"]) - def add_user_message(self, message: str) -> None: - """Add a user message to the memory.""" - self.upsert_messages(HumanMessage(content=message)) + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store""" + self.messages.append(message) + self.upsert_messages() - def add_ai_message(self, message: str) -> None: - """Add a AI message to the memory.""" - self.upsert_messages(AIMessage(content=message)) - - def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None: + def upsert_messages(self) -> None: """Update the cosmosdb item.""" - if new_message: - self.messages.append(new_message) if not self._container: raise ValueError("Container not initialized") self._container.upsert_item( diff --git a/langchain/memory/chat_message_histories/dynamodb.py b/langchain/memory/chat_message_histories/dynamodb.py index 413183ea..2fc05688 100644 --- a/langchain/memory/chat_message_histories/dynamodb.py +++ b/langchain/memory/chat_message_histories/dynamodb.py @@ -2,10 +2,8 @@ import logging from typing import List from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, _message_to_dict, messages_from_dict, messages_to_dict, @@ -53,13 +51,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): messages = messages_from_dict(items) return messages - def add_user_message(self, message: str) -> None: - self.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.append(AIMessage(content=message)) - - def append(self, message: BaseMessage) -> None: + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in DynamoDB""" from botocore.exceptions import ClientError diff --git a/langchain/memory/chat_message_histories/file.py b/langchain/memory/chat_message_histories/file.py index 37ca6f27..0fbbf1e7 100644 --- a/langchain/memory/chat_message_histories/file.py +++ b/langchain/memory/chat_message_histories/file.py @@ -4,10 +4,8 @@ from pathlib import Path from typing import List from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, messages_from_dict, messages_to_dict, ) @@ -36,13 +34,7 @@ class FileChatMessageHistory(BaseChatMessageHistory): messages = messages_from_dict(items) return messages - def add_user_message(self, message: str) -> None: - self.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.append(AIMessage(content=message)) - - def append(self, message: BaseMessage) -> None: + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in the local file""" messages = messages_to_dict(self.messages) messages.append(messages_to_dict([message])[0]) diff --git a/langchain/memory/chat_message_histories/firestore.py b/langchain/memory/chat_message_histories/firestore.py index dbbf3ff1..3e325682 100644 --- a/langchain/memory/chat_message_histories/firestore.py +++ b/langchain/memory/chat_message_histories/firestore.py @@ -5,10 +5,8 @@ import logging from typing import TYPE_CHECKING, List, Optional from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, messages_from_dict, messages_to_dict, ) @@ -81,18 +79,12 @@ class FirestoreChatMessageHistory(BaseChatMessageHistory): if "messages" in data and len(data["messages"]) > 0: self.messages = messages_from_dict(data["messages"]) - def add_user_message(self, message: str) -> None: - """Add a user message to the memory.""" - self.upsert_messages(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - """Add a AI message to the memory.""" - self.upsert_messages(AIMessage(content=message)) + def add_message(self, message: BaseMessage) -> None: + self.messages.append(message) + self.upsert_messages() def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None: """Update the Firestore document.""" - if new_message: - self.messages.append(new_message) if not self._document: raise ValueError("Document not initialized") self._document.set( diff --git a/langchain/memory/chat_message_histories/in_memory.py b/langchain/memory/chat_message_histories/in_memory.py index 0760bd3c..bcb60d2e 100644 --- a/langchain/memory/chat_message_histories/in_memory.py +++ b/langchain/memory/chat_message_histories/in_memory.py @@ -3,21 +3,17 @@ from typing import List from pydantic import BaseModel from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, ) class ChatMessageHistory(BaseChatMessageHistory, BaseModel): messages: List[BaseMessage] = [] - def add_user_message(self, message: str) -> None: - self.messages.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.messages.append(AIMessage(content=message)) + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store""" + self.messages.append(message) def clear(self) -> None: self.messages = [] diff --git a/langchain/memory/chat_message_histories/momento.py b/langchain/memory/chat_message_histories/momento.py index 1bc74981..885fe16b 100644 --- a/langchain/memory/chat_message_histories/momento.py +++ b/langchain/memory/chat_message_histories/momento.py @@ -5,10 +5,8 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Optional from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, _message_to_dict, messages_from_dict, ) @@ -143,23 +141,7 @@ class MomentoChatMessageHistory(BaseChatMessageHistory): else: raise Exception(f"Unexpected response: {fetch_response}") - def add_user_message(self, message: str) -> None: - """Store a user message in the cache. - - Args: - message (str): The message to store. - """ - self.__add_message(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - """Store an AI message in the cache. - - Args: - message (str): The message to store. - """ - self.__add_message(AIMessage(content=message)) - - def __add_message(self, message: BaseMessage) -> None: + def add_message(self, message: BaseMessage) -> None: """Store a message in the cache. Args: diff --git a/langchain/memory/chat_message_histories/mongodb.py b/langchain/memory/chat_message_histories/mongodb.py index 7995609b..3455d812 100644 --- a/langchain/memory/chat_message_histories/mongodb.py +++ b/langchain/memory/chat_message_histories/mongodb.py @@ -3,10 +3,8 @@ import logging from typing import List from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, _message_to_dict, messages_from_dict, ) @@ -68,13 +66,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): messages = messages_from_dict(items) return messages - def add_user_message(self, message: str) -> None: - self.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.append(AIMessage(content=message)) - - def append(self, message: BaseMessage) -> None: + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in MongoDB""" from pymongo import errors diff --git a/langchain/memory/chat_message_histories/postgres.py b/langchain/memory/chat_message_histories/postgres.py index ddca8444..4080acf9 100644 --- a/langchain/memory/chat_message_histories/postgres.py +++ b/langchain/memory/chat_message_histories/postgres.py @@ -3,10 +3,8 @@ import logging from typing import List from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, _message_to_dict, messages_from_dict, ) @@ -55,13 +53,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory): messages = messages_from_dict(items) return messages - def add_user_message(self, message: str) -> None: - self.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.append(AIMessage(content=message)) - - def append(self, message: BaseMessage) -> None: + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in PostgreSQL""" from psycopg import sql diff --git a/langchain/memory/chat_message_histories/redis.py b/langchain/memory/chat_message_histories/redis.py index d9c7a9a3..b32ece7b 100644 --- a/langchain/memory/chat_message_histories/redis.py +++ b/langchain/memory/chat_message_histories/redis.py @@ -3,10 +3,8 @@ import logging from typing import List, Optional from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, _message_to_dict, messages_from_dict, ) @@ -52,13 +50,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory): messages = messages_from_dict(items) return messages - def add_user_message(self, message: str) -> None: - self.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.append(AIMessage(content=message)) - - def append(self, message: BaseMessage) -> None: + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in Redis""" self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message))) if self.ttl: diff --git a/langchain/memory/chat_message_histories/sql.py b/langchain/memory/chat_message_histories/sql.py index 40f8691c..6151ad19 100644 --- a/langchain/memory/chat_message_histories/sql.py +++ b/langchain/memory/chat_message_histories/sql.py @@ -7,10 +7,8 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from langchain.schema import ( - AIMessage, BaseChatMessageHistory, BaseMessage, - HumanMessage, _message_to_dict, messages_from_dict, ) @@ -61,13 +59,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory): messages = messages_from_dict(items) return messages - def add_user_message(self, message: str) -> None: - self.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.append(AIMessage(content=message)) - - def append(self, message: BaseMessage) -> None: + def add_message(self, message: BaseMessage) -> None: """Append the message to the record in db""" with self.Session() as session: jsonstr = json.dumps(_message_to_dict(message)) diff --git a/langchain/memory/chat_message_histories/zep.py b/langchain/memory/chat_message_histories/zep.py index a0b03620..698b76ae 100644 --- a/langchain/memory/chat_message_histories/zep.py +++ b/langchain/memory/chat_message_histories/zep.py @@ -116,13 +116,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory): return None return zep_memory - def add_user_message(self, message: str) -> None: - self.append(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - self.append(AIMessage(content=message)) - - def append(self, message: BaseMessage) -> None: + def add_message(self, message: BaseMessage) -> None: """Append the message to the Zep memory history""" from zep_python import Memory, Message diff --git a/langchain/schema.py b/langchain/schema.py index 1e1edeb4..4a04bd04 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -234,18 +234,11 @@ class BaseChatMessageHistory(ABC): messages = json.loads(f.read()) return messages_from_dict(messages) - def add_user_message(self, message: str): - message_ = HumanMessage(content=message) - messages = self.messages.append(_message_to_dict(_message)) + def add_message(self, message: BaseMessage) -> None: + messages = self.messages.append(_message_to_dict(message)) with open(os.path.join(storage_path, session_id), 'w') as f: json.dump(f, messages) - - def add_ai_message(self, message: str): - message_ = AIMessage(content=message) - messages = self.messages.append(_message_to_dict(_message)) - with open(os.path.join(storage_path, session_id), 'w') as f: - json.dump(f, messages) - + def clear(self): with open(os.path.join(storage_path, session_id), 'w') as f: f.write("[]") @@ -253,13 +246,17 @@ class BaseChatMessageHistory(ABC): messages: List[BaseMessage] - @abstractmethod def add_user_message(self, message: str) -> None: """Add a user message to the store""" + self.add_message(HumanMessage(content=message)) - @abstractmethod def add_ai_message(self, message: str) -> None: """Add an AI message to the store""" + self.add_message(AIMessage(content=message)) + + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store""" + raise NotImplementedError @abstractmethod def clear(self) -> None: diff --git a/tests/unit_tests/memory/chat_message_histories/test_zep.py b/tests/unit_tests/memory/chat_message_histories/test_zep.py index 49f9ac07..8dd1b4ac 100644 --- a/tests/unit_tests/memory/chat_message_histories/test_zep.py +++ b/tests/unit_tests/memory/chat_message_histories/test_zep.py @@ -60,7 +60,7 @@ def test_add_ai_message(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) @pytest.mark.requires("zep_python") def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None: - zep_chat.append(AIMessage(content="test message")) + zep_chat.add_message(AIMessage(content="test message")) zep_chat.zep_client.add_memory.assert_called_once() # type: ignore