diff --git a/libs/langchain/langchain/callbacks/tracers/log_stream.py b/libs/langchain/langchain/callbacks/tracers/log_stream.py index 08e1eedddb..2440d5feca 100644 --- a/libs/langchain/langchain/callbacks/tracers/log_stream.py +++ b/libs/langchain/langchain/callbacks/tracers/log_stream.py @@ -2,6 +2,7 @@ from __future__ import annotations import math import threading +from collections import defaultdict from typing import ( Any, AsyncIterator, @@ -19,6 +20,7 @@ from anyio import create_memory_object_stream from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run +from langchain.load.load import load from langchain.schema.output import ChatGenerationChunk, GenerationChunk @@ -55,7 +57,7 @@ class RunState(TypedDict): """Final output of the run, usually the result of aggregating streamed_output. Only available after the run has finished successfully.""" - logs: list[LogEntry] + logs: Dict[str, LogEntry] """List of sub-runs contained in this run, if any, in the order they were started. If filters were supplied, this list will contain only the runs that matched the filters.""" @@ -85,7 +87,8 @@ class RunLogPatch: def __repr__(self) -> str: from pprint import pformat - return f"RunLogPatch(ops={pformat(self.ops)})" + # 1:-1 to get rid of the [] around the list + return f"RunLogPatch({pformat(self.ops)[1:-1]})" def __eq__(self, other: object) -> bool: return isinstance(other, RunLogPatch) and self.ops == other.ops @@ -112,7 +115,7 @@ class RunLog(RunLogPatch): def __repr__(self) -> str: from pprint import pformat - return f"RunLog(state={pformat(self.state)})" + return f"RunLog({pformat(self.state)})" class LogStreamCallbackHandler(BaseTracer): @@ -143,7 +146,8 @@ class LogStreamCallbackHandler(BaseTracer): self.lock = threading.Lock() self.send_stream = send_stream self.receive_stream = receive_stream - self._index_map: Dict[UUID, int] = {} + self._key_map_by_run_id: Dict[UUID, str] = {} + self._counter_map_by_name: Dict[str, int] = defaultdict(int) def __aiter__(self) -> AsyncIterator[RunLogPatch]: return self.receive_stream.__aiter__() @@ -196,7 +200,7 @@ class LogStreamCallbackHandler(BaseTracer): id=str(run.id), streamed_output=[], final_output=None, - logs=[], + logs={}, ), } ) @@ -207,14 +211,18 @@ class LogStreamCallbackHandler(BaseTracer): # Determine previous index, increment by 1 with self.lock: - self._index_map[run.id] = max(self._index_map.values(), default=-1) + 1 + self._counter_map_by_name[run.name] += 1 + count = self._counter_map_by_name[run.name] + self._key_map_by_run_id[run.id] = ( + run.name if count == 1 else f"{run.name}:{count}" + ) # Add the run to the stream self.send_stream.send_nowait( RunLogPatch( { "op": "add", - "path": f"/logs/{self._index_map[run.id]}", + "path": f"/logs/{self._key_map_by_run_id[run.id]}", "value": LogEntry( id=str(run.id), name=run.name, @@ -233,7 +241,7 @@ class LogStreamCallbackHandler(BaseTracer): def _on_run_update(self, run: Run) -> None: """Finish a run.""" try: - index = self._index_map.get(run.id) + index = self._key_map_by_run_id.get(run.id) if index is None: return @@ -243,7 +251,8 @@ class LogStreamCallbackHandler(BaseTracer): { "op": "add", "path": f"/logs/{index}/final_output", - "value": run.outputs, + # to undo the dumpd done by some runnables / tracer / etc + "value": load(run.outputs), }, { "op": "add", @@ -259,7 +268,7 @@ class LogStreamCallbackHandler(BaseTracer): { "op": "replace", "path": "/final_output", - "value": run.outputs, + "value": load(run.outputs), } ) ) @@ -273,7 +282,7 @@ class LogStreamCallbackHandler(BaseTracer): chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], ) -> None: """Process new LLM token.""" - index = self._index_map.get(run.id) + index = self._key_map_by_run_id.get(run.id) if index is None: return diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 875e6965bf..680df3c4d3 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1239,7 +1239,7 @@ async def test_prompt() -> None: assert len(stream_log[0].ops) == 1 assert stream_log[0].ops[0]["op"] == "replace" assert stream_log[0].ops[0]["path"] == "" - assert stream_log[0].ops[0]["value"]["logs"] == [] + assert stream_log[0].ops[0]["value"]["logs"] == {} assert stream_log[0].ops[0]["value"]["final_output"] is None assert stream_log[0].ops[0]["value"]["streamed_output"] == [] assert isinstance(stream_log[0].ops[0]["value"]["id"], str) @@ -1249,40 +1249,12 @@ async def test_prompt() -> None: { "op": "replace", "path": "/final_output", - "value": { - "id": ["langchain", "prompts", "chat", "ChatPromptValue"], - "kwargs": { - "messages": [ - { - "id": [ - "langchain", - "schema", - "messages", - "SystemMessage", - ], - "kwargs": {"content": "You are a nice " "assistant."}, - "lc": 1, - "type": "constructor", - }, - { - "id": [ - "langchain", - "schema", - "messages", - "HumanMessage", - ], - "kwargs": { - "additional_kwargs": {}, - "content": "What is your " "name?", - }, - "lc": 1, - "type": "constructor", - }, - ] - }, - "lc": 1, - "type": "constructor", - }, + "value": ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ), } ), RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}), @@ -1525,7 +1497,7 @@ async def test_prompt_with_llm( "op": "replace", "path": "", "value": { - "logs": [], + "logs": {}, "final_output": None, "streamed_output": [], }, @@ -1534,7 +1506,7 @@ async def test_prompt_with_llm( RunLogPatch( { "op": "add", - "path": "/logs/0", + "path": "/logs/ChatPromptTemplate", "value": { "end_time": None, "final_output": None, @@ -1550,55 +1522,24 @@ async def test_prompt_with_llm( RunLogPatch( { "op": "add", - "path": "/logs/0/final_output", - "value": { - "id": ["langchain", "prompts", "chat", "ChatPromptValue"], - "kwargs": { - "messages": [ - { - "id": [ - "langchain", - "schema", - "messages", - "SystemMessage", - ], - "kwargs": { - "additional_kwargs": {}, - "content": "You are a nice " "assistant.", - }, - "lc": 1, - "type": "constructor", - }, - { - "id": [ - "langchain", - "schema", - "messages", - "HumanMessage", - ], - "kwargs": { - "additional_kwargs": {}, - "content": "What is your " "name?", - }, - "lc": 1, - "type": "constructor", - }, - ] - }, - "lc": 1, - "type": "constructor", - }, + "path": "/logs/ChatPromptTemplate/final_output", + "value": ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ), }, { "op": "add", - "path": "/logs/0/end_time", + "path": "/logs/ChatPromptTemplate/end_time", "value": "2023-01-01T00:00:00.000", }, ), RunLogPatch( { "op": "add", - "path": "/logs/1", + "path": "/logs/FakeListLLM", "value": { "end_time": None, "final_output": None, @@ -1614,7 +1555,7 @@ async def test_prompt_with_llm( RunLogPatch( { "op": "add", - "path": "/logs/1/final_output", + "path": "/logs/FakeListLLM/final_output", "value": { "generations": [[{"generation_info": None, "text": "foo"}]], "llm_output": None, @@ -1623,7 +1564,7 @@ async def test_prompt_with_llm( }, { "op": "add", - "path": "/logs/1/end_time", + "path": "/logs/FakeListLLM/end_time", "value": "2023-01-01T00:00:00.000", }, ), @@ -1634,6 +1575,192 @@ async def test_prompt_with_llm( ] +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_stream_log_retriever() -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{documents}" + + "{question}" + ) + llm = FakeListLLM(responses=["foo", "bar"]) + + chain: Runnable = ( + {"documents": FakeRetriever(), "question": itemgetter("question")} + | prompt + | {"one": llm, "two": llm} + ) + + stream_log = [ + part async for part in chain.astream_log({"question": "What is your name?"}) + ] + + # remove ids from logs + for part in stream_log: + for op in part.ops: + if ( + isinstance(op["value"], dict) + and "id" in op["value"] + and not isinstance(op["value"]["id"], list) # serialized lc id + ): + del op["value"]["id"] + + assert stream_log[:-9] == [ + RunLogPatch( + { + "op": "replace", + "path": "", + "value": { + "logs": {}, + "final_output": None, + "streamed_output": [], + }, + } + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/RunnableMap", + "value": { + "end_time": None, + "final_output": None, + "metadata": {}, + "name": "RunnableMap", + "start_time": "2023-01-01T00:00:00.000", + "streamed_output_str": [], + "tags": ["seq:step:1"], + "type": "chain", + }, + } + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/RunnableLambda", + "value": { + "end_time": None, + "final_output": None, + "metadata": {}, + "name": "RunnableLambda", + "start_time": "2023-01-01T00:00:00.000", + "streamed_output_str": [], + "tags": ["map:key:question"], + "type": "chain", + }, + } + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/RunnableLambda/final_output", + "value": {"output": "What is your name?"}, + }, + { + "op": "add", + "path": "/logs/RunnableLambda/end_time", + "value": "2023-01-01T00:00:00.000", + }, + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/Retriever", + "value": { + "end_time": None, + "final_output": None, + "metadata": {}, + "name": "Retriever", + "start_time": "2023-01-01T00:00:00.000", + "streamed_output_str": [], + "tags": ["map:key:documents"], + "type": "retriever", + }, + } + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/Retriever/final_output", + "value": { + "documents": [ + Document(page_content="foo"), + Document(page_content="bar"), + ] + }, + }, + { + "op": "add", + "path": "/logs/Retriever/end_time", + "value": "2023-01-01T00:00:00.000", + }, + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/RunnableMap/final_output", + "value": { + "documents": [ + Document(page_content="foo"), + Document(page_content="bar"), + ], + "question": "What is your name?", + }, + }, + { + "op": "add", + "path": "/logs/RunnableMap/end_time", + "value": "2023-01-01T00:00:00.000", + }, + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/ChatPromptTemplate", + "value": { + "end_time": None, + "final_output": None, + "metadata": {}, + "name": "ChatPromptTemplate", + "start_time": "2023-01-01T00:00:00.000", + "streamed_output_str": [], + "tags": ["seq:step:2"], + "type": "prompt", + }, + } + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/ChatPromptTemplate/final_output", + "value": ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage( + content="[Document(page_content='foo'), Document(page_content='bar')]" # noqa: E501 + ), + HumanMessage(content="What is your name?"), + ] + ), + }, + { + "op": "add", + "path": "/logs/ChatPromptTemplate/end_time", + "value": "2023-01-01T00:00:00.000", + }, + ), + ] + + assert sorted(cast(RunLog, add(stream_log)).state["logs"]) == [ + "ChatPromptTemplate", + "FakeListLLM", + "FakeListLLM:2", + "Retriever", + "RunnableLambda", + "RunnableMap", + "RunnableMap:2", + ] + + @pytest.mark.asyncio @freeze_time("2023-01-01") async def test_prompt_with_llm_and_async_lambda( @@ -2291,14 +2418,18 @@ async def test_map_astream() -> None: assert isinstance(final_state.state["id"], str) assert len(final_state.ops) == len(streamed_ops) assert len(final_state.state["logs"]) == 5 - assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate" - assert final_state.state["logs"][0]["final_output"] == dumpd( - prompt.invoke({"question": "What is your name?"}) - ) - assert final_state.state["logs"][1]["name"] == "RunnableMap" - assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [ + assert ( + final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate" + ) + assert final_state.state["logs"]["ChatPromptTemplate"][ + "final_output" + ] == prompt.invoke({"question": "What is your name?"}) + assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap" + assert sorted(final_state.state["logs"]) == [ + "ChatPromptTemplate", "FakeListChatModel", "FakeStreamingListLLM", + "RunnableMap", "RunnablePassthrough", ] @@ -2316,7 +2447,7 @@ async def test_map_astream() -> None: assert final_state.state["final_output"] == final_value assert len(final_state.state["streamed_output"]) == len(streamed_chunks) assert len(final_state.state["logs"]) == 1 - assert final_state.state["logs"][0]["name"] == "FakeListChatModel" + assert final_state.state["logs"]["FakeListChatModel"]["name"] == "FakeListChatModel" # Test astream_log with exclude filters final_state = None @@ -2332,13 +2463,17 @@ async def test_map_astream() -> None: assert final_state.state["final_output"] == final_value assert len(final_state.state["streamed_output"]) == len(streamed_chunks) assert len(final_state.state["logs"]) == 4 - assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate" - assert final_state.state["logs"][0]["final_output"] == dumpd( + assert ( + final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate" + ) + assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == ( prompt.invoke({"question": "What is your name?"}) ) - assert final_state.state["logs"][1]["name"] == "RunnableMap" - assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [ + assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap" + assert sorted(final_state.state["logs"]) == [ + "ChatPromptTemplate", "FakeStreamingListLLM", + "RunnableMap", "RunnablePassthrough", ]