core[patch]: Use InMemoryChatMessageHistory in unit tests (#23916)

Update unit test to use the existing implementation of chat message
history
This commit is contained in:
Eugene Yurtsev 2024-07-05 16:10:54 -04:00 committed by GitHub
parent 8b84457b17
commit 9787552b00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 44 deletions

View File

@ -1,25 +0,0 @@
from typing import List
from langchain_core.chat_history import (
BaseChatMessageHistory,
)
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel, Field
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history.
Stores messages in an in memory list.
"""
messages: List[BaseMessage] = Field(default_factory=list)
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
if not isinstance(message, BaseMessage):
raise ValueError
self.messages.append(message)
def clear(self) -> None:
self.messages = []

View File

@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
@ -11,11 +12,10 @@ from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec
from tests.unit_tests.fake.memory import ChatMessageHistory
def test_interfaces() -> None:
history = ChatMessageHistory()
history = InMemoryChatMessageHistory()
history.add_message(SystemMessage(content="system"))
history.add_user_message("human 1")
history.add_ai_message("ai")
@ -26,12 +26,14 @@ def test_interfaces() -> None:
def _get_get_session_history(
*,
store: Optional[Dict[str, Any]] = None,
) -> Callable[..., ChatMessageHistory]:
) -> Callable[..., InMemoryChatMessageHistory]:
chat_history_store = store if store is not None else {}
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
def get_session_history(
session_id: str, **kwargs: Any
) -> InMemoryChatMessageHistory:
if session_id not in chat_history_store:
chat_history_store[session_id] = ChatMessageHistory()
chat_history_store[session_id] = InMemoryChatMessageHistory()
return chat_history_store[session_id]
return get_session_history
@ -51,7 +53,7 @@ def test_input_messages() -> None:
output = with_history.invoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye"
assert store == {
"1": ChatMessageHistory(
"1": InMemoryChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
@ -76,7 +78,7 @@ async def test_input_messages_async() -> None:
output = await with_history.ainvoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye"
assert store == {
"1_async": ChatMessageHistory(
"1_async": InMemoryChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
@ -485,9 +487,11 @@ def test_using_custom_config_specs() -> None:
runnable = RunnableLambda(_fake_llm)
store = {}
def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory:
def get_session_history(
user_id: str, conversation_id: str
) -> InMemoryChatMessageHistory:
if (user_id, conversation_id) not in store:
store[(user_id, conversation_id)] = ChatMessageHistory()
store[(user_id, conversation_id)] = InMemoryChatMessageHistory()
return store[(user_id, conversation_id)]
with_message_history = RunnableWithMessageHistory(
@ -524,7 +528,7 @@ def test_using_custom_config_specs() -> None:
AIMessage(content="you said: hello"),
]
assert store == {
("user1", "1"): ChatMessageHistory(
("user1", "1"): InMemoryChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
@ -542,7 +546,7 @@ def test_using_custom_config_specs() -> None:
AIMessage(content="you said: goodbye"),
]
assert store == {
("user1", "1"): ChatMessageHistory(
("user1", "1"): InMemoryChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
@ -562,7 +566,7 @@ def test_using_custom_config_specs() -> None:
AIMessage(content="you said: meow"),
]
assert store == {
("user1", "1"): ChatMessageHistory(
("user1", "1"): InMemoryChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
@ -570,7 +574,7 @@ def test_using_custom_config_specs() -> None:
AIMessage(content="you said: goodbye"),
]
),
("user2", "1"): ChatMessageHistory(
("user2", "1"): InMemoryChatMessageHistory(
messages=[
HumanMessage(content="meow"),
AIMessage(content="you said: meow"),
@ -596,9 +600,11 @@ async def test_using_custom_config_specs_async() -> None:
runnable = RunnableLambda(_fake_llm)
store = {}
def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory:
def get_session_history(
user_id: str, conversation_id: str
) -> InMemoryChatMessageHistory:
if (user_id, conversation_id) not in store:
store[(user_id, conversation_id)] = ChatMessageHistory()
store[(user_id, conversation_id)] = InMemoryChatMessageHistory()
return store[(user_id, conversation_id)]
with_message_history = RunnableWithMessageHistory(
@ -635,7 +641,7 @@ async def test_using_custom_config_specs_async() -> None:
AIMessage(content="you said: hello"),
]
assert store == {
("user1_async", "1_async"): ChatMessageHistory(
("user1_async", "1_async"): InMemoryChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
@ -653,7 +659,7 @@ async def test_using_custom_config_specs_async() -> None:
AIMessage(content="you said: goodbye"),
]
assert store == {
("user1_async", "1_async"): ChatMessageHistory(
("user1_async", "1_async"): InMemoryChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
@ -673,7 +679,7 @@ async def test_using_custom_config_specs_async() -> None:
AIMessage(content="you said: meow"),
]
assert store == {
("user1_async", "1_async"): ChatMessageHistory(
("user1_async", "1_async"): InMemoryChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
@ -681,7 +687,7 @@ async def test_using_custom_config_specs_async() -> None:
AIMessage(content="you said: goodbye"),
]
),
("user2_async", "1_async"): ChatMessageHistory(
("user2_async", "1_async"): InMemoryChatMessageHistory(
messages=[
HumanMessage(content="meow"),
AIMessage(content="you said: meow"),