You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/core/tests/unit_tests/runnables/test_runnable_events.py

1066 lines
31 KiB
Python

"""Module that contains tests for runnable.astream_events API."""
from itertools import cycle
from typing import AsyncIterator, List, Sequence, cast
import pytest
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.documents import Document
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
RunnableLambda,
)
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
from tests.unit_tests.fake.llm import FakeStreamingListLLM
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
"""Removes the run ids from events."""
return cast(List[StreamEvent], [{**event, "run_id": ""} for event in events])
async def _as_async_iterator(iterable: List) -> AsyncIterator:
"""Converts an iterable into an async iterator."""
for item in iterable:
yield item
async def _collect_events(events: AsyncIterator[StreamEvent]) -> List[StreamEvent]:
"""Collect the events and remove the run ids."""
materialized_events = [event async for event in events]
events_ = _with_nulled_run_id(materialized_events)
for event in events_:
event["tags"] = sorted(event["tags"])
return events_
async def test_event_stream_with_single_lambda() -> None:
"""Test the event stream with a tool."""
def reverse(s: str) -> str:
"""Reverse a string."""
return s[::-1]
chain = RunnableLambda(func=reverse)
events = await _collect_events(chain.astream_events("hello"))
assert events == [
{
"data": {"input": "hello"},
"event": "on_chain_start",
"metadata": {},
"name": "reverse",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": "olleh"},
"event": "on_chain_stream",
"metadata": {},
"name": "reverse",
"run_id": "",
"tags": [],
},
{
"data": {"output": "olleh"},
"event": "on_chain_end",
"metadata": {},
"name": "reverse",
"run_id": "",
"tags": [],
},
]
async def test_event_stream_with_triple_lambda() -> None:
def reverse(s: str) -> str:
"""Reverse a string."""
return s[::-1]
r = RunnableLambda(func=reverse)
chain = (
r.with_config({"run_name": "1"})
| r.with_config({"run_name": "2"})
| r.with_config({"run_name": "3"})
)
events = await _collect_events(chain.astream_events("hello"))
assert events == [
{
"data": {"input": "hello"},
"event": "on_chain_start",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "1",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {"chunk": "olleh"},
"event": "on_chain_stream",
"metadata": {},
"name": "1",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "2",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"input": "hello", "output": "olleh"},
"event": "on_chain_end",
"metadata": {},
"name": "1",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {"chunk": "hello"},
"event": "on_chain_stream",
"metadata": {},
"name": "2",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "3",
"run_id": "",
"tags": ["seq:step:3"],
},
{
"data": {"input": "olleh", "output": "hello"},
"event": "on_chain_end",
"metadata": {},
"name": "2",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": "olleh"},
"event": "on_chain_stream",
"metadata": {},
"name": "3",
"run_id": "",
"tags": ["seq:step:3"],
},
{
"data": {"chunk": "olleh"},
"event": "on_chain_stream",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {"input": "hello", "output": "olleh"},
"event": "on_chain_end",
"metadata": {},
"name": "3",
"run_id": "",
"tags": ["seq:step:3"],
},
{
"data": {"output": "olleh"},
"event": "on_chain_end",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
]
async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"""Test filtering based on tags / names"""
def reverse(s: str) -> str:
"""Reverse a string."""
return s[::-1]
r = RunnableLambda(func=reverse)
chain = (
r.with_config({"run_name": "1"})
| r.with_config({"run_name": "2", "tags": ["my_tag"]})
| r.with_config({"run_name": "3", "tags": ["my_tag"]})
)
events = await _collect_events(chain.astream_events("hello", include_names=["1"]))
assert events == [
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "1",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {"chunk": "olleh"},
"event": "on_chain_stream",
"metadata": {},
"name": "1",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {"input": "hello", "output": "olleh"},
"event": "on_chain_end",
"metadata": {},
"name": "1",
"run_id": "",
"tags": ["seq:step:1"],
},
]
events = await _collect_events(
chain.astream_events("hello", include_tags=["my_tag"], exclude_names=["2"])
)
assert events == [
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "3",
"run_id": "",
"tags": ["my_tag", "seq:step:3"],
},
{
"data": {"chunk": "olleh"},
"event": "on_chain_stream",
"metadata": {},
"name": "3",
"run_id": "",
"tags": ["my_tag", "seq:step:3"],
},
{
"data": {"input": "hello", "output": "olleh"},
"event": "on_chain_end",
"metadata": {},
"name": "3",
"run_id": "",
"tags": ["my_tag", "seq:step:3"],
},
]
async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"}
)
events = await _collect_events(as_lambdas.astream_events({"question": "hello"}))
assert events == [
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
"metadata": {},
"name": "my_lambda",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": {"answer": "goodbye"}},
"event": "on_chain_stream",
"metadata": {},
"name": "my_lambda",
"run_id": "",
"tags": [],
},
{
"data": {"output": {"answer": "goodbye"}},
"event": "on_chain_end",
"metadata": {},
"name": "my_lambda",
"run_id": "",
"tags": [],
},
]
async def test_event_stream_with_simple_chain() -> None:
"""Test as event stream."""
template = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")]
).with_config({"run_name": "my_template", "tags": ["my_template"]})
infinite_cycle = cycle(
[AIMessage(content="hello world!"), AIMessage(content="goodbye world!")]
)
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
model = (
GenericFakeChatModel(messages=infinite_cycle)
.with_config(
{
"metadata": {"a": "b"},
"tags": ["my_model"],
"run_name": "my_model",
}
)
.bind(stop="<stop_token>")
)
chain = (template | model).with_config(
{
"metadata": {"foo": "bar"},
"tags": ["my_chain"],
"run_name": "my_chain",
}
)
events = await _collect_events(chain.astream_events({"question": "hello"}))
assert events == [
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": "",
"tags": ["my_chain"],
},
{
"data": {"input": {"question": "hello"}},
"event": "on_prompt_start",
"metadata": {"foo": "bar"},
"name": "my_template",
"run_id": "",
"tags": ["my_chain", "my_template", "seq:step:1"],
},
{
"data": {
"input": {"question": "hello"},
"output": ChatPromptValue(
messages=[
SystemMessage(content="You are Cat Agent 007"),
HumanMessage(content="hello"),
]
),
},
"event": "on_prompt_end",
"metadata": {"foo": "bar"},
"name": "my_template",
"run_id": "",
"tags": ["my_chain", "my_template", "seq:step:1"],
},
{
"data": {
"input": {
"messages": [
[
SystemMessage(content="You are Cat Agent 007"),
HumanMessage(content="hello"),
]
]
}
},
"event": "on_chat_model_start",
"metadata": {"a": "b", "foo": "bar"},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": "",
"tags": ["my_chain"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "foo": "bar"},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": "",
"tags": ["my_chain"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "foo": "bar"},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": "",
"tags": ["my_chain"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "foo": "bar"},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {
"input": {
"messages": [
[
SystemMessage(content="You are Cat Agent 007"),
HumanMessage(content="hello"),
]
]
},
"output": {
"generations": [
[
{
"generation_info": None,
"message": AIMessageChunk(content="hello world!"),
"text": "hello world!",
"type": "ChatGenerationChunk",
}
]
],
"llm_output": None,
"run": None,
},
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "foo": "bar"},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"output": AIMessageChunk(content="hello world!")},
"event": "on_chain_end",
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": "",
"tags": ["my_chain"],
},
]
async def test_event_streaming_with_tools() -> None:
"""Test streaming events with different tool definitions."""
@tool
def parameterless() -> str:
"""A tool that does nothing."""
return "hello"
@tool
def with_callbacks(callbacks: Callbacks) -> str:
"""A tool that does nothing."""
return "world"
@tool
def with_parameters(x: int, y: str) -> dict:
"""A tool that does nothing."""
return {"x": x, "y": y}
@tool
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
"""A tool that does nothing."""
return {"x": x, "y": y}
# type ignores below because the tools don't appear to be runnables to type checkers
# we can remove as soon as that's fixed
events = await _collect_events(parameterless.astream_events({})) # type: ignore
assert events == [
{
"data": {"input": {}},
"event": "on_tool_start",
"metadata": {},
"name": "parameterless",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": "hello"},
"event": "on_tool_stream",
"metadata": {},
"name": "parameterless",
"run_id": "",
"tags": [],
},
{
"data": {"output": "hello"},
"event": "on_tool_end",
"metadata": {},
"name": "parameterless",
"run_id": "",
"tags": [],
},
]
events = await _collect_events(with_callbacks.astream_events({})) # type: ignore
assert events == [
{
"data": {"input": {}},
"event": "on_tool_start",
"metadata": {},
"name": "with_callbacks",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": "world"},
"event": "on_tool_stream",
"metadata": {},
"name": "with_callbacks",
"run_id": "",
"tags": [],
},
{
"data": {"output": "world"},
"event": "on_tool_end",
"metadata": {},
"name": "with_callbacks",
"run_id": "",
"tags": [],
},
]
events = await _collect_events(with_parameters.astream_events({"x": 1, "y": "2"})) # type: ignore
assert events == [
{
"data": {"input": {"x": 1, "y": "2"}},
"event": "on_tool_start",
"metadata": {},
"name": "with_parameters",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": {"x": 1, "y": "2"}},
"event": "on_tool_stream",
"metadata": {},
"name": "with_parameters",
"run_id": "",
"tags": [],
},
{
"data": {"output": {"x": 1, "y": "2"}},
"event": "on_tool_end",
"metadata": {},
"name": "with_parameters",
"run_id": "",
"tags": [],
},
]
events = await _collect_events(
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}) # type: ignore
)
assert events == [
{
"data": {"input": {"x": 1, "y": "2"}},
"event": "on_tool_start",
"metadata": {},
"name": "with_parameters_and_callbacks",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": {"x": 1, "y": "2"}},
"event": "on_tool_stream",
"metadata": {},
"name": "with_parameters_and_callbacks",
"run_id": "",
"tags": [],
},
{
"data": {"output": {"x": 1, "y": "2"}},
"event": "on_tool_end",
"metadata": {},
"name": "with_parameters_and_callbacks",
"run_id": "",
"tags": [],
},
]
class HardCodedRetriever(BaseRetriever):
documents: List[Document]
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.documents
async def test_event_stream_with_retriever() -> None:
"""Test the event stream with a retriever."""
retriever = HardCodedRetriever(
documents=[
Document(
page_content="hello world!",
metadata={"foo": "bar"},
),
Document(
page_content="goodbye world!",
metadata={"food": "spare"},
),
]
)
events = await _collect_events(retriever.astream_events({"query": "hello"}))
assert events == [
{
"data": {
"input": {"query": "hello"},
},
"event": "on_retriever_start",
"metadata": {},
"name": "HardCodedRetriever",
"run_id": "",
"tags": [],
},
{
"data": {
"chunk": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(page_content="goodbye world!", metadata={"food": "spare"}),
]
},
"event": "on_retriever_stream",
"metadata": {},
"name": "HardCodedRetriever",
"run_id": "",
"tags": [],
},
{
"data": {
"output": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(page_content="goodbye world!", metadata={"food": "spare"}),
],
},
"event": "on_retriever_end",
"metadata": {},
"name": "HardCodedRetriever",
"run_id": "",
"tags": [],
},
]
async def test_event_stream_with_retriever_and_formatter() -> None:
"""Test the event stream with a retriever."""
retriever = HardCodedRetriever(
documents=[
Document(
page_content="hello world!",
metadata={"foo": "bar"},
),
Document(
page_content="goodbye world!",
metadata={"food": "spare"},
),
]
)
def format_docs(docs: List[Document]) -> str:
"""Format the docs."""
return ", ".join([doc.page_content for doc in docs])
chain = retriever | format_docs
events = await _collect_events(chain.astream_events("hello"))
assert events == [
{
"data": {"input": "hello"},
"event": "on_chain_start",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {"input": {"query": "hello"}},
"event": "on_retriever_start",
"metadata": {},
"name": "Retriever",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {
"input": {"query": "hello"},
"output": {
"documents": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(
page_content="goodbye world!", metadata={"food": "spare"}
),
]
},
},
"event": "on_retriever_end",
"metadata": {},
"name": "Retriever",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "format_docs",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": "hello world!, goodbye world!"},
"event": "on_chain_stream",
"metadata": {},
"name": "format_docs",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": "hello world!, goodbye world!"},
"event": "on_chain_stream",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {
"input": [
Document(page_content="hello world!", metadata={"foo": "bar"}),
Document(page_content="goodbye world!", metadata={"food": "spare"}),
],
"output": "hello world!, goodbye world!",
},
"event": "on_chain_end",
"metadata": {},
"name": "format_docs",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"output": "hello world!, goodbye world!"},
"event": "on_chain_end",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
]
async def test_event_stream_on_chain_with_tool() -> None:
"""Test the event stream with a tool."""
@tool
def concat(a: str, b: str) -> str:
"""A tool that does nothing."""
return a + b
def reverse(s: str) -> str:
"""Reverse a string."""
return s[::-1]
# For whatever reason type annotations fail here because reverse
# does not appear to be a runnable
chain = concat | reverse # type: ignore
events = await _collect_events(chain.astream_events({"a": "hello", "b": "world"}))
assert events == [
{
"data": {"input": {"a": "hello", "b": "world"}},
"event": "on_chain_start",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {"input": {"a": "hello", "b": "world"}},
"event": "on_tool_start",
"metadata": {},
"name": "concat",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {"input": {"a": "hello", "b": "world"}, "output": "helloworld"},
"event": "on_tool_end",
"metadata": {},
"name": "concat",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "reverse",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": "dlrowolleh"},
"event": "on_chain_stream",
"metadata": {},
"name": "reverse",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": "dlrowolleh"},
"event": "on_chain_stream",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {"input": "helloworld", "output": "dlrowolleh"},
"event": "on_chain_end",
"metadata": {},
"name": "reverse",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"output": "dlrowolleh"},
"event": "on_chain_end",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
]
async def test_event_stream_with_retry() -> None:
"""Test the event stream with a tool."""
def success(inputs: str) -> str:
return "success"
def fail(inputs: str) -> None:
"""Simple func."""
raise Exception("fail")
chain = RunnableLambda(success) | RunnableLambda(fail).with_retry(
stop_after_attempt=1,
)
iterable = chain.astream_events("q")
events = []
for _ in range(10):
try:
next_chunk = await iterable.__anext__()
events.append(next_chunk)
except Exception:
break
events = _with_nulled_run_id(events)
for event in events:
event["tags"] = sorted(event["tags"])
assert events == [
{
"data": {"input": "q"},
"event": "on_chain_start",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "success",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {"chunk": "success"},
"event": "on_chain_stream",
"metadata": {},
"name": "success",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "fail",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"input": "q", "output": "success"},
"event": "on_chain_end",
"metadata": {},
"name": "success",
"run_id": "",
"tags": ["seq:step:1"],
},
{
"data": {"input": "success", "output": None},
"event": "on_chain_end",
"metadata": {},
"name": "fail",
"run_id": "",
"tags": ["seq:step:2"],
},
]
async def test_with_llm() -> None:
"""Test with regular llm."""
prompt = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")]
).with_config({"run_name": "my_template", "tags": ["my_template"]})
llm = FakeStreamingListLLM(responses=["abc"])
chain = prompt | llm
events = await _collect_events(chain.astream_events({"question": "hello"}))
assert events == [
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {"input": {"question": "hello"}},
"event": "on_prompt_start",
"metadata": {},
"name": "my_template",
"run_id": "",
"tags": ["my_template", "seq:step:1"],
},
{
"data": {
"input": {"question": "hello"},
"output": ChatPromptValue(
messages=[
SystemMessage(content="You are Cat Agent 007"),
HumanMessage(content="hello"),
]
),
},
"event": "on_prompt_end",
"metadata": {},
"name": "my_template",
"run_id": "",
"tags": ["my_template", "seq:step:1"],
},
{
"data": {
"input": {"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]}
},
"event": "on_llm_start",
"metadata": {},
"name": "FakeStreamingListLLM",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {
"input": {
"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]
},
"output": {
"generations": [
[{"generation_info": None, "text": "abc", "type": "Generation"}]
],
"llm_output": None,
"run": None,
},
},
"event": "on_llm_end",
"metadata": {},
"name": "FakeStreamingListLLM",
"run_id": "",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": "a"},
"event": "on_chain_stream",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": "b"},
"event": "on_chain_stream",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": "c"},
"event": "on_chain_stream",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
{
"data": {"output": "abc"},
"event": "on_chain_end",
"metadata": {},
"name": "RunnableSequence",
"run_id": "",
"tags": [],
},
]
async def test_runnable_each() -> None:
"""Test runnable each astream_events."""
async def add_one(x: int) -> int:
return x + 1
add_one_map = RunnableLambda(add_one).map() # type: ignore
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3]):
pass