core[patch]: Add astream events config test (#17055)

Verify that astream events propagates config correctly

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/16905/head^2
Eugene Yurtsev 5 months ago committed by GitHub
parent 609ea019b2
commit fbab8baac5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,23 +1,29 @@
"""Module that contains tests for runnable.astream_events API.""" """Module that contains tests for runnable.astream_events API."""
from itertools import cycle from itertools import cycle
from typing import AsyncIterator, List, Sequence, cast from typing import Any, AsyncIterator, Dict, List, Sequence, cast
import pytest import pytest
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
BaseMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain_core.prompt_values import ChatPromptValue from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import ( from langchain_core.runnables import (
ConfigurableField,
Runnable,
RunnableLambda, RunnableLambda,
) )
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool from langchain_core.tools import tool
from tests.unit_tests.fake.chat_model import GenericFakeChatModel from tests.unit_tests.fake.chat_model import GenericFakeChatModel
@ -1079,3 +1085,130 @@ async def test_runnable_each() -> None:
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"): async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
pass pass
async def test_events_astream_config() -> None:
"""Test that astream events support accepting config"""
infinite_cycle = cycle([AIMessage(content="hello world!")])
good_world_on_repeat = cycle([AIMessage(content="Goodbye world")])
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
messages=ConfigurableField(
id="messages",
name="Messages",
description="Messages return by the LLM",
)
)
model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}})
assert model_02.invoke("hello") == AIMessage(content="Goodbye world")
events = await _collect_events(model_02.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
"event": "on_chat_model_start",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="Goodbye")},
"event": "on_chat_model_stream",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="world")},
"event": "on_chat_model_stream",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
{
"data": {"output": AIMessageChunk(content="Goodbye world")},
"event": "on_chat_model_end",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
]
async def test_runnable_with_message_history() -> None:
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history."""
# Attention: for the tests use an Any type to work-around a pydantic issue
# where it re-instantiates a list, so mutating the list doesn't end up mutating
# the content in the store!
# Using Any type here rather than List[BaseMessage] due to pydantic issue!
messages: Any
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store."""
self.messages.append(message)
def clear(self) -> None:
self.messages = []
# Here we use a global variable to store the chat message history.
# This will make it easier to inspect it to see the underlying results.
store: Dict = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history"""
if session_id not in store:
store[session_id] = []
return InMemoryHistory(messages=store[session_id])
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="world")])
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a cat"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
chain: Runnable = prompt | model
with_message_history = RunnableWithMessageHistory(
chain,
get_session_history=get_by_session_id,
input_messages_key="question",
history_messages_key="history",
)
with_message_history.with_config(
{"configurable": {"session_id": "session-123"}}
).invoke({"question": "hello"})
assert store == {
"session-123": [HumanMessage(content="hello"), AIMessage(content="hello")]
}
with_message_history.with_config(
{"configurable": {"session_id": "session-123"}}
).invoke({"question": "meow"})
assert store == {
"session-123": [
HumanMessage(content="hello"),
AIMessage(content="hello"),
HumanMessage(content="meow"),
AIMessage(content="world"),
]
}

Loading…
Cancel
Save