mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[patch]: Use InMemoryChatMessageHistory in unit tests (#23916)
Update unit test to use the existing implementation of chat message history
This commit is contained in:
parent
8b84457b17
commit
9787552b00
@ -1,25 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.chat_history import (
|
||||
BaseChatMessageHistory,
|
||||
)
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||
"""In memory implementation of chat message history.
|
||||
|
||||
Stores messages in an in memory list.
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage] = Field(default_factory=list)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a self-created message to the store"""
|
||||
if not isinstance(message, BaseMessage):
|
||||
raise ValueError
|
||||
self.messages.append(message)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.messages = []
|
@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
@ -11,11 +12,10 @@ from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec
|
||||
from tests.unit_tests.fake.memory import ChatMessageHistory
|
||||
|
||||
|
||||
def test_interfaces() -> None:
|
||||
history = ChatMessageHistory()
|
||||
history = InMemoryChatMessageHistory()
|
||||
history.add_message(SystemMessage(content="system"))
|
||||
history.add_user_message("human 1")
|
||||
history.add_ai_message("ai")
|
||||
@ -26,12 +26,14 @@ def test_interfaces() -> None:
|
||||
def _get_get_session_history(
|
||||
*,
|
||||
store: Optional[Dict[str, Any]] = None,
|
||||
) -> Callable[..., ChatMessageHistory]:
|
||||
) -> Callable[..., InMemoryChatMessageHistory]:
|
||||
chat_history_store = store if store is not None else {}
|
||||
|
||||
def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
|
||||
def get_session_history(
|
||||
session_id: str, **kwargs: Any
|
||||
) -> InMemoryChatMessageHistory:
|
||||
if session_id not in chat_history_store:
|
||||
chat_history_store[session_id] = ChatMessageHistory()
|
||||
chat_history_store[session_id] = InMemoryChatMessageHistory()
|
||||
return chat_history_store[session_id]
|
||||
|
||||
return get_session_history
|
||||
@ -51,7 +53,7 @@ def test_input_messages() -> None:
|
||||
output = with_history.invoke([HumanMessage(content="good bye")], config)
|
||||
assert output == "you said: hello\ngood bye"
|
||||
assert store == {
|
||||
"1": ChatMessageHistory(
|
||||
"1": InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
@ -76,7 +78,7 @@ async def test_input_messages_async() -> None:
|
||||
output = await with_history.ainvoke([HumanMessage(content="good bye")], config)
|
||||
assert output == "you said: hello\ngood bye"
|
||||
assert store == {
|
||||
"1_async": ChatMessageHistory(
|
||||
"1_async": InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
@ -485,9 +487,11 @@ def test_using_custom_config_specs() -> None:
|
||||
runnable = RunnableLambda(_fake_llm)
|
||||
store = {}
|
||||
|
||||
def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory:
|
||||
def get_session_history(
|
||||
user_id: str, conversation_id: str
|
||||
) -> InMemoryChatMessageHistory:
|
||||
if (user_id, conversation_id) not in store:
|
||||
store[(user_id, conversation_id)] = ChatMessageHistory()
|
||||
store[(user_id, conversation_id)] = InMemoryChatMessageHistory()
|
||||
return store[(user_id, conversation_id)]
|
||||
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
@ -524,7 +528,7 @@ def test_using_custom_config_specs() -> None:
|
||||
AIMessage(content="you said: hello"),
|
||||
]
|
||||
assert store == {
|
||||
("user1", "1"): ChatMessageHistory(
|
||||
("user1", "1"): InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
@ -542,7 +546,7 @@ def test_using_custom_config_specs() -> None:
|
||||
AIMessage(content="you said: goodbye"),
|
||||
]
|
||||
assert store == {
|
||||
("user1", "1"): ChatMessageHistory(
|
||||
("user1", "1"): InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
@ -562,7 +566,7 @@ def test_using_custom_config_specs() -> None:
|
||||
AIMessage(content="you said: meow"),
|
||||
]
|
||||
assert store == {
|
||||
("user1", "1"): ChatMessageHistory(
|
||||
("user1", "1"): InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
@ -570,7 +574,7 @@ def test_using_custom_config_specs() -> None:
|
||||
AIMessage(content="you said: goodbye"),
|
||||
]
|
||||
),
|
||||
("user2", "1"): ChatMessageHistory(
|
||||
("user2", "1"): InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="meow"),
|
||||
AIMessage(content="you said: meow"),
|
||||
@ -596,9 +600,11 @@ async def test_using_custom_config_specs_async() -> None:
|
||||
runnable = RunnableLambda(_fake_llm)
|
||||
store = {}
|
||||
|
||||
def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory:
|
||||
def get_session_history(
|
||||
user_id: str, conversation_id: str
|
||||
) -> InMemoryChatMessageHistory:
|
||||
if (user_id, conversation_id) not in store:
|
||||
store[(user_id, conversation_id)] = ChatMessageHistory()
|
||||
store[(user_id, conversation_id)] = InMemoryChatMessageHistory()
|
||||
return store[(user_id, conversation_id)]
|
||||
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
@ -635,7 +641,7 @@ async def test_using_custom_config_specs_async() -> None:
|
||||
AIMessage(content="you said: hello"),
|
||||
]
|
||||
assert store == {
|
||||
("user1_async", "1_async"): ChatMessageHistory(
|
||||
("user1_async", "1_async"): InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
@ -653,7 +659,7 @@ async def test_using_custom_config_specs_async() -> None:
|
||||
AIMessage(content="you said: goodbye"),
|
||||
]
|
||||
assert store == {
|
||||
("user1_async", "1_async"): ChatMessageHistory(
|
||||
("user1_async", "1_async"): InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
@ -673,7 +679,7 @@ async def test_using_custom_config_specs_async() -> None:
|
||||
AIMessage(content="you said: meow"),
|
||||
]
|
||||
assert store == {
|
||||
("user1_async", "1_async"): ChatMessageHistory(
|
||||
("user1_async", "1_async"): InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
@ -681,7 +687,7 @@ async def test_using_custom_config_specs_async() -> None:
|
||||
AIMessage(content="you said: goodbye"),
|
||||
]
|
||||
),
|
||||
("user2_async", "1_async"): ChatMessageHistory(
|
||||
("user2_async", "1_async"): InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="meow"),
|
||||
AIMessage(content="you said: meow"),
|
||||
|
Loading…
Reference in New Issue
Block a user