Track RunnableAssign as a separate run trace (#13972)

Addressing incorrect order being sent to callbacks / tracers, due to the
nature of threading

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
David Duong 2023-11-28 23:02:31 +01:00 committed by GitHub
parent 0f255bb6c4
commit eb67f07e32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 130 additions and 14 deletions

View File

@ -5,6 +5,7 @@ import asyncio
import inspect
import threading
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
@ -31,11 +32,18 @@ from langchain_core.runnables.config import (
acall_func_with_variable_args,
call_func_with_variable_args,
get_executor_for_config,
patch_config,
)
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
def identity(x: Other) -> Other:
"""An identity function"""
@ -345,18 +353,52 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.mapper.config_specs
def _invoke(
self,
input: Dict[str, Any],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
return {
**input,
**self.mapper.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
),
}
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
input: Dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
return {
**input,
**self.mapper.invoke(input, config, **kwargs),
**await self.mapper.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
),
}
async def ainvoke(
@ -365,26 +407,30 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
return {
**input,
**await self.mapper.ainvoke(input, config, **kwargs),
}
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def transform(
def _transform(
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
# create map output stream
map_output = self.mapper.transform(for_map, config, **kwargs)
map_output = self.mapper.transform(
for_map,
patch_config(
config,
callbacks=run_manager.get_child(),
),
**kwargs,
)
# get executor to start map output stream in background
with get_executor_for_config(config or {}) as executor:
# start map output stream
@ -409,10 +455,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
for chunk in map_output:
yield chunk
async def atransform(
def transform(
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Dict[str, Any]]:
yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
async def _atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
# collect mapper keys
@ -420,7 +477,14 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
# create map output stream
map_output = self.mapper.atransform(for_map, config, **kwargs)
map_output = self.mapper.atransform(
for_map,
patch_config(
config,
callbacks=run_manager.get_child(),
),
**kwargs,
)
# start map output stream
first_map_chunk_task: asyncio.Task = asyncio.create_task(
py_anext(map_output, None), # type: ignore[arg-type]
@ -441,6 +505,17 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
async for chunk in map_output:
yield chunk
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
):
yield chunk
def stream(
self,
input: Dict[str, Any],

View File

@ -4146,3 +4146,44 @@ async def test_ainvoke_on_returned_runnable() -> None:
return idchain
assert await RunnableLambda(func).ainvoke({})
def test_invoke_stream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
return False
chain = RunnablePassthrough.assign(urls=idchain_sync)
tracer = FakeTracer()
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
assert tracer.runs[0].name == "RunnableAssign"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
tracer = FakeTracer()
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass
assert tracer.runs[0].name == "RunnableAssign"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
@pytest.mark.asyncio
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
return False
chain = RunnablePassthrough.assign(urls=idchain_sync)
tracer = FakeTracer()
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
assert tracer.runs[0].name == "RunnableAssign"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
tracer = FakeTracer()
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass
assert tracer.runs[0].name == "RunnableAssign"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"