core[patch]: Add astream events config test (#17055)

Verify that astream events propagates config correctly

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/16905/head^2
Eugene Yurtsev 5 months ago committed by GitHub
parent 609ea019b2
commit fbab8baac5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,23 +1,29 @@
"""Module that contains tests for runnable.astream_events API."""
from itertools import cycle
from typing import AsyncIterator, List, Sequence, cast
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
import pytest
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.documents import Document
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
ConfigurableField,
Runnable,
RunnableLambda,
)
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
@ -1079,3 +1085,130 @@ async def test_runnable_each() -> None:
with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
pass
async def test_events_astream_config() -> None:
"""Test that astream events support accepting config"""
infinite_cycle = cycle([AIMessage(content="hello world!")])
good_world_on_repeat = cycle([AIMessage(content="Goodbye world")])
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
messages=ConfigurableField(
id="messages",
name="Messages",
description="Messages return by the LLM",
)
)
model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}})
assert model_02.invoke("hello") == AIMessage(content="Goodbye world")
events = await _collect_events(model_02.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
"event": "on_chat_model_start",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="Goodbye")},
"event": "on_chat_model_stream",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="world")},
"event": "on_chat_model_stream",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
{
"data": {"output": AIMessageChunk(content="Goodbye world")},
"event": "on_chat_model_end",
"metadata": {},
"name": "RunnableConfigurableFields",
"run_id": "",
"tags": [],
},
]
async def test_runnable_with_message_history() -> None:
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history."""
# Attention: for the tests use an Any type to work-around a pydantic issue
# where it re-instantiates a list, so mutating the list doesn't end up mutating
# the content in the store!
# Using Any type here rather than List[BaseMessage] due to pydantic issue!
messages: Any
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store."""
self.messages.append(message)
def clear(self) -> None:
self.messages = []
# Here we use a global variable to store the chat message history.
# This will make it easier to inspect it to see the underlying results.
store: Dict = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history"""
if session_id not in store:
store[session_id] = []
return InMemoryHistory(messages=store[session_id])
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="world")])
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a cat"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
chain: Runnable = prompt | model
with_message_history = RunnableWithMessageHistory(
chain,
get_session_history=get_by_session_id,
input_messages_key="question",
history_messages_key="history",
)
with_message_history.with_config(
{"configurable": {"session_id": "session-123"}}
).invoke({"question": "hello"})
assert store == {
"session-123": [HumanMessage(content="hello"), AIMessage(content="hello")]
}
with_message_history.with_config(
{"configurable": {"session_id": "session-123"}}
).invoke({"question": "meow"})
assert store == {
"session-123": [
HumanMessage(content="hello"),
AIMessage(content="hello"),
HumanMessage(content="meow"),
AIMessage(content="world"),
]
}

Loading…
Cancel
Save