From 9a0356d276f74cc213782c833918ca6c83130e1d Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 19 Apr 2023 21:05:20 -0700 Subject: [PATCH] Harrison/file chat history (#3198) Co-authored-by: Young Lee --- .../memory/chat_message_histories/__init__.py | 2 + .../memory/chat_message_histories/file.py | 53 ++++++++++++++ tests/unit_tests/memory/__init__.py | 1 + .../memory/chat_message_histories/__init__.py | 1 + .../chat_message_histories/test_file.py | 71 +++++++++++++++++++ 5 files changed, 128 insertions(+) create mode 100644 langchain/memory/chat_message_histories/file.py create mode 100644 tests/unit_tests/memory/__init__.py create mode 100644 tests/unit_tests/memory/chat_message_histories/__init__.py create mode 100644 tests/unit_tests/memory/chat_message_histories/test_file.py diff --git a/langchain/memory/chat_message_histories/__init__.py b/langchain/memory/chat_message_histories/__init__.py index ee8a6222..d113e280 100644 --- a/langchain/memory/chat_message_histories/__init__.py +++ b/langchain/memory/chat_message_histories/__init__.py @@ -1,4 +1,5 @@ from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory +from langchain.memory.chat_message_histories.file import FileChatMessageHistory from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory @@ -6,4 +7,5 @@ __all__ = [ "DynamoDBChatMessageHistory", "RedisChatMessageHistory", "PostgresChatMessageHistory", + "FileChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/file.py b/langchain/memory/chat_message_histories/file.py new file mode 100644 index 00000000..37ca6f27 --- /dev/null +++ b/langchain/memory/chat_message_histories/file.py @@ -0,0 +1,53 @@ +import json +import logging +from pathlib import Path +from typing import List + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, + messages_from_dict, + messages_to_dict, +) + +logger = logging.getLogger(__name__) + + +class FileChatMessageHistory(BaseChatMessageHistory): + """ + Chat message history that stores history in a local file. + + Args: + file_path: path of the local file to store the messages. + """ + + def __init__(self, file_path: str): + self.file_path = Path(file_path) + if not self.file_path.exists(): + self.file_path.touch() + self.file_path.write_text(json.dumps([])) + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from the local file""" + items = json.loads(self.file_path.read_text()) + messages = messages_from_dict(items) + return messages + + def add_user_message(self, message: str) -> None: + self.append(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + self.append(AIMessage(content=message)) + + def append(self, message: BaseMessage) -> None: + """Append the message to the record in the local file""" + messages = messages_to_dict(self.messages) + messages.append(messages_to_dict([message])[0]) + self.file_path.write_text(json.dumps(messages)) + + def clear(self) -> None: + """Clear session memory from the local file""" + self.file_path.write_text(json.dumps([])) diff --git a/tests/unit_tests/memory/__init__.py b/tests/unit_tests/memory/__init__.py new file mode 100644 index 00000000..2494d102 --- /dev/null +++ b/tests/unit_tests/memory/__init__.py @@ -0,0 +1 @@ +"""Unit tests for memory module""" diff --git a/tests/unit_tests/memory/chat_message_histories/__init__.py b/tests/unit_tests/memory/chat_message_histories/__init__.py new file mode 100644 index 00000000..eed005a6 --- /dev/null +++ b/tests/unit_tests/memory/chat_message_histories/__init__.py @@ -0,0 +1 @@ +"""Unit tests for chat_message_history modules""" diff --git a/tests/unit_tests/memory/chat_message_histories/test_file.py b/tests/unit_tests/memory/chat_message_histories/test_file.py new file mode 100644 index 00000000..13962370 --- /dev/null +++ b/tests/unit_tests/memory/chat_message_histories/test_file.py @@ -0,0 +1,71 @@ +import tempfile +from pathlib import Path +from typing import Generator + +import pytest + +from langchain.memory.chat_message_histories import FileChatMessageHistory +from langchain.schema import AIMessage, HumanMessage + + +@pytest.fixture +def file_chat_message_history() -> Generator[FileChatMessageHistory, None, None]: + with tempfile.TemporaryDirectory() as temp_dir: + file_path = Path(temp_dir) / "test_chat_history.json" + file_chat_message_history = FileChatMessageHistory(str(file_path)) + yield file_chat_message_history + + +def test_add_messages(file_chat_message_history: FileChatMessageHistory) -> None: + file_chat_message_history.add_user_message("Hello!") + file_chat_message_history.add_ai_message("Hi there!") + + messages = file_chat_message_history.messages + assert len(messages) == 2 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert messages[0].content == "Hello!" + assert messages[1].content == "Hi there!" + + +def test_clear_messages(file_chat_message_history: FileChatMessageHistory) -> None: + file_chat_message_history.add_user_message("Hello!") + file_chat_message_history.add_ai_message("Hi there!") + + file_chat_message_history.clear() + messages = file_chat_message_history.messages + assert len(messages) == 0 + + +def test_multiple_sessions(file_chat_message_history: FileChatMessageHistory) -> None: + # First session + file_chat_message_history.add_user_message("Hello, AI!") + file_chat_message_history.add_ai_message("Hello, how can I help you?") + file_chat_message_history.add_user_message("Tell me a joke.") + file_chat_message_history.add_ai_message( + "Why did the chicken cross the road? To get to the other side!" + ) + + # Ensure the messages are added correctly in the first session + messages = file_chat_message_history.messages + assert len(messages) == 4 + assert messages[0].content == "Hello, AI!" + assert messages[1].content == "Hello, how can I help you?" + assert messages[2].content == "Tell me a joke." + expected_content = "Why did the chicken cross the road? To get to the other side!" + assert messages[3].content == expected_content + + # Second session (reinitialize FileChatMessageHistory) + file_path = file_chat_message_history.file_path + second_session_chat_message_history = FileChatMessageHistory( + file_path=str(file_path) + ) + + # Ensure the history is maintained in the second session + messages = second_session_chat_message_history.messages + assert len(messages) == 4 + assert messages[0].content == "Hello, AI!" + assert messages[1].content == "Hello, how can I help you?" + assert messages[2].content == "Tell me a joke." + expected_content = "Why did the chicken cross the road? To get to the other side!" + assert messages[3].content == expected_content