@ -5,6 +5,7 @@ import asyncio
import inspect
import threading
from typing import (
TYPE_CHECKING ,
Any ,
AsyncIterator ,
Awaitable ,
@ -31,11 +32,18 @@ from langchain_core.runnables.config import (
acall_func_with_variable_args ,
call_func_with_variable_args ,
get_executor_for_config ,
patch_config ,
)
from langchain_core . runnables . utils import AddableDict , ConfigurableFieldSpec
from langchain_core . utils . aiter import atee , py_anext
from langchain_core . utils . iter import safetee
if TYPE_CHECKING :
from langchain_core . callbacks . manager import (
AsyncCallbackManagerForChainRun ,
CallbackManagerForChainRun ,
)
def identity ( x : Other ) - > Other :
""" An identity function """
@ -345,46 +353,84 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def config_specs ( self ) - > List [ ConfigurableFieldSpec ] :
return self . mapper . config_specs
def invoke(
def _ invoke(
self ,
input : Dict [ str , Any ] ,
config : Optional [ RunnableConfig ] = None ,
run_manager : CallbackManagerForChainRun ,
config : RunnableConfig ,
* * kwargs : Any ,
) - > Dict [ str , Any ] :
assert isinstance (
input , dict
) , " The input to RunnablePassthrough.assign() must be a dict. "
return {
* * input ,
* * self . mapper . invoke ( input , config , * * kwargs ) ,
* * self . mapper . invoke (
input ,
patch_config ( config , callbacks = run_manager . get_child ( ) ) ,
* * kwargs ,
) ,
}
async def ainvoke (
def invoke(
self ,
input : Dict [ str , Any ] ,
config : Optional [ RunnableConfig ] = None ,
* * kwargs : Any ,
) - > Dict [ str , Any ] :
return self . _call_with_config ( self . _invoke , input , config , * * kwargs )
async def _ainvoke (
self ,
input : Dict [ str , Any ] ,
run_manager : AsyncCallbackManagerForChainRun ,
config : RunnableConfig ,
* * kwargs : Any ,
) - > Dict [ str , Any ] :
assert isinstance (
input , dict
) , " The input to RunnablePassthrough.assign() must be a dict. "
return {
* * input ,
* * await self . mapper . ainvoke ( input , config , * * kwargs ) ,
* * await self . mapper . ainvoke (
input ,
patch_config ( config , callbacks = run_manager . get_child ( ) ) ,
* * kwargs ,
) ,
}
def transform (
async def ainvoke (
self ,
input : Iterator[ Dict[ str , Any ] ] ,
input : Dict[ str , Any ] ,
config : Optional [ RunnableConfig ] = None ,
* * kwargs : Any ,
) - > Dict [ str , Any ] :
return await self . _acall_with_config ( self . _ainvoke , input , config , * * kwargs )
def _transform (
self ,
input : Iterator [ Dict [ str , Any ] ] ,
run_manager : CallbackManagerForChainRun ,
config : RunnableConfig ,
* * kwargs : Any ,
) - > Iterator [ Dict [ str , Any ] ] :
# collect mapper keys
mapper_keys = set ( self . mapper . steps . keys ( ) )
# create two streams, one for the map and one for the passthrough
for_passthrough , for_map = safetee ( input , 2 , lock = threading . Lock ( ) )
# create map output stream
map_output = self . mapper . transform ( for_map , config , * * kwargs )
map_output = self . mapper . transform (
for_map ,
patch_config (
config ,
callbacks = run_manager . get_child ( ) ,
) ,
* * kwargs ,
)
# get executor to start map output stream in background
with get_executor_for_config ( config or { } ) as executor :
# start map output stream
@ -409,10 +455,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
for chunk in map_output :
yield chunk
async def a transform(
def transform(
self ,
input : Async Iterator[ Dict [ str , Any ] ] ,
input : Iterator[ Dict [ str , Any ] ] ,
config : Optional [ RunnableConfig ] = None ,
* * kwargs : Any | None ,
) - > Iterator [ Dict [ str , Any ] ] :
yield from self . _transform_stream_with_config (
input , self . _transform , config , * * kwargs
)
async def _atransform (
self ,
input : AsyncIterator [ Dict [ str , Any ] ] ,
run_manager : AsyncCallbackManagerForChainRun ,
config : RunnableConfig ,
* * kwargs : Any ,
) - > AsyncIterator [ Dict [ str , Any ] ] :
# collect mapper keys
@ -420,7 +477,14 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
# create two streams, one for the map and one for the passthrough
for_passthrough , for_map = atee ( input , 2 , lock = asyncio . Lock ( ) )
# create map output stream
map_output = self . mapper . atransform ( for_map , config , * * kwargs )
map_output = self . mapper . atransform (
for_map ,
patch_config (
config ,
callbacks = run_manager . get_child ( ) ,
) ,
* * kwargs ,
)
# start map output stream
first_map_chunk_task : asyncio . Task = asyncio . create_task (
py_anext ( map_output , None ) , # type: ignore[arg-type]
@ -441,6 +505,17 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
async for chunk in map_output :
yield chunk
async def atransform (
self ,
input : AsyncIterator [ Dict [ str , Any ] ] ,
config : Optional [ RunnableConfig ] = None ,
* * kwargs : Any ,
) - > AsyncIterator [ Dict [ str , Any ] ] :
async for chunk in self . _atransform_stream_with_config (
input , self . _atransform , config , * * kwargs
) :
yield chunk
def stream (
self ,
input : Dict [ str , Any ] ,