Improve output of Runnable.astream_log() (#11391)

- Make logs a dictionary keyed by run name (and counter for repeats)
- Ensure no output shows up in lc_serializable format
- Fix up repr for RunLog and RunLogPatch

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11304/head
Nuno Campos 11 months ago committed by GitHub
parent a30f98f534
commit 4d66756d93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,7 @@ from __future__ import annotations
import math import math
import threading import threading
from collections import defaultdict
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
@ -19,6 +20,7 @@ from anyio import create_memory_object_stream
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
from langchain.load.load import load
from langchain.schema.output import ChatGenerationChunk, GenerationChunk from langchain.schema.output import ChatGenerationChunk, GenerationChunk
@ -55,7 +57,7 @@ class RunState(TypedDict):
"""Final output of the run, usually the result of aggregating streamed_output. """Final output of the run, usually the result of aggregating streamed_output.
Only available after the run has finished successfully.""" Only available after the run has finished successfully."""
logs: list[LogEntry] logs: Dict[str, LogEntry]
"""List of sub-runs contained in this run, if any, in the order they were started. """List of sub-runs contained in this run, if any, in the order they were started.
If filters were supplied, this list will contain only the runs that matched the If filters were supplied, this list will contain only the runs that matched the
filters.""" filters."""
@ -85,7 +87,8 @@ class RunLogPatch:
def __repr__(self) -> str: def __repr__(self) -> str:
from pprint import pformat from pprint import pformat
return f"RunLogPatch(ops={pformat(self.ops)})" # 1:-1 to get rid of the [] around the list
return f"RunLogPatch({pformat(self.ops)[1:-1]})"
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, RunLogPatch) and self.ops == other.ops return isinstance(other, RunLogPatch) and self.ops == other.ops
@ -112,7 +115,7 @@ class RunLog(RunLogPatch):
def __repr__(self) -> str: def __repr__(self) -> str:
from pprint import pformat from pprint import pformat
return f"RunLog(state={pformat(self.state)})" return f"RunLog({pformat(self.state)})"
class LogStreamCallbackHandler(BaseTracer): class LogStreamCallbackHandler(BaseTracer):
@ -143,7 +146,8 @@ class LogStreamCallbackHandler(BaseTracer):
self.lock = threading.Lock() self.lock = threading.Lock()
self.send_stream = send_stream self.send_stream = send_stream
self.receive_stream = receive_stream self.receive_stream = receive_stream
self._index_map: Dict[UUID, int] = {} self._key_map_by_run_id: Dict[UUID, str] = {}
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
def __aiter__(self) -> AsyncIterator[RunLogPatch]: def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__() return self.receive_stream.__aiter__()
@ -196,7 +200,7 @@ class LogStreamCallbackHandler(BaseTracer):
id=str(run.id), id=str(run.id),
streamed_output=[], streamed_output=[],
final_output=None, final_output=None,
logs=[], logs={},
), ),
} }
) )
@ -207,14 +211,18 @@ class LogStreamCallbackHandler(BaseTracer):
# Determine previous index, increment by 1 # Determine previous index, increment by 1
with self.lock: with self.lock:
self._index_map[run.id] = max(self._index_map.values(), default=-1) + 1 self._counter_map_by_name[run.name] += 1
count = self._counter_map_by_name[run.name]
self._key_map_by_run_id[run.id] = (
run.name if count == 1 else f"{run.name}:{count}"
)
# Add the run to the stream # Add the run to the stream
self.send_stream.send_nowait( self.send_stream.send_nowait(
RunLogPatch( RunLogPatch(
{ {
"op": "add", "op": "add",
"path": f"/logs/{self._index_map[run.id]}", "path": f"/logs/{self._key_map_by_run_id[run.id]}",
"value": LogEntry( "value": LogEntry(
id=str(run.id), id=str(run.id),
name=run.name, name=run.name,
@ -233,7 +241,7 @@ class LogStreamCallbackHandler(BaseTracer):
def _on_run_update(self, run: Run) -> None: def _on_run_update(self, run: Run) -> None:
"""Finish a run.""" """Finish a run."""
try: try:
index = self._index_map.get(run.id) index = self._key_map_by_run_id.get(run.id)
if index is None: if index is None:
return return
@ -243,7 +251,8 @@ class LogStreamCallbackHandler(BaseTracer):
{ {
"op": "add", "op": "add",
"path": f"/logs/{index}/final_output", "path": f"/logs/{index}/final_output",
"value": run.outputs, # to undo the dumpd done by some runnables / tracer / etc
"value": load(run.outputs),
}, },
{ {
"op": "add", "op": "add",
@ -259,7 +268,7 @@ class LogStreamCallbackHandler(BaseTracer):
{ {
"op": "replace", "op": "replace",
"path": "/final_output", "path": "/final_output",
"value": run.outputs, "value": load(run.outputs),
} }
) )
) )
@ -273,7 +282,7 @@ class LogStreamCallbackHandler(BaseTracer):
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> None: ) -> None:
"""Process new LLM token.""" """Process new LLM token."""
index = self._index_map.get(run.id) index = self._key_map_by_run_id.get(run.id)
if index is None: if index is None:
return return

@ -1239,7 +1239,7 @@ async def test_prompt() -> None:
assert len(stream_log[0].ops) == 1 assert len(stream_log[0].ops) == 1
assert stream_log[0].ops[0]["op"] == "replace" assert stream_log[0].ops[0]["op"] == "replace"
assert stream_log[0].ops[0]["path"] == "" assert stream_log[0].ops[0]["path"] == ""
assert stream_log[0].ops[0]["value"]["logs"] == [] assert stream_log[0].ops[0]["value"]["logs"] == {}
assert stream_log[0].ops[0]["value"]["final_output"] is None assert stream_log[0].ops[0]["value"]["final_output"] is None
assert stream_log[0].ops[0]["value"]["streamed_output"] == [] assert stream_log[0].ops[0]["value"]["streamed_output"] == []
assert isinstance(stream_log[0].ops[0]["value"]["id"], str) assert isinstance(stream_log[0].ops[0]["value"]["id"], str)
@ -1249,40 +1249,12 @@ async def test_prompt() -> None:
{ {
"op": "replace", "op": "replace",
"path": "/final_output", "path": "/final_output",
"value": { "value": ChatPromptValue(
"id": ["langchain", "prompts", "chat", "ChatPromptValue"], messages=[
"kwargs": { SystemMessage(content="You are a nice assistant."),
"messages": [ HumanMessage(content="What is your name?"),
{ ]
"id": [ ),
"langchain",
"schema",
"messages",
"SystemMessage",
],
"kwargs": {"content": "You are a nice " "assistant."},
"lc": 1,
"type": "constructor",
},
{
"id": [
"langchain",
"schema",
"messages",
"HumanMessage",
],
"kwargs": {
"additional_kwargs": {},
"content": "What is your " "name?",
},
"lc": 1,
"type": "constructor",
},
]
},
"lc": 1,
"type": "constructor",
},
} }
), ),
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}), RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
@ -1525,7 +1497,7 @@ async def test_prompt_with_llm(
"op": "replace", "op": "replace",
"path": "", "path": "",
"value": { "value": {
"logs": [], "logs": {},
"final_output": None, "final_output": None,
"streamed_output": [], "streamed_output": [],
}, },
@ -1534,7 +1506,7 @@ async def test_prompt_with_llm(
RunLogPatch( RunLogPatch(
{ {
"op": "add", "op": "add",
"path": "/logs/0", "path": "/logs/ChatPromptTemplate",
"value": { "value": {
"end_time": None, "end_time": None,
"final_output": None, "final_output": None,
@ -1550,55 +1522,24 @@ async def test_prompt_with_llm(
RunLogPatch( RunLogPatch(
{ {
"op": "add", "op": "add",
"path": "/logs/0/final_output", "path": "/logs/ChatPromptTemplate/final_output",
"value": { "value": ChatPromptValue(
"id": ["langchain", "prompts", "chat", "ChatPromptValue"], messages=[
"kwargs": { SystemMessage(content="You are a nice assistant."),
"messages": [ HumanMessage(content="What is your name?"),
{ ]
"id": [ ),
"langchain",
"schema",
"messages",
"SystemMessage",
],
"kwargs": {
"additional_kwargs": {},
"content": "You are a nice " "assistant.",
},
"lc": 1,
"type": "constructor",
},
{
"id": [
"langchain",
"schema",
"messages",
"HumanMessage",
],
"kwargs": {
"additional_kwargs": {},
"content": "What is your " "name?",
},
"lc": 1,
"type": "constructor",
},
]
},
"lc": 1,
"type": "constructor",
},
}, },
{ {
"op": "add", "op": "add",
"path": "/logs/0/end_time", "path": "/logs/ChatPromptTemplate/end_time",
"value": "2023-01-01T00:00:00.000", "value": "2023-01-01T00:00:00.000",
}, },
), ),
RunLogPatch( RunLogPatch(
{ {
"op": "add", "op": "add",
"path": "/logs/1", "path": "/logs/FakeListLLM",
"value": { "value": {
"end_time": None, "end_time": None,
"final_output": None, "final_output": None,
@ -1614,7 +1555,7 @@ async def test_prompt_with_llm(
RunLogPatch( RunLogPatch(
{ {
"op": "add", "op": "add",
"path": "/logs/1/final_output", "path": "/logs/FakeListLLM/final_output",
"value": { "value": {
"generations": [[{"generation_info": None, "text": "foo"}]], "generations": [[{"generation_info": None, "text": "foo"}]],
"llm_output": None, "llm_output": None,
@ -1623,7 +1564,7 @@ async def test_prompt_with_llm(
}, },
{ {
"op": "add", "op": "add",
"path": "/logs/1/end_time", "path": "/logs/FakeListLLM/end_time",
"value": "2023-01-01T00:00:00.000", "value": "2023-01-01T00:00:00.000",
}, },
), ),
@ -1634,6 +1575,192 @@ async def test_prompt_with_llm(
] ]
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_stream_log_retriever() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{documents}"
+ "{question}"
)
llm = FakeListLLM(responses=["foo", "bar"])
chain: Runnable = (
{"documents": FakeRetriever(), "question": itemgetter("question")}
| prompt
| {"one": llm, "two": llm}
)
stream_log = [
part async for part in chain.astream_log({"question": "What is your name?"})
]
# remove ids from logs
for part in stream_log:
for op in part.ops:
if (
isinstance(op["value"], dict)
and "id" in op["value"]
and not isinstance(op["value"]["id"], list) # serialized lc id
):
del op["value"]["id"]
assert stream_log[:-9] == [
RunLogPatch(
{
"op": "replace",
"path": "",
"value": {
"logs": {},
"final_output": None,
"streamed_output": [],
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/RunnableMap",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "RunnableMap",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["seq:step:1"],
"type": "chain",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/RunnableLambda",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "RunnableLambda",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["map:key:question"],
"type": "chain",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/RunnableLambda/final_output",
"value": {"output": "What is your name?"},
},
{
"op": "add",
"path": "/logs/RunnableLambda/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/Retriever",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "Retriever",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["map:key:documents"],
"type": "retriever",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/Retriever/final_output",
"value": {
"documents": [
Document(page_content="foo"),
Document(page_content="bar"),
]
},
},
{
"op": "add",
"path": "/logs/Retriever/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/RunnableMap/final_output",
"value": {
"documents": [
Document(page_content="foo"),
Document(page_content="bar"),
],
"question": "What is your name?",
},
},
{
"op": "add",
"path": "/logs/RunnableMap/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/ChatPromptTemplate",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "ChatPromptTemplate",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["seq:step:2"],
"type": "prompt",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/ChatPromptTemplate/final_output",
"value": ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(
content="[Document(page_content='foo'), Document(page_content='bar')]" # noqa: E501
),
HumanMessage(content="What is your name?"),
]
),
},
{
"op": "add",
"path": "/logs/ChatPromptTemplate/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
]
assert sorted(cast(RunLog, add(stream_log)).state["logs"]) == [
"ChatPromptTemplate",
"FakeListLLM",
"FakeListLLM:2",
"Retriever",
"RunnableLambda",
"RunnableMap",
"RunnableMap:2",
]
@pytest.mark.asyncio @pytest.mark.asyncio
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda( async def test_prompt_with_llm_and_async_lambda(
@ -2291,14 +2418,18 @@ async def test_map_astream() -> None:
assert isinstance(final_state.state["id"], str) assert isinstance(final_state.state["id"], str)
assert len(final_state.ops) == len(streamed_ops) assert len(final_state.ops) == len(streamed_ops)
assert len(final_state.state["logs"]) == 5 assert len(final_state.state["logs"]) == 5
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate" assert (
assert final_state.state["logs"][0]["final_output"] == dumpd( final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate"
prompt.invoke({"question": "What is your name?"}) )
) assert final_state.state["logs"]["ChatPromptTemplate"][
assert final_state.state["logs"][1]["name"] == "RunnableMap" "final_output"
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [ ] == prompt.invoke({"question": "What is your name?"})
assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap"
assert sorted(final_state.state["logs"]) == [
"ChatPromptTemplate",
"FakeListChatModel", "FakeListChatModel",
"FakeStreamingListLLM", "FakeStreamingListLLM",
"RunnableMap",
"RunnablePassthrough", "RunnablePassthrough",
] ]
@ -2316,7 +2447,7 @@ async def test_map_astream() -> None:
assert final_state.state["final_output"] == final_value assert final_state.state["final_output"] == final_value
assert len(final_state.state["streamed_output"]) == len(streamed_chunks) assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
assert len(final_state.state["logs"]) == 1 assert len(final_state.state["logs"]) == 1
assert final_state.state["logs"][0]["name"] == "FakeListChatModel" assert final_state.state["logs"]["FakeListChatModel"]["name"] == "FakeListChatModel"
# Test astream_log with exclude filters # Test astream_log with exclude filters
final_state = None final_state = None
@ -2332,13 +2463,17 @@ async def test_map_astream() -> None:
assert final_state.state["final_output"] == final_value assert final_state.state["final_output"] == final_value
assert len(final_state.state["streamed_output"]) == len(streamed_chunks) assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
assert len(final_state.state["logs"]) == 4 assert len(final_state.state["logs"]) == 4
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate" assert (
assert final_state.state["logs"][0]["final_output"] == dumpd( final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate"
)
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
prompt.invoke({"question": "What is your name?"}) prompt.invoke({"question": "What is your name?"})
) )
assert final_state.state["logs"][1]["name"] == "RunnableMap" assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap"
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [ assert sorted(final_state.state["logs"]) == [
"ChatPromptTemplate",
"FakeStreamingListLLM", "FakeStreamingListLLM",
"RunnableMap",
"RunnablePassthrough", "RunnablePassthrough",
] ]

Loading…
Cancel
Save