From fbab8baac507a1f80bb26f9513195d18b0b217fc Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 5 Feb 2024 20:24:58 -0500 Subject: [PATCH] core[patch]: Add astream events config test (#17055) Verify that astream events propagates config correctly --------- Co-authored-by: Bagatur --- .../runnables/test_runnable_events.py | 137 +++++++++++++++++- 1 file changed, 135 insertions(+), 2 deletions(-) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events.py index 6d5de10941..a31e15e7ba 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events.py @@ -1,23 +1,29 @@ """Module that contains tests for runnable.astream_events API.""" from itertools import cycle -from typing import AsyncIterator, List, Sequence, cast +from typing import Any, AsyncIterator, Dict, List, Sequence, cast import pytest from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks +from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.documents import Document from langchain_core.messages import ( AIMessage, AIMessageChunk, + BaseMessage, HumanMessage, SystemMessage, ) from langchain_core.prompt_values import ChatPromptValue -from langchain_core.prompts import ChatPromptTemplate +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.pydantic_v1 import BaseModel from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( + ConfigurableField, + Runnable, RunnableLambda, ) +from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import tool from tests.unit_tests.fake.chat_model import GenericFakeChatModel @@ -1079,3 +1085,130 @@ async def test_runnable_each() -> None: with pytest.raises(NotImplementedError): async for _ in add_one_map.astream_events([1, 2, 3], version="v1"): pass + + +async def test_events_astream_config() -> None: + """Test that astream events support accepting config""" + infinite_cycle = cycle([AIMessage(content="hello world!")]) + good_world_on_repeat = cycle([AIMessage(content="Goodbye world")]) + model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields( + messages=ConfigurableField( + id="messages", + name="Messages", + description="Messages return by the LLM", + ) + ) + + model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}}) + assert model_02.invoke("hello") == AIMessage(content="Goodbye world") + + events = await _collect_events(model_02.astream_events("hello", version="v1")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chat_model_start", + "metadata": {}, + "name": "RunnableConfigurableFields", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": AIMessageChunk(content="Goodbye")}, + "event": "on_chat_model_stream", + "metadata": {}, + "name": "RunnableConfigurableFields", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": AIMessageChunk(content=" ")}, + "event": "on_chat_model_stream", + "metadata": {}, + "name": "RunnableConfigurableFields", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": AIMessageChunk(content="world")}, + "event": "on_chat_model_stream", + "metadata": {}, + "name": "RunnableConfigurableFields", + "run_id": "", + "tags": [], + }, + { + "data": {"output": AIMessageChunk(content="Goodbye world")}, + "event": "on_chat_model_end", + "metadata": {}, + "name": "RunnableConfigurableFields", + "run_id": "", + "tags": [], + }, + ] + + +async def test_runnable_with_message_history() -> None: + class InMemoryHistory(BaseChatMessageHistory, BaseModel): + """In memory implementation of chat message history.""" + + # Attention: for the tests use an Any type to work-around a pydantic issue + # where it re-instantiates a list, so mutating the list doesn't end up mutating + # the content in the store! + + # Using Any type here rather than List[BaseMessage] due to pydantic issue! + messages: Any + + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store.""" + self.messages.append(message) + + def clear(self) -> None: + self.messages = [] + + # Here we use a global variable to store the chat message history. + # This will make it easier to inspect it to see the underlying results. + store: Dict = {} + + def get_by_session_id(session_id: str) -> BaseChatMessageHistory: + """Get a chat message history""" + if session_id not in store: + store[session_id] = [] + return InMemoryHistory(messages=store[session_id]) + + infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="world")]) + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a cat"), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ] + ) + model = GenericFakeChatModel(messages=infinite_cycle) + + chain: Runnable = prompt | model + with_message_history = RunnableWithMessageHistory( + chain, + get_session_history=get_by_session_id, + input_messages_key="question", + history_messages_key="history", + ) + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke({"question": "hello"}) + + assert store == { + "session-123": [HumanMessage(content="hello"), AIMessage(content="hello")] + } + + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke({"question": "meow"}) + assert store == { + "session-123": [ + HumanMessage(content="hello"), + AIMessage(content="hello"), + HumanMessage(content="meow"), + AIMessage(content="world"), + ] + }