Improve output of Runnable.astream_log() (#11391)

- Make logs a dictionary keyed by run name (and counter for repeats)
- Ensure no output shows up in lc_serializable format
- Fix up repr for RunLog and RunLogPatch

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11304/head
Nuno Campos 10 months ago committed by GitHub
parent a30f98f534
commit 4d66756d93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,7 @@ from __future__ import annotations
import math
import threading
from collections import defaultdict
from typing import (
Any,
AsyncIterator,
@ -19,6 +20,7 @@ from anyio import create_memory_object_stream
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.load.load import load
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
@ -55,7 +57,7 @@ class RunState(TypedDict):
"""Final output of the run, usually the result of aggregating streamed_output.
Only available after the run has finished successfully."""
logs: list[LogEntry]
logs: Dict[str, LogEntry]
"""List of sub-runs contained in this run, if any, in the order they were started.
If filters were supplied, this list will contain only the runs that matched the
filters."""
@ -85,7 +87,8 @@ class RunLogPatch:
def __repr__(self) -> str:
from pprint import pformat
return f"RunLogPatch(ops={pformat(self.ops)})"
# 1:-1 to get rid of the [] around the list
return f"RunLogPatch({pformat(self.ops)[1:-1]})"
def __eq__(self, other: object) -> bool:
return isinstance(other, RunLogPatch) and self.ops == other.ops
@ -112,7 +115,7 @@ class RunLog(RunLogPatch):
def __repr__(self) -> str:
from pprint import pformat
return f"RunLog(state={pformat(self.state)})"
return f"RunLog({pformat(self.state)})"
class LogStreamCallbackHandler(BaseTracer):
@ -143,7 +146,8 @@ class LogStreamCallbackHandler(BaseTracer):
self.lock = threading.Lock()
self.send_stream = send_stream
self.receive_stream = receive_stream
self._index_map: Dict[UUID, int] = {}
self._key_map_by_run_id: Dict[UUID, str] = {}
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__()
@ -196,7 +200,7 @@ class LogStreamCallbackHandler(BaseTracer):
id=str(run.id),
streamed_output=[],
final_output=None,
logs=[],
logs={},
),
}
)
@ -207,14 +211,18 @@ class LogStreamCallbackHandler(BaseTracer):
# Determine previous index, increment by 1
with self.lock:
self._index_map[run.id] = max(self._index_map.values(), default=-1) + 1
self._counter_map_by_name[run.name] += 1
count = self._counter_map_by_name[run.name]
self._key_map_by_run_id[run.id] = (
run.name if count == 1 else f"{run.name}:{count}"
)
# Add the run to the stream
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{self._index_map[run.id]}",
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
"value": LogEntry(
id=str(run.id),
name=run.name,
@ -233,7 +241,7 @@ class LogStreamCallbackHandler(BaseTracer):
def _on_run_update(self, run: Run) -> None:
"""Finish a run."""
try:
index = self._index_map.get(run.id)
index = self._key_map_by_run_id.get(run.id)
if index is None:
return
@ -243,7 +251,8 @@ class LogStreamCallbackHandler(BaseTracer):
{
"op": "add",
"path": f"/logs/{index}/final_output",
"value": run.outputs,
# to undo the dumpd done by some runnables / tracer / etc
"value": load(run.outputs),
},
{
"op": "add",
@ -259,7 +268,7 @@ class LogStreamCallbackHandler(BaseTracer):
{
"op": "replace",
"path": "/final_output",
"value": run.outputs,
"value": load(run.outputs),
}
)
)
@ -273,7 +282,7 @@ class LogStreamCallbackHandler(BaseTracer):
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> None:
"""Process new LLM token."""
index = self._index_map.get(run.id)
index = self._key_map_by_run_id.get(run.id)
if index is None:
return

