mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1066 lines
31 KiB
Python
1066 lines
31 KiB
Python
"""Module that contains tests for runnable.astream_events API."""
|
|
from itertools import cycle
|
|
from typing import AsyncIterator, List, Sequence, cast
|
|
|
|
import pytest
|
|
|
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
|
from langchain_core.documents import Document
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain_core.prompt_values import ChatPromptValue
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.retrievers import BaseRetriever
|
|
from langchain_core.runnables import (
|
|
RunnableLambda,
|
|
)
|
|
from langchain_core.runnables.schema import StreamEvent
|
|
from langchain_core.tools import tool
|
|
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
|
from tests.unit_tests.fake.llm import FakeStreamingListLLM
|
|
|
|
|
|
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
|
|
"""Removes the run ids from events."""
|
|
return cast(List[StreamEvent], [{**event, "run_id": ""} for event in events])
|
|
|
|
|
|
async def _as_async_iterator(iterable: List) -> AsyncIterator:
|
|
"""Converts an iterable into an async iterator."""
|
|
for item in iterable:
|
|
yield item
|
|
|
|
|
|
async def _collect_events(events: AsyncIterator[StreamEvent]) -> List[StreamEvent]:
|
|
"""Collect the events and remove the run ids."""
|
|
materialized_events = [event async for event in events]
|
|
events_ = _with_nulled_run_id(materialized_events)
|
|
for event in events_:
|
|
event["tags"] = sorted(event["tags"])
|
|
return events_
|
|
|
|
|
|
async def test_event_stream_with_single_lambda() -> None:
|
|
"""Test the event stream with a tool."""
|
|
|
|
def reverse(s: str) -> str:
|
|
"""Reverse a string."""
|
|
return s[::-1]
|
|
|
|
chain = RunnableLambda(func=reverse)
|
|
|
|
events = await _collect_events(chain.astream_events("hello"))
|
|
assert events == [
|
|
{
|
|
"data": {"input": "hello"},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "reverse",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"chunk": "olleh"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "reverse",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"output": "olleh"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "reverse",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_event_stream_with_triple_lambda() -> None:
|
|
def reverse(s: str) -> str:
|
|
"""Reverse a string."""
|
|
return s[::-1]
|
|
|
|
r = RunnableLambda(func=reverse)
|
|
|
|
chain = (
|
|
r.with_config({"run_name": "1"})
|
|
| r.with_config({"run_name": "2"})
|
|
| r.with_config({"run_name": "3"})
|
|
)
|
|
events = await _collect_events(chain.astream_events("hello"))
|
|
assert events == [
|
|
{
|
|
"data": {"input": "hello"},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "1",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {"chunk": "olleh"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "1",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "2",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"input": "hello", "output": "olleh"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "1",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {"chunk": "hello"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "2",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "3",
|
|
"run_id": "",
|
|
"tags": ["seq:step:3"],
|
|
},
|
|
{
|
|
"data": {"input": "olleh", "output": "hello"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "2",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"chunk": "olleh"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "3",
|
|
"run_id": "",
|
|
"tags": ["seq:step:3"],
|
|
},
|
|
{
|
|
"data": {"chunk": "olleh"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"input": "hello", "output": "olleh"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "3",
|
|
"run_id": "",
|
|
"tags": ["seq:step:3"],
|
|
},
|
|
{
|
|
"data": {"output": "olleh"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_event_stream_with_triple_lambda_test_filtering() -> None:
|
|
"""Test filtering based on tags / names"""
|
|
|
|
def reverse(s: str) -> str:
|
|
"""Reverse a string."""
|
|
return s[::-1]
|
|
|
|
r = RunnableLambda(func=reverse)
|
|
|
|
chain = (
|
|
r.with_config({"run_name": "1"})
|
|
| r.with_config({"run_name": "2", "tags": ["my_tag"]})
|
|
| r.with_config({"run_name": "3", "tags": ["my_tag"]})
|
|
)
|
|
events = await _collect_events(chain.astream_events("hello", include_names=["1"]))
|
|
assert events == [
|
|
{
|
|
"data": {},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "1",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {"chunk": "olleh"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "1",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {"input": "hello", "output": "olleh"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "1",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
]
|
|
|
|
events = await _collect_events(
|
|
chain.astream_events("hello", include_tags=["my_tag"], exclude_names=["2"])
|
|
)
|
|
assert events == [
|
|
{
|
|
"data": {},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "3",
|
|
"run_id": "",
|
|
"tags": ["my_tag", "seq:step:3"],
|
|
},
|
|
{
|
|
"data": {"chunk": "olleh"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "3",
|
|
"run_id": "",
|
|
"tags": ["my_tag", "seq:step:3"],
|
|
},
|
|
{
|
|
"data": {"input": "hello", "output": "olleh"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "3",
|
|
"run_id": "",
|
|
"tags": ["my_tag", "seq:step:3"],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_event_stream_with_lambdas_from_lambda() -> None:
|
|
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
|
|
{"run_name": "my_lambda"}
|
|
)
|
|
events = await _collect_events(as_lambdas.astream_events({"question": "hello"}))
|
|
assert events == [
|
|
{
|
|
"data": {"input": {"question": "hello"}},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "my_lambda",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"chunk": {"answer": "goodbye"}},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "my_lambda",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"output": {"answer": "goodbye"}},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "my_lambda",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_event_stream_with_simple_chain() -> None:
|
|
"""Test as event stream."""
|
|
template = ChatPromptTemplate.from_messages(
|
|
[("system", "You are Cat Agent 007"), ("human", "{question}")]
|
|
).with_config({"run_name": "my_template", "tags": ["my_template"]})
|
|
|
|
infinite_cycle = cycle(
|
|
[AIMessage(content="hello world!"), AIMessage(content="goodbye world!")]
|
|
)
|
|
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
|
|
model = (
|
|
GenericFakeChatModel(messages=infinite_cycle)
|
|
.with_config(
|
|
{
|
|
"metadata": {"a": "b"},
|
|
"tags": ["my_model"],
|
|
"run_name": "my_model",
|
|
}
|
|
)
|
|
.bind(stop="<stop_token>")
|
|
)
|
|
|
|
chain = (template | model).with_config(
|
|
{
|
|
"metadata": {"foo": "bar"},
|
|
"tags": ["my_chain"],
|
|
"run_name": "my_chain",
|
|
}
|
|
)
|
|
|
|
events = await _collect_events(chain.astream_events({"question": "hello"}))
|
|
assert events == [
|
|
{
|
|
"data": {"input": {"question": "hello"}},
|
|
"event": "on_chain_start",
|
|
"metadata": {"foo": "bar"},
|
|
"name": "my_chain",
|
|
"run_id": "",
|
|
"tags": ["my_chain"],
|
|
},
|
|
{
|
|
"data": {"input": {"question": "hello"}},
|
|
"event": "on_prompt_start",
|
|
"metadata": {"foo": "bar"},
|
|
"name": "my_template",
|
|
"run_id": "",
|
|
"tags": ["my_chain", "my_template", "seq:step:1"],
|
|
},
|
|
{
|
|
"data": {
|
|
"input": {"question": "hello"},
|
|
"output": ChatPromptValue(
|
|
messages=[
|
|
SystemMessage(content="You are Cat Agent 007"),
|
|
HumanMessage(content="hello"),
|
|
]
|
|
),
|
|
},
|
|
"event": "on_prompt_end",
|
|
"metadata": {"foo": "bar"},
|
|
"name": "my_template",
|
|
"run_id": "",
|
|
"tags": ["my_chain", "my_template", "seq:step:1"],
|
|
},
|
|
{
|
|
"data": {
|
|
"input": {
|
|
"messages": [
|
|
[
|
|
SystemMessage(content="You are Cat Agent 007"),
|
|
HumanMessage(content="hello"),
|
|
]
|
|
]
|
|
}
|
|
},
|
|
"event": "on_chat_model_start",
|
|
"metadata": {"a": "b", "foo": "bar"},
|
|
"name": "my_model",
|
|
"run_id": "",
|
|
"tags": ["my_chain", "my_model", "seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"chunk": AIMessageChunk(content="hello")},
|
|
"event": "on_chain_stream",
|
|
"metadata": {"foo": "bar"},
|
|
"name": "my_chain",
|
|
"run_id": "",
|
|
"tags": ["my_chain"],
|
|
},
|
|
{
|
|
"data": {"chunk": AIMessageChunk(content="hello")},
|
|
"event": "on_chat_model_stream",
|
|
"metadata": {"a": "b", "foo": "bar"},
|
|
"name": "my_model",
|
|
"run_id": "",
|
|
"tags": ["my_chain", "my_model", "seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"chunk": AIMessageChunk(content=" ")},
|
|
"event": "on_chain_stream",
|
|
"metadata": {"foo": "bar"},
|
|
"name": "my_chain",
|
|
"run_id": "",
|
|
"tags": ["my_chain"],
|
|
},
|
|
{
|
|
"data": {"chunk": AIMessageChunk(content=" ")},
|
|
"event": "on_chat_model_stream",
|
|
"metadata": {"a": "b", "foo": "bar"},
|
|
"name": "my_model",
|
|
"run_id": "",
|
|
"tags": ["my_chain", "my_model", "seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"chunk": AIMessageChunk(content="world!")},
|
|
"event": "on_chain_stream",
|
|
"metadata": {"foo": "bar"},
|
|
"name": "my_chain",
|
|
"run_id": "",
|
|
"tags": ["my_chain"],
|
|
},
|
|
{
|
|
"data": {"chunk": AIMessageChunk(content="world!")},
|
|
"event": "on_chat_model_stream",
|
|
"metadata": {"a": "b", "foo": "bar"},
|
|
"name": "my_model",
|
|
"run_id": "",
|
|
"tags": ["my_chain", "my_model", "seq:step:2"],
|
|
},
|
|
{
|
|
"data": {
|
|
"input": {
|
|
"messages": [
|
|
[
|
|
SystemMessage(content="You are Cat Agent 007"),
|
|
HumanMessage(content="hello"),
|
|
]
|
|
]
|
|
},
|
|
"output": {
|
|
"generations": [
|
|
[
|
|
{
|
|
"generation_info": None,
|
|
"message": AIMessageChunk(content="hello world!"),
|
|
"text": "hello world!",
|
|
"type": "ChatGenerationChunk",
|
|
}
|
|
]
|
|
],
|
|
"llm_output": None,
|
|
"run": None,
|
|
},
|
|
},
|
|
"event": "on_chat_model_end",
|
|
"metadata": {"a": "b", "foo": "bar"},
|
|
"name": "my_model",
|
|
"run_id": "",
|
|
"tags": ["my_chain", "my_model", "seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"output": AIMessageChunk(content="hello world!")},
|
|
"event": "on_chain_end",
|
|
"metadata": {"foo": "bar"},
|
|
"name": "my_chain",
|
|
"run_id": "",
|
|
"tags": ["my_chain"],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_event_streaming_with_tools() -> None:
|
|
"""Test streaming events with different tool definitions."""
|
|
|
|
@tool
|
|
def parameterless() -> str:
|
|
"""A tool that does nothing."""
|
|
return "hello"
|
|
|
|
@tool
|
|
def with_callbacks(callbacks: Callbacks) -> str:
|
|
"""A tool that does nothing."""
|
|
return "world"
|
|
|
|
@tool
|
|
def with_parameters(x: int, y: str) -> dict:
|
|
"""A tool that does nothing."""
|
|
return {"x": x, "y": y}
|
|
|
|
@tool
|
|
def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
|
|
"""A tool that does nothing."""
|
|
return {"x": x, "y": y}
|
|
|
|
# type ignores below because the tools don't appear to be runnables to type checkers
|
|
# we can remove as soon as that's fixed
|
|
events = await _collect_events(parameterless.astream_events({})) # type: ignore
|
|
assert events == [
|
|
{
|
|
"data": {"input": {}},
|
|
"event": "on_tool_start",
|
|
"metadata": {},
|
|
"name": "parameterless",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"chunk": "hello"},
|
|
"event": "on_tool_stream",
|
|
"metadata": {},
|
|
"name": "parameterless",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"output": "hello"},
|
|
"event": "on_tool_end",
|
|
"metadata": {},
|
|
"name": "parameterless",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
events = await _collect_events(with_callbacks.astream_events({})) # type: ignore
|
|
assert events == [
|
|
{
|
|
"data": {"input": {}},
|
|
"event": "on_tool_start",
|
|
"metadata": {},
|
|
"name": "with_callbacks",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"chunk": "world"},
|
|
"event": "on_tool_stream",
|
|
"metadata": {},
|
|
"name": "with_callbacks",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"output": "world"},
|
|
"event": "on_tool_end",
|
|
"metadata": {},
|
|
"name": "with_callbacks",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
events = await _collect_events(with_parameters.astream_events({"x": 1, "y": "2"})) # type: ignore
|
|
assert events == [
|
|
{
|
|
"data": {"input": {"x": 1, "y": "2"}},
|
|
"event": "on_tool_start",
|
|
"metadata": {},
|
|
"name": "with_parameters",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"chunk": {"x": 1, "y": "2"}},
|
|
"event": "on_tool_stream",
|
|
"metadata": {},
|
|
"name": "with_parameters",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"output": {"x": 1, "y": "2"}},
|
|
"event": "on_tool_end",
|
|
"metadata": {},
|
|
"name": "with_parameters",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
events = await _collect_events(
|
|
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}) # type: ignore
|
|
)
|
|
assert events == [
|
|
{
|
|
"data": {"input": {"x": 1, "y": "2"}},
|
|
"event": "on_tool_start",
|
|
"metadata": {},
|
|
"name": "with_parameters_and_callbacks",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"chunk": {"x": 1, "y": "2"}},
|
|
"event": "on_tool_stream",
|
|
"metadata": {},
|
|
"name": "with_parameters_and_callbacks",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"output": {"x": 1, "y": "2"}},
|
|
"event": "on_tool_end",
|
|
"metadata": {},
|
|
"name": "with_parameters_and_callbacks",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
|
|
class HardCodedRetriever(BaseRetriever):
|
|
documents: List[Document]
|
|
|
|
def _get_relevant_documents(
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
return self.documents
|
|
|
|
|
|
async def test_event_stream_with_retriever() -> None:
|
|
"""Test the event stream with a retriever."""
|
|
retriever = HardCodedRetriever(
|
|
documents=[
|
|
Document(
|
|
page_content="hello world!",
|
|
metadata={"foo": "bar"},
|
|
),
|
|
Document(
|
|
page_content="goodbye world!",
|
|
metadata={"food": "spare"},
|
|
),
|
|
]
|
|
)
|
|
events = await _collect_events(retriever.astream_events({"query": "hello"}))
|
|
assert events == [
|
|
{
|
|
"data": {
|
|
"input": {"query": "hello"},
|
|
},
|
|
"event": "on_retriever_start",
|
|
"metadata": {},
|
|
"name": "HardCodedRetriever",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {
|
|
"chunk": [
|
|
Document(page_content="hello world!", metadata={"foo": "bar"}),
|
|
Document(page_content="goodbye world!", metadata={"food": "spare"}),
|
|
]
|
|
},
|
|
"event": "on_retriever_stream",
|
|
"metadata": {},
|
|
"name": "HardCodedRetriever",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {
|
|
"output": [
|
|
Document(page_content="hello world!", metadata={"foo": "bar"}),
|
|
Document(page_content="goodbye world!", metadata={"food": "spare"}),
|
|
],
|
|
},
|
|
"event": "on_retriever_end",
|
|
"metadata": {},
|
|
"name": "HardCodedRetriever",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_event_stream_with_retriever_and_formatter() -> None:
|
|
"""Test the event stream with a retriever."""
|
|
retriever = HardCodedRetriever(
|
|
documents=[
|
|
Document(
|
|
page_content="hello world!",
|
|
metadata={"foo": "bar"},
|
|
),
|
|
Document(
|
|
page_content="goodbye world!",
|
|
metadata={"food": "spare"},
|
|
),
|
|
]
|
|
)
|
|
|
|
def format_docs(docs: List[Document]) -> str:
|
|
"""Format the docs."""
|
|
return ", ".join([doc.page_content for doc in docs])
|
|
|
|
chain = retriever | format_docs
|
|
events = await _collect_events(chain.astream_events("hello"))
|
|
assert events == [
|
|
{
|
|
"data": {"input": "hello"},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"input": {"query": "hello"}},
|
|
"event": "on_retriever_start",
|
|
"metadata": {},
|
|
"name": "Retriever",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {
|
|
"input": {"query": "hello"},
|
|
"output": {
|
|
"documents": [
|
|
Document(page_content="hello world!", metadata={"foo": "bar"}),
|
|
Document(
|
|
page_content="goodbye world!", metadata={"food": "spare"}
|
|
),
|
|
]
|
|
},
|
|
},
|
|
"event": "on_retriever_end",
|
|
"metadata": {},
|
|
"name": "Retriever",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "format_docs",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"chunk": "hello world!, goodbye world!"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "format_docs",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"chunk": "hello world!, goodbye world!"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {
|
|
"input": [
|
|
Document(page_content="hello world!", metadata={"foo": "bar"}),
|
|
Document(page_content="goodbye world!", metadata={"food": "spare"}),
|
|
],
|
|
"output": "hello world!, goodbye world!",
|
|
},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "format_docs",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"output": "hello world!, goodbye world!"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_event_stream_on_chain_with_tool() -> None:
|
|
"""Test the event stream with a tool."""
|
|
|
|
@tool
|
|
def concat(a: str, b: str) -> str:
|
|
"""A tool that does nothing."""
|
|
return a + b
|
|
|
|
def reverse(s: str) -> str:
|
|
"""Reverse a string."""
|
|
return s[::-1]
|
|
|
|
# For whatever reason type annotations fail here because reverse
|
|
# does not appear to be a runnable
|
|
chain = concat | reverse # type: ignore
|
|
|
|
events = await _collect_events(chain.astream_events({"a": "hello", "b": "world"}))
|
|
assert events == [
|
|
{
|
|
"data": {"input": {"a": "hello", "b": "world"}},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"input": {"a": "hello", "b": "world"}},
|
|
"event": "on_tool_start",
|
|
"metadata": {},
|
|
"name": "concat",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {"input": {"a": "hello", "b": "world"}, "output": "helloworld"},
|
|
"event": "on_tool_end",
|
|
"metadata": {},
|
|
"name": "concat",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "reverse",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"chunk": "dlrowolleh"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "reverse",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"chunk": "dlrowolleh"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"input": "helloworld", "output": "dlrowolleh"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "reverse",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"output": "dlrowolleh"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_event_stream_with_retry() -> None:
|
|
"""Test the event stream with a tool."""
|
|
|
|
def success(inputs: str) -> str:
|
|
return "success"
|
|
|
|
def fail(inputs: str) -> None:
|
|
"""Simple func."""
|
|
raise Exception("fail")
|
|
|
|
chain = RunnableLambda(success) | RunnableLambda(fail).with_retry(
|
|
stop_after_attempt=1,
|
|
)
|
|
iterable = chain.astream_events("q")
|
|
|
|
events = []
|
|
|
|
for _ in range(10):
|
|
try:
|
|
next_chunk = await iterable.__anext__()
|
|
events.append(next_chunk)
|
|
except Exception:
|
|
break
|
|
|
|
events = _with_nulled_run_id(events)
|
|
for event in events:
|
|
event["tags"] = sorted(event["tags"])
|
|
|
|
assert events == [
|
|
{
|
|
"data": {"input": "q"},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "success",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {"chunk": "success"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "success",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "fail",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"input": "q", "output": "success"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "success",
|
|
"run_id": "",
|
|
"tags": ["seq:step:1"],
|
|
},
|
|
{
|
|
"data": {"input": "success", "output": None},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "fail",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_with_llm() -> None:
|
|
"""Test with regular llm."""
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[("system", "You are Cat Agent 007"), ("human", "{question}")]
|
|
).with_config({"run_name": "my_template", "tags": ["my_template"]})
|
|
llm = FakeStreamingListLLM(responses=["abc"])
|
|
|
|
chain = prompt | llm
|
|
events = await _collect_events(chain.astream_events({"question": "hello"}))
|
|
assert events == [
|
|
{
|
|
"data": {"input": {"question": "hello"}},
|
|
"event": "on_chain_start",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"input": {"question": "hello"}},
|
|
"event": "on_prompt_start",
|
|
"metadata": {},
|
|
"name": "my_template",
|
|
"run_id": "",
|
|
"tags": ["my_template", "seq:step:1"],
|
|
},
|
|
{
|
|
"data": {
|
|
"input": {"question": "hello"},
|
|
"output": ChatPromptValue(
|
|
messages=[
|
|
SystemMessage(content="You are Cat Agent 007"),
|
|
HumanMessage(content="hello"),
|
|
]
|
|
),
|
|
},
|
|
"event": "on_prompt_end",
|
|
"metadata": {},
|
|
"name": "my_template",
|
|
"run_id": "",
|
|
"tags": ["my_template", "seq:step:1"],
|
|
},
|
|
{
|
|
"data": {
|
|
"input": {"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]}
|
|
},
|
|
"event": "on_llm_start",
|
|
"metadata": {},
|
|
"name": "FakeStreamingListLLM",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {
|
|
"input": {
|
|
"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]
|
|
},
|
|
"output": {
|
|
"generations": [
|
|
[{"generation_info": None, "text": "abc", "type": "Generation"}]
|
|
],
|
|
"llm_output": None,
|
|
"run": None,
|
|
},
|
|
},
|
|
"event": "on_llm_end",
|
|
"metadata": {},
|
|
"name": "FakeStreamingListLLM",
|
|
"run_id": "",
|
|
"tags": ["seq:step:2"],
|
|
},
|
|
{
|
|
"data": {"chunk": "a"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"chunk": "b"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"chunk": "c"},
|
|
"event": "on_chain_stream",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
{
|
|
"data": {"output": "abc"},
|
|
"event": "on_chain_end",
|
|
"metadata": {},
|
|
"name": "RunnableSequence",
|
|
"run_id": "",
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
|
|
async def test_runnable_each() -> None:
|
|
"""Test runnable each astream_events."""
|
|
|
|
async def add_one(x: int) -> int:
|
|
return x + 1
|
|
|
|
add_one_map = RunnableLambda(add_one).map() # type: ignore
|
|
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
async for _ in add_one_map.astream_events([1, 2, 3]):
|
|
pass
|