From 7897483819c9031aca2cffe26d181eb5b1b4d28b Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 30 Oct 2023 15:55:16 +0000 Subject: [PATCH] Allow astream_log to be used inside atrace_as_chain_group (#12558) --- libs/langchain/langchain/callbacks/manager.py | 12 +++++++ .../langchain/callbacks/tracers/log_stream.py | 8 +++-- .../langchain/schema/runnable/base.py | 2 +- .../schema/runnable/test_runnable.py | 36 ++++++++++++++++++- 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 3ee6f45896..3231892ff7 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -1832,6 +1832,18 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): self.parent_run_manager = parent_run_manager self.ended = False + def copy(self) -> AsyncCallbackManagerForChainGroup: + return self.__class__( + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + parent_run_manager=self.parent_run_manager, + ) + async def on_chain_end( self, outputs: Union[Dict[str, Any], Any], **kwargs: Any ) -> None: diff --git a/libs/langchain/langchain/callbacks/tracers/log_stream.py b/libs/langchain/langchain/callbacks/tracers/log_stream.py index d45b576a54..1bca4098f8 100644 --- a/libs/langchain/langchain/callbacks/tracers/log_stream.py +++ b/libs/langchain/langchain/callbacks/tracers/log_stream.py @@ -157,12 +157,13 @@ class LogStreamCallbackHandler(BaseTracer): self.receive_stream = receive_stream self._key_map_by_run_id: Dict[UUID, str] = {} self._counter_map_by_name: Dict[str, int] = defaultdict(int) + self.root_id: Optional[UUID] = None def __aiter__(self) -> AsyncIterator[RunLogPatch]: return self.receive_stream.__aiter__() def include_run(self, run: Run) -> bool: - if run.parent_run_id is None: + if run.id == self.root_id: return False run_tags = run.tags or [] @@ -199,7 +200,8 @@ class LogStreamCallbackHandler(BaseTracer): def _on_run_create(self, run: Run) -> None: """Start a run.""" - if run.parent_run_id is None: + if self.root_id is None: + self.root_id = run.id self.send_stream.send_nowait( RunLogPatch( { @@ -273,7 +275,7 @@ class LogStreamCallbackHandler(BaseTracer): ) ) finally: - if run.parent_run_id is None: + if run.id == self.root_id: self.send_stream.send_nowait( RunLogPatch( { diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c95d5254d9..025e0030e9 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -463,7 +463,7 @@ class Runnable(Generic[Input, Output], ABC): config["callbacks"] = callbacks + [stream] elif isinstance(callbacks, BaseCallbackManager): callbacks = callbacks.copy() - callbacks.inheritable_handlers.append(stream) + callbacks.add_handler(stream, inherit=True) config["callbacks"] = callbacks else: raise ValueError( diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 3db0fbcdd6..a0fc7e8823 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -19,7 +19,7 @@ from freezegun import freeze_time from pytest_mock import MockerFixture from syrupy import SnapshotAssertion -from langchain.callbacks.manager import Callbacks, collect_runs +from langchain.callbacks.manager import Callbacks, atrace_as_chain_group, collect_runs from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch from langchain.callbacks.tracers.schemas import Run @@ -1442,6 +1442,40 @@ async def test_prompt() -> None: }, ) + # nested inside trace_with_chain_group + + async with atrace_as_chain_group("a_group") as manager: + stream_log_nested = [ + part + async for part in prompt.astream_log( + {"question": "What is your name?"}, config={"callbacks": manager} + ) + ] + + assert len(stream_log_nested[0].ops) == 1 + assert stream_log_nested[0].ops[0]["op"] == "replace" + assert stream_log_nested[0].ops[0]["path"] == "" + assert stream_log_nested[0].ops[0]["value"]["logs"] == {} + assert stream_log_nested[0].ops[0]["value"]["final_output"] is None + assert stream_log_nested[0].ops[0]["value"]["streamed_output"] == [] + assert isinstance(stream_log_nested[0].ops[0]["value"]["id"], str) + + assert stream_log_nested[1:] == [ + RunLogPatch( + { + "op": "replace", + "path": "/final_output", + "value": ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ), + } + ), + RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}), + ] + def test_prompt_template_params() -> None: prompt = ChatPromptTemplate.from_template(