@ -1239,7 +1239,7 @@ async def test_prompt() -> None:
assert len(stream_log[0].ops) == 1
assert stream_log[0].ops[0]["op"] == "replace"
assert stream_log[0].ops[0]["path"] == ""
assert stream_log[0].ops[0]["value"]["logs"] == []
assert stream_log[0].ops[0]["value"]["logs"] == {}
assert stream_log[0].ops[0]["value"]["final_output"] is None
assert stream_log[0].ops[0]["value"]["streamed_output"] == []
assert isinstance(stream_log[0].ops[0]["value"]["id"], str)
@ -1249,40 +1249,12 @@ async def test_prompt() -> None:
{
"op": "replace",
"path": "/final_output",
"value": {
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
"kwargs": {
"messages": [
{
"id": [
"langchain",
"schema",
"messages",
"SystemMessage",
],
"kwargs": {"content": "You are a nice " "assistant."},
"lc": 1,
"type": "constructor",
},
{
"id": [
"langchain",
"schema",
"messages",
"HumanMessage",
],
"kwargs": {
"additional_kwargs": {},
"content": "What is your " "name?",
},
"lc": 1,
"type": "constructor",
},
]
},
"lc": 1,
"type": "constructor",
},
"value": ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
}
),
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
@ -1525,7 +1497,7 @@ async def test_prompt_with_llm(
"op": "replace",
"path": "",
"value": {
"logs": [],
"logs": {},
"final_output": None,
"streamed_output": [],
},
@ -1534,7 +1506,7 @@ async def test_prompt_with_llm(
RunLogPatch(
{
"op": "add",
"path": "/logs/0",
"path": "/logs/ChatPromptTemplate",
"value": {
"end_time": None,
"final_output": None,
@ -1550,55 +1522,24 @@ async def test_prompt_with_llm(
RunLogPatch(
{
"op": "add",
"path": "/logs/0/final_output",
"value": {
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
"kwargs": {
"messages": [
{
"id": [
"langchain",
"schema",
"messages",
"SystemMessage",
],
"kwargs": {
"additional_kwargs": {},
"content": "You are a nice " "assistant.",
},
"lc": 1,
"type": "constructor",
},
{
"id": [
"langchain",
"schema",
"messages",
"HumanMessage",
],
"kwargs": {
"additional_kwargs": {},
"content": "What is your " "name?",
},
"lc": 1,
"type": "constructor",
},
]
},
"lc": 1,
"type": "constructor",
},
"path": "/logs/ChatPromptTemplate/final_output",
"value": ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
},
{
"op": "add",
"path": "/logs/0/end_time",
"path": "/logs/ChatPromptTemplate/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/1",
"path": "/logs/FakeListLLM",
"value": {
"end_time": None,
"final_output": None,
@ -1614,7 +1555,7 @@ async def test_prompt_with_llm(
RunLogPatch(
{
"op": "add",
"path": "/logs/1/final_output",
"path": "/logs/FakeListLLM/final_output",
"value": {
"generations": [[{"generation_info": None, "text": "foo"}]],
"llm_output": None,
@ -1623,7 +1564,7 @@ async def test_prompt_with_llm(
},
{
"op": "add",
"path": "/logs/1/end_time",
"path": "/logs/FakeListLLM/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
@ -1634,6 +1575,192 @@ async def test_prompt_with_llm(
]
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_stream_log_retriever() -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{documents}"
+ "{question}"
)
llm = FakeListLLM(responses=["foo", "bar"])
chain: Runnable = (
{"documents": FakeRetriever(), "question": itemgetter("question")}
| prompt
| {"one": llm, "two": llm}
)
stream_log = [
part async for part in chain.astream_log({"question": "What is your name?"})
]
# remove ids from logs
for part in stream_log:
for op in part.ops:
if (
isinstance(op["value"], dict)
and "id" in op["value"]
and not isinstance(op["value"]["id"], list) # serialized lc id
):
del op["value"]["id"]
assert stream_log[:-9] == [
RunLogPatch(
{
"op": "replace",
"path": "",
"value": {
"logs": {},
"final_output": None,
"streamed_output": [],
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/RunnableMap",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "RunnableMap",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["seq:step:1"],
"type": "chain",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/RunnableLambda",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "RunnableLambda",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["map:key:question"],
"type": "chain",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/RunnableLambda/final_output",
"value": {"output": "What is your name?"},
},
{
"op": "add",
"path": "/logs/RunnableLambda/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/Retriever",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "Retriever",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["map:key:documents"],
"type": "retriever",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/Retriever/final_output",
"value": {
"documents": [
Document(page_content="foo"),
Document(page_content="bar"),
]
},
},
{
"op": "add",
"path": "/logs/Retriever/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/RunnableMap/final_output",
"value": {
"documents": [
Document(page_content="foo"),
Document(page_content="bar"),
],
"question": "What is your name?",
},
},
{
"op": "add",
"path": "/logs/RunnableMap/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/ChatPromptTemplate",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "ChatPromptTemplate",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["seq:step:2"],
"type": "prompt",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/ChatPromptTemplate/final_output",
"value": ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(
content="[Document(page_content='foo'), Document(page_content='bar')]" # noqa: E501
),
HumanMessage(content="What is your name?"),
]
),
},
{
"op": "add",
"path": "/logs/ChatPromptTemplate/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
]
assert sorted(cast(RunLog, add(stream_log)).state["logs"]) == [
"ChatPromptTemplate",
"FakeListLLM",
"FakeListLLM:2",
"Retriever",
"RunnableLambda",
"RunnableMap",
"RunnableMap:2",
]
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda(
@ -2291,14 +2418,18 @@ async def test_map_astream() -> None:
assert isinstance(final_state.state["id"], str)
assert len(final_state.ops) == len(streamed_ops)
assert len(final_state.state["logs"]) == 5
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
assert final_state.state["logs"][0]["final_output"] == dumpd(
prompt.invoke({"question": "What is your name?"})
)
assert final_state.state["logs"][1]["name"] == "RunnableMap"
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
assert (
final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate"
)
assert final_state.state["logs"]["ChatPromptTemplate"][
"final_output"
] == prompt.invoke({"question": "What is your name?"})
assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap"
assert sorted(final_state.state["logs"]) == [
"ChatPromptTemplate",
"FakeListChatModel",
"FakeStreamingListLLM",
"RunnableMap",
"RunnablePassthrough",
]
@ -2316,7 +2447,7 @@ async def test_map_astream() -> None:
assert final_state.state["final_output"] == final_value
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
assert len(final_state.state["logs"]) == 1
assert final_state.state["logs"][0]["name"] == "FakeListChatModel"
assert final_state.state["logs"]["FakeListChatModel"]["name"] == "FakeListChatModel"
# Test astream_log with exclude filters
final_state = None
@ -2332,13 +2463,17 @@ async def test_map_astream() -> None:
assert final_state.state["final_output"] == final_value
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
assert len(final_state.state["logs"]) == 4
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
assert final_state.state["logs"][0]["final_output"] == dumpd(
assert (
final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate"
)
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
prompt.invoke({"question": "What is your name?"})
)
assert final_state.state["logs"][1]["name"] == "RunnableMap"
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap"
assert sorted(final_state.state["logs"]) == [
"ChatPromptTemplate",
"FakeStreamingListLLM",
"RunnableMap",
"RunnablePassthrough",
]

Loading…
Cancel
Save