From eb67f07e3202e5d6f2e29279eeecea97c12c2c43 Mon Sep 17 00:00:00 2001 From: David Duong Date: Tue, 28 Nov 2023 23:02:31 +0100 Subject: [PATCH] Track RunnableAssign as a separate run trace (#13972) Addressing incorrect order being sent to callbacks / tracers, due to the nature of threading --------- Co-authored-by: Nuno Campos --- .../langchain_core/runnables/passthrough.py | 103 +++++++++++++++--- .../unit_tests/runnables/test_runnable.py | 41 +++++++ 2 files changed, 130 insertions(+), 14 deletions(-) diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index f9eb3bdebc..1b7242f28e 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -5,6 +5,7 @@ import asyncio import inspect import threading from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Awaitable, @@ -31,11 +32,18 @@ from langchain_core.runnables.config import ( acall_func_with_variable_args, call_func_with_variable_args, get_executor_for_config, + patch_config, ) from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec from langchain_core.utils.aiter import atee, py_anext from langchain_core.utils.iter import safetee +if TYPE_CHECKING: + from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + ) + def identity(x: Other) -> Other: """An identity function""" @@ -345,18 +353,52 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def config_specs(self) -> List[ConfigurableFieldSpec]: return self.mapper.config_specs + def _invoke( + self, + input: Dict[str, Any], + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + **kwargs: Any, + ) -> Dict[str, Any]: + assert isinstance( + input, dict + ), "The input to RunnablePassthrough.assign() must be a dict." + + return { + **input, + **self.mapper.invoke( + input, + patch_config(config, callbacks=run_manager.get_child()), + **kwargs, + ), + } + def invoke( self, input: Dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, + ) -> Dict[str, Any]: + return self._call_with_config(self._invoke, input, config, **kwargs) + + async def _ainvoke( + self, + input: Dict[str, Any], + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + **kwargs: Any, ) -> Dict[str, Any]: assert isinstance( input, dict ), "The input to RunnablePassthrough.assign() must be a dict." + return { **input, - **self.mapper.invoke(input, config, **kwargs), + **await self.mapper.ainvoke( + input, + patch_config(config, callbacks=run_manager.get_child()), + **kwargs, + ), } async def ainvoke( @@ -365,26 +407,30 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Dict[str, Any]: - assert isinstance( - input, dict - ), "The input to RunnablePassthrough.assign() must be a dict." - return { - **input, - **await self.mapper.ainvoke(input, config, **kwargs), - } + return await self._acall_with_config(self._ainvoke, input, config, **kwargs) - def transform( + def _transform( self, input: Iterator[Dict[str, Any]], - config: Optional[RunnableConfig] = None, + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, **kwargs: Any, ) -> Iterator[Dict[str, Any]]: # collect mapper keys mapper_keys = set(self.mapper.steps.keys()) # create two streams, one for the map and one for the passthrough for_passthrough, for_map = safetee(input, 2, lock=threading.Lock()) + # create map output stream - map_output = self.mapper.transform(for_map, config, **kwargs) + map_output = self.mapper.transform( + for_map, + patch_config( + config, + callbacks=run_manager.get_child(), + ), + **kwargs, + ) + # get executor to start map output stream in background with get_executor_for_config(config or {}) as executor: # start map output stream @@ -409,10 +455,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): for chunk in map_output: yield chunk - async def atransform( + def transform( + self, + input: Iterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> Iterator[Dict[str, Any]]: + yield from self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + async def _atransform( self, input: AsyncIterator[Dict[str, Any]], - config: Optional[RunnableConfig] = None, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, **kwargs: Any, ) -> AsyncIterator[Dict[str, Any]]: # collect mapper keys @@ -420,7 +477,14 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): # create two streams, one for the map and one for the passthrough for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock()) # create map output stream - map_output = self.mapper.atransform(for_map, config, **kwargs) + map_output = self.mapper.atransform( + for_map, + patch_config( + config, + callbacks=run_manager.get_child(), + ), + **kwargs, + ) # start map output stream first_map_chunk_task: asyncio.Task = asyncio.create_task( py_anext(map_output, None), # type: ignore[arg-type] @@ -441,6 +505,17 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): async for chunk in map_output: yield chunk + async def atransform( + self, + input: AsyncIterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + async for chunk in self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ): + yield chunk + def stream( self, input: Dict[str, Any], diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 161279945d..9d0bc740bd 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -4146,3 +4146,44 @@ async def test_ainvoke_on_returned_runnable() -> None: return idchain assert await RunnableLambda(func).ainvoke({}) + + +def test_invoke_stream_passthrough_assign_trace() -> None: + def idchain_sync(__input: dict) -> bool: + return False + + chain = RunnablePassthrough.assign(urls=idchain_sync) + + tracer = FakeTracer() + chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) + + assert tracer.runs[0].name == "RunnableAssign" + assert tracer.runs[0].child_runs[0].name == "RunnableParallel" + + tracer = FakeTracer() + for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): + pass + + assert tracer.runs[0].name == "RunnableAssign" + assert tracer.runs[0].child_runs[0].name == "RunnableParallel" + + +@pytest.mark.asyncio +async def test_ainvoke_astream_passthrough_assign_trace() -> None: + def idchain_sync(__input: dict) -> bool: + return False + + chain = RunnablePassthrough.assign(urls=idchain_sync) + + tracer = FakeTracer() + await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) + + assert tracer.runs[0].name == "RunnableAssign" + assert tracer.runs[0].child_runs[0].name == "RunnableParallel" + + tracer = FakeTracer() + async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): + pass + + assert tracer.runs[0].name == "RunnableAssign" + assert tracer.runs[0].child_runs[0].name == "RunnableParallel"