Populate streamed_output for all runs handled by atransform_stream_with_config (#15599)

This means that users of astream_log() now get streamed output of
virtually all requested runs, whereas before the only streamed output
would be for the root run and raw llm runs

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/15683/head
Nuno Campos 9 months ago committed by GitHub
parent 7025fa23aa
commit ef22559f1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1277,6 +1277,8 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
# tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one
@ -1302,6 +1304,16 @@ class Runnable(Generic[Input, Output], ABC):
context = copy_context()
context.run(var_child_runnable_config.set, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_log := next(
(
h
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_log.tap_output_aiter(run_manager.run_id, iterator)
try:
while True:
if accepts_context(asyncio.create_task):
@ -2733,6 +2745,7 @@ class RunnableLambda(Runnable[Input, Output]):
],
]
] = None,
name: Optional[str] = None,
) -> None:
"""Create a RunnableLambda from a callable, and async callable or both.
@ -2766,7 +2779,9 @@ class RunnableLambda(Runnable[Input, Output]):
)
try:
if func_for_name.__name__ != "<lambda>":
if name is not None:
self.name = name
elif func_for_name.__name__ != "<lambda>":
self.name = func_for_name.__name__
except AttributeError:
pass
@ -3046,17 +3061,7 @@ class RunnableLambda(Runnable[Input, Output]):
def _config(
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
) -> RunnableConfig:
config = ensure_config(config)
if config.get("run_name") is None:
try:
run_name = callable.__name__
except AttributeError:
run_name = None
if run_name is not None:
return patch_config(config, run_name=run_name)
return config
return ensure_config(config)
def invoke(
self,

@ -12,6 +12,7 @@ from typing import (
Optional,
Sequence,
TypedDict,
TypeVar,
Union,
)
from uuid import UUID
@ -128,6 +129,9 @@ class RunLog(RunLogPatch):
return f"RunLog({pformat(self.state)})"
T = TypeVar("T")
class LogStreamCallbackHandler(BaseTracer):
"""A tracer that streams run logs to a stream."""
@ -165,6 +169,28 @@ class LogStreamCallbackHandler(BaseTracer):
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__()
async def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
"""Tap an output async iterator to stream its values to the log."""
async for chunk in output:
# root run is handled in .astream_log()
if run_id != self.root_id:
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if key := self._key_map_by_run_id.get(run_id):
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
)
)
yield chunk
def include_run(self, run: Run) -> bool:
if run.id == self.root_id:
return False

File diff suppressed because one or more lines are too long

@ -2140,6 +2140,272 @@ async def test_prompt_with_llm(
assert stream_log == expected
@freeze_time("2023-01-01")
async def test_prompt_with_llm_parser(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["bear, dog, cat", "tomato, lettuce, onion"])
parser = CommaSeparatedListOutputParser()
chain: Runnable = prompt | llm | parser
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [llm]
assert chain.last == parser
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "ainvoke")
parser_spy = mocker.spy(parser.__class__, "ainvoke")
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == ["bear", "dog", "cat"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert parser_spy.call_args.args[1] == "bear, dog, cat"
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
mocker.stop(parser_spy)
# Test batch
prompt_spy = mocker.spy(prompt.__class__, "abatch")
llm_spy = mocker.spy(llm.__class__, "abatch")
parser_spy = mocker.spy(parser.__class__, "abatch")
tracer = FakeTracer()
assert await chain.abatch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
) == [["tomato", "lettuce", "onion"], ["bear", "dog", "cat"]]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
assert llm_spy.call_args.args[1] == [
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert parser_spy.call_args.args[1] == [
"tomato, lettuce, onion",
"bear, dog, cat",
]
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
mocker.stop(parser_spy)
# Test stream
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "astream")
tracer = FakeTracer()
assert [
token
async for token in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == [["tomato"], ["lettuce"], ["onion"]]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
prompt_spy.reset_mock()
llm_spy.reset_mock()
stream_log = [
part async for part in chain.astream_log({"question": "What is your name?"})
]
# remove ids from logs
for part in stream_log:
for op in part.ops:
if (
isinstance(op["value"], dict)
and "id" in op["value"]
and not isinstance(op["value"]["id"], list) # serialized lc id
):
del op["value"]["id"]
expected = [
RunLogPatch(
{
"op": "replace",
"path": "",
"value": {
"logs": {},
"final_output": None,
"streamed_output": [],
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/ChatPromptTemplate",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "ChatPromptTemplate",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],
"streamed_output_str": [],
"tags": ["seq:step:1"],
"type": "prompt",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/ChatPromptTemplate/final_output",
"value": ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
},
{
"op": "add",
"path": "/logs/ChatPromptTemplate/end_time",
"value": "2023-01-01T00:00:00.000+00:00",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/FakeStreamingListLLM",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "FakeStreamingListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],
"streamed_output_str": [],
"tags": ["seq:step:2"],
"type": "llm",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/FakeStreamingListLLM/final_output",
"value": {
"generations": [
[
{
"generation_info": None,
"text": "bear, dog, cat",
"type": "Generation",
}
]
],
"llm_output": None,
"run": None,
},
},
{
"op": "add",
"path": "/logs/FakeStreamingListLLM/end_time",
"value": "2023-01-01T00:00:00.000+00:00",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "CommaSeparatedListOutputParser",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],
"streamed_output_str": [],
"tags": ["seq:step:3"],
"type": "parser",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
"value": ["bear"],
}
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": ["bear"]},
{"op": "replace", "path": "/final_output", "value": ["bear"]},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
"value": ["dog"],
}
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": ["dog"]},
{"op": "add", "path": "/final_output/1", "value": "dog"},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
"value": ["cat"],
}
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": ["cat"]},
{"op": "add", "path": "/final_output/2", "value": "cat"},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/final_output",
"value": {"output": ["bear", "dog", "cat"]},
},
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/end_time",
"value": "2023-01-01T00:00:00.000+00:00",
},
),
]
assert stream_log == expected
@freeze_time("2023-01-01")
async def test_stream_log_retriever() -> None:
prompt = (
@ -4606,6 +4872,14 @@ async def test_runnable_iter_context_config() -> None:
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer()
assert [p async for p in agen.astream_log("a", {"callbacks": [tracer]})]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer()
assert await agen.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
assert len(tracer.runs) == 2

Loading…
Cancel
Save