mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Track RunnableAssign as a separate run trace (#13972)
Addressing incorrect order being sent to callbacks / tracers, due to the nature of threading --------- Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
0f255bb6c4
commit
eb67f07e32
@ -5,6 +5,7 @@ import asyncio
|
||||
import inspect
|
||||
import threading
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
@ -31,11 +32,18 @@ from langchain_core.runnables.config import (
|
||||
acall_func_with_variable_args,
|
||||
call_func_with_variable_args,
|
||||
get_executor_for_config,
|
||||
patch_config,
|
||||
)
|
||||
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
|
||||
from langchain_core.utils.aiter import atee, py_anext
|
||||
from langchain_core.utils.iter import safetee
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
|
||||
def identity(x: Other) -> Other:
|
||||
"""An identity function"""
|
||||
@ -345,18 +353,52 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.mapper.config_specs
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
|
||||
return {
|
||||
**input,
|
||||
**self.mapper.invoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
),
|
||||
}
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return self._call_with_config(self._invoke, input, config, **kwargs)
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
|
||||
return {
|
||||
**input,
|
||||
**self.mapper.invoke(input, config, **kwargs),
|
||||
**await self.mapper.ainvoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
),
|
||||
}
|
||||
|
||||
async def ainvoke(
|
||||
@ -365,26 +407,30 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
return {
|
||||
**input,
|
||||
**await self.mapper.ainvoke(input, config, **kwargs),
|
||||
}
|
||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||
|
||||
def transform(
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
mapper_keys = set(self.mapper.steps.keys())
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
|
||||
|
||||
# create map output stream
|
||||
map_output = self.mapper.transform(for_map, config, **kwargs)
|
||||
map_output = self.mapper.transform(
|
||||
for_map,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# get executor to start map output stream in background
|
||||
with get_executor_for_config(config or {}) as executor:
|
||||
# start map output stream
|
||||
@ -409,10 +455,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
for chunk in map_output:
|
||||
yield chunk
|
||||
|
||||
async def atransform(
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, **kwargs
|
||||
)
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
# collect mapper keys
|
||||
@ -420,7 +477,14 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
# create two streams, one for the map and one for the passthrough
|
||||
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
|
||||
# create map output stream
|
||||
map_output = self.mapper.atransform(for_map, config, **kwargs)
|
||||
map_output = self.mapper.atransform(
|
||||
for_map,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
# start map output stream
|
||||
first_map_chunk_task: asyncio.Task = asyncio.create_task(
|
||||
py_anext(map_output, None), # type: ignore[arg-type]
|
||||
@ -441,6 +505,17 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
async for chunk in map_output:
|
||||
yield chunk
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Dict[str, Any]],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
|
@ -4146,3 +4146,44 @@ async def test_ainvoke_on_returned_runnable() -> None:
|
||||
return idchain
|
||||
|
||||
assert await RunnableLambda(func).ainvoke({})
|
||||
|
||||
|
||||
def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||
def idchain_sync(__input: dict) -> bool:
|
||||
return False
|
||||
|
||||
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
||||
|
||||
tracer = FakeTracer()
|
||||
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||
|
||||
tracer = FakeTracer()
|
||||
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||
pass
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||
def idchain_sync(__input: dict) -> bool:
|
||||
return False
|
||||
|
||||
chain = RunnablePassthrough.assign(urls=idchain_sync)
|
||||
|
||||
tracer = FakeTracer()
|
||||
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||
|
||||
tracer = FakeTracer()
|
||||
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||
pass
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign"
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||
|
Loading…
Reference in New Issue
Block a user