|
|
@ -1,23 +1,29 @@
|
|
|
|
"""Module that contains tests for runnable.astream_events API."""
|
|
|
|
"""Module that contains tests for runnable.astream_events API."""
|
|
|
|
from itertools import cycle
|
|
|
|
from itertools import cycle
|
|
|
|
from typing import AsyncIterator, List, Sequence, cast
|
|
|
|
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
|
|
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
|
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
|
|
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
|
|
|
|
|
|
|
from langchain_core.chat_history import BaseChatMessageHistory
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
from langchain_core.messages import (
|
|
|
|
from langchain_core.messages import (
|
|
|
|
AIMessage,
|
|
|
|
AIMessage,
|
|
|
|
AIMessageChunk,
|
|
|
|
AIMessageChunk,
|
|
|
|
|
|
|
|
BaseMessage,
|
|
|
|
HumanMessage,
|
|
|
|
HumanMessage,
|
|
|
|
SystemMessage,
|
|
|
|
SystemMessage,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from langchain_core.prompt_values import ChatPromptValue
|
|
|
|
from langchain_core.prompt_values import ChatPromptValue
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
|
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel
|
|
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
from langchain_core.runnables import (
|
|
|
|
from langchain_core.runnables import (
|
|
|
|
|
|
|
|
ConfigurableField,
|
|
|
|
|
|
|
|
Runnable,
|
|
|
|
RunnableLambda,
|
|
|
|
RunnableLambda,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
|
from langchain_core.runnables.schema import StreamEvent
|
|
|
|
from langchain_core.runnables.schema import StreamEvent
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
|
|
|
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
|
|
@ -1079,3 +1085,130 @@ async def test_runnable_each() -> None:
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
|
|
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
|
|
|
|
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_events_astream_config() -> None:
|
|
|
|
|
|
|
|
"""Test that astream events support accepting config"""
|
|
|
|
|
|
|
|
infinite_cycle = cycle([AIMessage(content="hello world!")])
|
|
|
|
|
|
|
|
good_world_on_repeat = cycle([AIMessage(content="Goodbye world")])
|
|
|
|
|
|
|
|
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
|
|
|
|
|
|
|
|
messages=ConfigurableField(
|
|
|
|
|
|
|
|
id="messages",
|
|
|
|
|
|
|
|
name="Messages",
|
|
|
|
|
|
|
|
description="Messages return by the LLM",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}})
|
|
|
|
|
|
|
|
assert model_02.invoke("hello") == AIMessage(content="Goodbye world")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
events = await _collect_events(model_02.astream_events("hello", version="v1"))
|
|
|
|
|
|
|
|
assert events == [
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"data": {"input": "hello"},
|
|
|
|
|
|
|
|
"event": "on_chat_model_start",
|
|
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
|
|
"name": "RunnableConfigurableFields",
|
|
|
|
|
|
|
|
"run_id": "",
|
|
|
|
|
|
|
|
"tags": [],
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"data": {"chunk": AIMessageChunk(content="Goodbye")},
|
|
|
|
|
|
|
|
"event": "on_chat_model_stream",
|
|
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
|
|
"name": "RunnableConfigurableFields",
|
|
|
|
|
|
|
|
"run_id": "",
|
|
|
|
|
|
|
|
"tags": [],
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"data": {"chunk": AIMessageChunk(content=" ")},
|
|
|
|
|
|
|
|
"event": "on_chat_model_stream",
|
|
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
|
|
"name": "RunnableConfigurableFields",
|
|
|
|
|
|
|
|
"run_id": "",
|
|
|
|
|
|
|
|
"tags": [],
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"data": {"chunk": AIMessageChunk(content="world")},
|
|
|
|
|
|
|
|
"event": "on_chat_model_stream",
|
|
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
|
|
"name": "RunnableConfigurableFields",
|
|
|
|
|
|
|
|
"run_id": "",
|
|
|
|
|
|
|
|
"tags": [],
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"data": {"output": AIMessageChunk(content="Goodbye world")},
|
|
|
|
|
|
|
|
"event": "on_chat_model_end",
|
|
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
|
|
"name": "RunnableConfigurableFields",
|
|
|
|
|
|
|
|
"run_id": "",
|
|
|
|
|
|
|
|
"tags": [],
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_runnable_with_message_history() -> None:
|
|
|
|
|
|
|
|
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
|
|
|
|
|
|
|
|
"""In memory implementation of chat message history."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Attention: for the tests use an Any type to work-around a pydantic issue
|
|
|
|
|
|
|
|
# where it re-instantiates a list, so mutating the list doesn't end up mutating
|
|
|
|
|
|
|
|
# the content in the store!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Using Any type here rather than List[BaseMessage] due to pydantic issue!
|
|
|
|
|
|
|
|
messages: Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_message(self, message: BaseMessage) -> None:
|
|
|
|
|
|
|
|
"""Add a self-created message to the store."""
|
|
|
|
|
|
|
|
self.messages.append(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
|
|
|
|
self.messages = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Here we use a global variable to store the chat message history.
|
|
|
|
|
|
|
|
# This will make it easier to inspect it to see the underlying results.
|
|
|
|
|
|
|
|
store: Dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
|
|
|
|
|
|
|
|
"""Get a chat message history"""
|
|
|
|
|
|
|
|
if session_id not in store:
|
|
|
|
|
|
|
|
store[session_id] = []
|
|
|
|
|
|
|
|
return InMemoryHistory(messages=store[session_id])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="world")])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
("system", "You are a cat"),
|
|
|
|
|
|
|
|
MessagesPlaceholder(variable_name="history"),
|
|
|
|
|
|
|
|
("human", "{question}"),
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chain: Runnable = prompt | model
|
|
|
|
|
|
|
|
with_message_history = RunnableWithMessageHistory(
|
|
|
|
|
|
|
|
chain,
|
|
|
|
|
|
|
|
get_session_history=get_by_session_id,
|
|
|
|
|
|
|
|
input_messages_key="question",
|
|
|
|
|
|
|
|
history_messages_key="history",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
with_message_history.with_config(
|
|
|
|
|
|
|
|
{"configurable": {"session_id": "session-123"}}
|
|
|
|
|
|
|
|
).invoke({"question": "hello"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert store == {
|
|
|
|
|
|
|
|
"session-123": [HumanMessage(content="hello"), AIMessage(content="hello")]
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with_message_history.with_config(
|
|
|
|
|
|
|
|
{"configurable": {"session_id": "session-123"}}
|
|
|
|
|
|
|
|
).invoke({"question": "meow"})
|
|
|
|
|
|
|
|
assert store == {
|
|
|
|
|
|
|
|
"session-123": [
|
|
|
|
|
|
|
|
HumanMessage(content="hello"),
|
|
|
|
|
|
|
|
AIMessage(content="hello"),
|
|
|
|
|
|
|
|
HumanMessage(content="meow"),
|
|
|
|
|
|
|
|
AIMessage(content="world"),
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
}
|
|
|
|