core(minor): Add bulk add messages to BaseChatMessageHistory interface (#15709)

* Add bulk add_messages method to the interface.
* Update documentation for add_ai_message and add_human_message to
denote them as being marked for deprecation. We should stop using them
as they create more incorrect (inefficient) ways of doing things
This commit is contained in:
Eugene Yurtsev 2024-01-31 11:59:39 -08:00 committed by GitHub
parent af8c5c185b
commit 2e5949b6f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 121 additions and 7 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Union
from typing import List, Sequence, Union
from langchain_core.messages import (
AIMessage,
@ -14,9 +14,18 @@ from langchain_core.messages import (
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.
See `ChatMessageHistory` for default implementation.
Implementations should over-ride the add_messages method to handle bulk addition
of messages.
The default implementation of add_message will correctly call add_messages, so
it is not necessary to implement both methods.
When used for updating history, users should favor usage of `add_messages`
over `add_message` or other variants like `add_user_message` and `add_ai_message`
to avoid unnecessary round-trips to the underlying persistence layer.
Example: Shows a default implementation.
Example:
.. code-block:: python
class FileChatMessageHistory(BaseChatMessageHistory):
@ -29,8 +38,13 @@ class BaseChatMessageHistory(ABC):
messages = json.loads(f.read())
return messages_from_dict(messages)
def add_message(self, message: BaseMessage) -> None:
messages = self.messages.append(_message_to_dict(message))
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
all_messages = list(self.messages) # Existing messages
all_messages.extend(messages) # Add new messages
serialized = [message_to_dict(message) for message in all_messages]
# Can be further optimized by only writing new messages
# using append mode.
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
@ -45,6 +59,12 @@ class BaseChatMessageHistory(ABC):
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
"""Convenience method for adding a human message string to the store.
Please note that this is a convenience method. Code should favor the
bulk add_messages interface instead to save on round-trips to the underlying
persistence layer.
This method may be deprecated in a future release.
Args:
message: The human message to add
"""
@ -56,6 +76,12 @@ class BaseChatMessageHistory(ABC):
def add_ai_message(self, message: Union[AIMessage, str]) -> None:
"""Convenience method for adding an AI message string to the store.
Please note that this is a convenience method. Code should favor the bulk
add_messages interface instead to save on round-trips to the underlying
persistence layer.
This method may be deprecated in a future release.
Args:
message: The AI message to add.
"""
@ -64,18 +90,38 @@ class BaseChatMessageHistory(ABC):
else:
self.add_message(AIMessage(content=message))
@abstractmethod
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
raise NotImplementedError()
if type(self).add_messages != BaseChatMessageHistory.add_messages:
# This means that the sub-class has implemented an efficient add_messages
# method, so we should usage of add_message to that.
self.add_messages([message])
else:
raise NotImplementedError(
"add_message is not implemented for this class. "
"Please implement add_message or add_messages."
)
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add a list of messages.
Implementations should over-ride this method to handle bulk addition of messages
in an efficient manner to avoid unnecessary round-trips to the underlying store.
Args:
messages: A list of BaseMessage objects to store.
"""
for message in messages:
self.add_message(message)
@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""
def __str__(self) -> str:
"""Return a string representation of the chat history."""
return get_buffer_string(self.messages)

View File

@ -0,0 +1,68 @@
from typing import List, Sequence
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage
def test_add_message_implementation_only() -> None:
"""Test implementation of add_message only."""
class SampleChatHistory(BaseChatMessageHistory):
def __init__(self, *, store: List[BaseMessage]) -> None:
self.store = store
def add_message(self, message: BaseMessage) -> None:
"""Add a message to the store."""
self.store.append(message)
def clear(self) -> None:
"""Clear the store."""
raise NotImplementedError()
store: List[BaseMessage] = []
chat_history = SampleChatHistory(store=store)
chat_history.add_message(HumanMessage(content="Hello"))
assert len(store) == 1
assert store[0] == HumanMessage(content="Hello")
chat_history.add_message(HumanMessage(content="World"))
assert len(store) == 2
assert store[1] == HumanMessage(content="World")
chat_history.add_messages(
[HumanMessage(content="Hello"), HumanMessage(content="World")]
)
assert len(store) == 4
assert store[2] == HumanMessage(content="Hello")
assert store[3] == HumanMessage(content="World")
def test_bulk_message_implementation_only() -> None:
"""Test that SampleChatHistory works as expected."""
store: List[BaseMessage] = []
class BulkAddHistory(BaseChatMessageHistory):
def __init__(self, *, store: List[BaseMessage]) -> None:
self.store = store
def add_messages(self, message: Sequence[BaseMessage]) -> None:
"""Add a message to the store."""
self.store.extend(message)
def clear(self) -> None:
"""Clear the store."""
raise NotImplementedError()
chat_history = BulkAddHistory(store=store)
chat_history.add_message(HumanMessage(content="Hello"))
assert len(store) == 1
assert store[0] == HumanMessage(content="Hello")
chat_history.add_message(HumanMessage(content="World"))
assert len(store) == 2
assert store[1] == HumanMessage(content="World")
chat_history.add_messages(
[HumanMessage(content="Hello"), HumanMessage(content="World")]
)
assert len(store) == 4
assert store[2] == HumanMessage(content="Hello")
assert store[3] == HumanMessage(content="World")