mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
core(minor): Add bulk add messages to BaseChatMessageHistory interface (#15709)
* Add bulk add_messages method to the interface. * Update documentation for add_ai_message and add_human_message to denote them as being marked for deprecation. We should stop using them as they create more incorrect (inefficient) ways of doing things
This commit is contained in:
parent
af8c5c185b
commit
2e5949b6f8
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -14,9 +14,18 @@ from langchain_core.messages import (
|
||||
class BaseChatMessageHistory(ABC):
|
||||
"""Abstract base class for storing chat message history.
|
||||
|
||||
See `ChatMessageHistory` for default implementation.
|
||||
Implementations should over-ride the add_messages method to handle bulk addition
|
||||
of messages.
|
||||
|
||||
The default implementation of add_message will correctly call add_messages, so
|
||||
it is not necessary to implement both methods.
|
||||
|
||||
When used for updating history, users should favor usage of `add_messages`
|
||||
over `add_message` or other variants like `add_user_message` and `add_ai_message`
|
||||
to avoid unnecessary round-trips to the underlying persistence layer.
|
||||
|
||||
Example: Shows a default implementation.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||
@ -29,8 +38,13 @@ class BaseChatMessageHistory(ABC):
|
||||
messages = json.loads(f.read())
|
||||
return messages_from_dict(messages)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
messages = self.messages.append(_message_to_dict(message))
|
||||
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
all_messages = list(self.messages) # Existing messages
|
||||
all_messages.extend(messages) # Add new messages
|
||||
|
||||
serialized = [message_to_dict(message) for message in all_messages]
|
||||
# Can be further optimized by only writing new messages
|
||||
# using append mode.
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
json.dump(f, messages)
|
||||
|
||||
@ -45,6 +59,12 @@ class BaseChatMessageHistory(ABC):
|
||||
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
|
||||
"""Convenience method for adding a human message string to the store.
|
||||
|
||||
Please note that this is a convenience method. Code should favor the
|
||||
bulk add_messages interface instead to save on round-trips to the underlying
|
||||
persistence layer.
|
||||
|
||||
This method may be deprecated in a future release.
|
||||
|
||||
Args:
|
||||
message: The human message to add
|
||||
"""
|
||||
@ -56,6 +76,12 @@ class BaseChatMessageHistory(ABC):
|
||||
def add_ai_message(self, message: Union[AIMessage, str]) -> None:
|
||||
"""Convenience method for adding an AI message string to the store.
|
||||
|
||||
Please note that this is a convenience method. Code should favor the bulk
|
||||
add_messages interface instead to save on round-trips to the underlying
|
||||
persistence layer.
|
||||
|
||||
This method may be deprecated in a future release.
|
||||
|
||||
Args:
|
||||
message: The AI message to add.
|
||||
"""
|
||||
@ -64,18 +90,38 @@ class BaseChatMessageHistory(ABC):
|
||||
else:
|
||||
self.add_message(AIMessage(content=message))
|
||||
|
||||
@abstractmethod
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a Message object to the store.
|
||||
|
||||
Args:
|
||||
message: A BaseMessage object to store.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
if type(self).add_messages != BaseChatMessageHistory.add_messages:
|
||||
# This means that the sub-class has implemented an efficient add_messages
|
||||
# method, so we should usage of add_message to that.
|
||||
self.add_messages([message])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"add_message is not implemented for this class. "
|
||||
"Please implement add_message or add_messages."
|
||||
)
|
||||
|
||||
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
"""Add a list of messages.
|
||||
|
||||
Implementations should over-ride this method to handle bulk addition of messages
|
||||
in an efficient manner to avoid unnecessary round-trips to the underlying store.
|
||||
|
||||
Args:
|
||||
messages: A list of BaseMessage objects to store.
|
||||
"""
|
||||
for message in messages:
|
||||
self.add_message(message)
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the chat history."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
0
libs/core/tests/unit_tests/chat_history/__init__.py
Normal file
0
libs/core/tests/unit_tests/chat_history/__init__.py
Normal file
68
libs/core/tests/unit_tests/chat_history/test_chat_history.py
Normal file
68
libs/core/tests/unit_tests/chat_history/test_chat_history.py
Normal file
@ -0,0 +1,68 @@
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
|
||||
|
||||
def test_add_message_implementation_only() -> None:
|
||||
"""Test implementation of add_message only."""
|
||||
|
||||
class SampleChatHistory(BaseChatMessageHistory):
|
||||
def __init__(self, *, store: List[BaseMessage]) -> None:
|
||||
self.store = store
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a message to the store."""
|
||||
self.store.append(message)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the store."""
|
||||
raise NotImplementedError()
|
||||
|
||||
store: List[BaseMessage] = []
|
||||
chat_history = SampleChatHistory(store=store)
|
||||
chat_history.add_message(HumanMessage(content="Hello"))
|
||||
assert len(store) == 1
|
||||
assert store[0] == HumanMessage(content="Hello")
|
||||
chat_history.add_message(HumanMessage(content="World"))
|
||||
assert len(store) == 2
|
||||
assert store[1] == HumanMessage(content="World")
|
||||
|
||||
chat_history.add_messages(
|
||||
[HumanMessage(content="Hello"), HumanMessage(content="World")]
|
||||
)
|
||||
assert len(store) == 4
|
||||
assert store[2] == HumanMessage(content="Hello")
|
||||
assert store[3] == HumanMessage(content="World")
|
||||
|
||||
|
||||
def test_bulk_message_implementation_only() -> None:
|
||||
"""Test that SampleChatHistory works as expected."""
|
||||
store: List[BaseMessage] = []
|
||||
|
||||
class BulkAddHistory(BaseChatMessageHistory):
|
||||
def __init__(self, *, store: List[BaseMessage]) -> None:
|
||||
self.store = store
|
||||
|
||||
def add_messages(self, message: Sequence[BaseMessage]) -> None:
|
||||
"""Add a message to the store."""
|
||||
self.store.extend(message)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the store."""
|
||||
raise NotImplementedError()
|
||||
|
||||
chat_history = BulkAddHistory(store=store)
|
||||
chat_history.add_message(HumanMessage(content="Hello"))
|
||||
assert len(store) == 1
|
||||
assert store[0] == HumanMessage(content="Hello")
|
||||
chat_history.add_message(HumanMessage(content="World"))
|
||||
assert len(store) == 2
|
||||
assert store[1] == HumanMessage(content="World")
|
||||
|
||||
chat_history.add_messages(
|
||||
[HumanMessage(content="Hello"), HumanMessage(content="World")]
|
||||
)
|
||||
assert len(store) == 4
|
||||
assert store[2] == HumanMessage(content="Hello")
|
||||
assert store[3] == HumanMessage(content="World")
|
Loading…
Reference in New Issue
Block a user