@ -42,6 +42,7 @@ from langchain.schema.runnable.config import (
ensure_config ,
get_async_callback_manager_for_config ,
get_callback_manager_for_config ,
get_config_list ,
get_executor_for_config ,
patch_config ,
)
@ -110,7 +111,7 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of batch , which calls invoke N times .
Subclasses should override this method if they can batch more efficiently .
"""
configs = self . _ get_config_list( config , len ( inputs ) )
configs = get_config_list( config , len ( inputs ) )
# If there's only one input, don't bother with the executor
if len ( inputs ) == 1 :
@ -129,7 +130,7 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of abatch , which calls ainvoke N times .
Subclasses should override this method if they can batch more efficiently .
"""
configs = self . _ get_config_list( config , len ( inputs ) )
configs = get_config_list( config , len ( inputs ) )
coros = map ( partial ( self . ainvoke , * * kwargs ) , inputs , configs )
return await gather_with_concurrency ( configs [ 0 ] . get ( " max_concurrency " ) , * coros )
@ -210,7 +211,20 @@ class Runnable(Generic[Input, Output], ABC):
"""
Bind arguments to a Runnable , returning a new Runnable .
"""
return RunnableBinding ( bound = self , kwargs = kwargs )
return RunnableBinding ( bound = self , kwargs = kwargs , config = { } )
def with_config (
self ,
config : Optional [ RunnableConfig ] = None ,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
* * kwargs : Any ,
) - > Runnable [ Input , Output ] :
"""
Bind config to a Runnable , returning a new Runnable .
"""
return RunnableBinding (
bound = self , config = { * * ( config or { } ) , * * kwargs } , kwargs = { }
)
def map ( self ) - > Runnable [ List [ Input ] , List [ Output ] ] :
"""
@ -233,27 +247,6 @@ class Runnable(Generic[Input, Output], ABC):
""" --- Helper methods for Subclasses --- """
def _get_config_list (
self , config : Optional [ Union [ RunnableConfig , List [ RunnableConfig ] ] ] , length : int
) - > List [ RunnableConfig ] :
"""
Helper method to get a list of configs from a single config or a list of
configs , useful for subclasses overriding batch ( ) or abatch ( ) .
"""
if length < 1 :
raise ValueError ( f " length must be >= 1, but got { length } " )
if isinstance ( config , list ) and len ( config ) != length :
raise ValueError (
f " config must be a list of the same length as inputs, "
f " but got { len ( config ) } configs for { length } inputs "
)
return (
list ( map ( ensure_config , config ) )
if isinstance ( config , list )
else [ patch_config ( config , deep_copy_locals = True ) for _ in range ( length ) ]
)
def _call_with_config (
self ,
func : Union [
@ -273,6 +266,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd ( self ) ,
input ,
run_type = run_type ,
name = config . get ( " run_name " ) ,
)
try :
if accepts_run_manager_and_config ( func ) :
@ -314,6 +308,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd ( self ) ,
input ,
run_type = run_type ,
name = config . get ( " run_name " ) ,
)
try :
if accepts_run_manager_and_config ( func ) :
@ -371,6 +366,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd ( self ) ,
{ " input " : " " } ,
run_type = run_type ,
name = config . get ( " run_name " ) ,
)
try :
if accepts_run_manager_and_config ( transformer ) :
@ -451,6 +447,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd ( self ) ,
{ " input " : " " } ,
run_type = run_type ,
name = config . get ( " run_name " ) ,
)
try :
# mypy can't quite work out thew type guard here, but this is safe,
@ -526,7 +523,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config = ensure_config ( config )
callback_manager = get_callback_manager_for_config ( config )
# start the root run
run_manager = callback_manager . on_chain_start ( dumpd ( self ) , input )
run_manager = callback_manager . on_chain_start (
dumpd ( self ) , input , name = config . get ( " run_name " )
)
first_error = None
for runnable in self . runnables :
try :
@ -558,7 +557,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config = ensure_config ( config )
callback_manager = get_async_callback_manager_for_config ( config )
# start the root run
run_manager = await callback_manager . on_chain_start ( dumpd ( self ) , input )
run_manager = await callback_manager . on_chain_start (
dumpd ( self ) , input , name = config . get ( " run_name " )
)
first_error = None
for runnable in self . runnables :
@ -590,7 +591,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
from langchain . callbacks . manager import CallbackManager
# setup callbacks
configs = self . _ get_config_list( config , len ( inputs ) )
configs = get_config_list( config , len ( inputs ) )
callback_managers = [
CallbackManager . configure (
inheritable_callbacks = config . get ( " callbacks " ) ,
@ -606,9 +607,11 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
# start the root runs, one per input
run_managers = [
cm . on_chain_start (
dumpd ( self ) , input if isinstance ( input , dict ) else { " input " : input }
dumpd ( self ) ,
input if isinstance ( input , dict ) else { " input " : input } ,
name = config . get ( " run_name " ) ,
)
for cm , input in zip ( callback_managers , inputs )
for cm , input , config in zip ( callback_managers , input s, config s)
]
first_error = None
@ -648,7 +651,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
from langchain . callbacks . manager import AsyncCallbackManager
# setup callbacks
configs = self . _ get_config_list( config , len ( inputs ) )
configs = get_config_list( config , len ( inputs ) )
callback_managers = [
AsyncCallbackManager . configure (
inheritable_callbacks = config . get ( " callbacks " ) ,
@ -664,8 +667,12 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
# start the root runs, one per input
run_managers : List [ AsyncCallbackManagerForChainRun ] = await asyncio . gather (
* (
cm . on_chain_start ( dumpd ( self ) , input )
for cm , input in zip ( callback_managers , inputs )
cm . on_chain_start (
dumpd ( self ) ,
input ,
name = config . get ( " run_name " ) ,
)
for cm , input , config in zip ( callback_managers , inputs , configs )
)
)
@ -770,7 +777,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config ( config )
callback_manager = get_callback_manager_for_config ( config )
# start the root run
run_manager = callback_manager . on_chain_start ( dumpd ( self ) , input )
run_manager = callback_manager . on_chain_start (
dumpd ( self ) , input , name = config . get ( " run_name " )
)
# invoke all steps in sequence
try :
@ -798,7 +807,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config ( config )
callback_manager = get_async_callback_manager_for_config ( config )
# start the root run
run_manager = await callback_manager . on_chain_start ( dumpd ( self ) , input )
run_manager = await callback_manager . on_chain_start (
dumpd ( self ) , input , name = config . get ( " run_name " )
)
# invoke all steps in sequence
try :
@ -825,7 +836,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
from langchain . callbacks . manager import CallbackManager
# setup callbacks
configs = self . _ get_config_list( config , len ( inputs ) )
configs = get_config_list( config , len ( inputs ) )
callback_managers = [
CallbackManager . configure (
inheritable_callbacks = config . get ( " callbacks " ) ,
@ -840,8 +851,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
]
# start the root runs, one per input
run_managers = [
cm . on_chain_start ( dumpd ( self ) , input )
for cm , input in zip ( callback_managers , inputs )
cm . on_chain_start (
dumpd ( self ) ,
input ,
name = config . get ( " run_name " ) ,
)
for cm , input , config in zip ( callback_managers , inputs , configs )
]
# invoke
@ -876,7 +891,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
)
# setup callbacks
configs = self . _ get_config_list( config , len ( inputs ) )
configs = get_config_list( config , len ( inputs ) )
callback_managers = [
AsyncCallbackManager . configure (
inheritable_callbacks = config . get ( " callbacks " ) ,
@ -892,8 +907,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# start the root runs, one per input
run_managers : List [ AsyncCallbackManagerForChainRun ] = await asyncio . gather (
* (
cm . on_chain_start ( dumpd ( self ) , input )
for cm , input in zip ( callback_managers , inputs )
cm . on_chain_start (
dumpd ( self ) ,
input ,
name = config . get ( " run_name " ) ,
)
for cm , input , config in zip ( callback_managers , inputs , configs )
)
)
@ -929,7 +948,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config ( config )
callback_manager = get_callback_manager_for_config ( config )
# start the root run
run_manager = callback_manager . on_chain_start ( dumpd ( self ) , input )
run_manager = callback_manager . on_chain_start (
dumpd ( self ) , input , name = config . get ( " run_name " )
)
steps = [ self . first ] + self . middle + [ self . last ]
streaming_start_index = 0
@ -996,7 +1017,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config ( config )
callback_manager = get_async_callback_manager_for_config ( config )
# start the root run
run_manager = await callback_manager . on_chain_start ( dumpd ( self ) , input )
run_manager = await callback_manager . on_chain_start (
dumpd ( self ) , input , name = config . get ( " run_name " )
)
steps = [ self . first ] + self . middle + [ self . last ]
streaming_start_index = len ( steps ) - 1
@ -1127,7 +1150,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
local_metadata = None ,
)
# start the root run
run_manager = callback_manager . on_chain_start ( dumpd ( self ) , input )
run_manager = callback_manager . on_chain_start (
dumpd ( self ) , input , name = config . get ( " run_name " )
)
# gather results from all steps
try :
@ -1166,7 +1191,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
config = ensure_config ( config )
callback_manager = get_async_callback_manager_for_config ( config )
# start the root run
run_manager = await callback_manager . on_chain_start ( dumpd ( self ) , input )
run_manager = await callback_manager . on_chain_start (
dumpd ( self ) , input , name = config . get ( " run_name " )
)
# gather results from all steps
try :
@ -1479,6 +1506,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
kwargs : Mapping [ str , Any ]
config : Mapping [ str , Any ] = Field ( default_factory = dict )
class Config :
arbitrary_types_allowed = True
@ -1490,8 +1519,31 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def lc_namespace ( self ) - > List [ str ] :
return self . __class__ . __module__ . split ( " . " ) [ : - 1 ]
def _merge_config ( self , config : Optional [ RunnableConfig ] ) - > RunnableConfig :
copy = cast ( RunnableConfig , dict ( self . config ) )
if config :
for key in config :
# Even though the keys aren't literals this is correct
# because both dicts are same type
copy [ key ] = config [ key ] or copy . get ( key ) # type: ignore
return copy
def bind ( self , * * kwargs : Any ) - > Runnable [ Input , Output ] :
return self . __class__ ( bound = self . bound , kwargs = { * * self . kwargs , * * kwargs } )
return self . __class__ (
bound = self . bound , config = self . config , kwargs = { * * self . kwargs , * * kwargs }
)
def with_config (
self ,
config : Optional [ RunnableConfig ] = None ,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
* * kwargs : Any ,
) - > Runnable [ Input , Output ] :
return self . __class__ (
bound = self . bound ,
kwargs = self . kwargs ,
config = { * * self . config , * * ( config or { } ) , * * kwargs } ,
)
def invoke (
self ,
@ -1499,7 +1551,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config : Optional [ RunnableConfig ] = None ,
* * kwargs : Optional [ Any ] ,
) - > Output :
return self . bound . invoke ( input , config , * * { * * self . kwargs , * * kwargs } )
return self . bound . invoke (
input ,
self . _merge_config ( config ) ,
* * { * * self . kwargs , * * kwargs } ,
)
async def ainvoke (
self ,
@ -1507,7 +1563,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config : Optional [ RunnableConfig ] = None ,
* * kwargs : Optional [ Any ] ,
) - > Output :
return await self . bound . ainvoke ( input , config , * * { * * self . kwargs , * * kwargs } )
return await self . bound . ainvoke (
input ,
self . _merge_config ( config ) ,
* * { * * self . kwargs , * * kwargs } ,
)
def batch (
self ,
@ -1515,7 +1575,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config : Optional [ Union [ RunnableConfig , List [ RunnableConfig ] ] ] = None ,
* * kwargs : Optional [ Any ] ,
) - > List [ Output ] :
return self . bound . batch ( inputs , config , * * { * * self . kwargs , * * kwargs } )
if isinstance ( config , list ) :
configs = cast (
List [ RunnableConfig ] , [ self . _merge_config ( conf ) for conf in config ]
)
else :
configs = [
patch_config ( self . _merge_config ( config ) , deep_copy_locals = True )
for _ in range ( len ( inputs ) )
]
return self . bound . batch ( inputs , configs , * * { * * self . kwargs , * * kwargs } )
async def abatch (
self ,
@ -1523,7 +1592,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config : Optional [ Union [ RunnableConfig , List [ RunnableConfig ] ] ] = None ,
* * kwargs : Optional [ Any ] ,
) - > List [ Output ] :
return await self . bound . abatch ( inputs , config , * * { * * self . kwargs , * * kwargs } )
if isinstance ( config , list ) :
configs = cast (
List [ RunnableConfig ] , [ self . _merge_config ( conf ) for conf in config ]
)
else :
configs = [
patch_config ( self . _merge_config ( config ) , deep_copy_locals = True )
for _ in range ( len ( inputs ) )
]
return await self . bound . abatch ( inputs , configs , * * { * * self . kwargs , * * kwargs } )
def stream (
self ,
@ -1531,7 +1609,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config : Optional [ RunnableConfig ] = None ,
* * kwargs : Optional [ Any ] ,
) - > Iterator [ Output ] :
yield from self . bound . stream ( input , config , * * { * * self . kwargs , * * kwargs } )
yield from self . bound . stream (
input ,
self . _merge_config ( config ) ,
* * { * * self . kwargs , * * kwargs } ,
)
async def astream (
self ,
@ -1540,7 +1622,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
* * kwargs : Optional [ Any ] ,
) - > AsyncIterator [ Output ] :
async for item in self . bound . astream (
input , config , * * { * * self . kwargs , * * kwargs }
input ,
self . _merge_config ( config ) ,
* * { * * self . kwargs , * * kwargs } ,
) :
yield item
@ -1550,7 +1634,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config : Optional [ RunnableConfig ] = None ,
* * kwargs : Any ,
) - > Iterator [ Output ] :
yield from self . bound . transform ( input , config , * * { * * self . kwargs , * * kwargs } )
yield from self . bound . transform (
input ,
self . _merge_config ( config ) ,
* * { * * self . kwargs , * * kwargs } ,
)
async def atransform (
self ,
@ -1559,11 +1647,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
* * kwargs : Any ,
) - > AsyncIterator [ Output ] :
async for item in self . bound . atransform (
input , config , * * { * * self . kwargs , * * kwargs }
input ,
self . _merge_config ( config ) ,
* * { * * self . kwargs , * * kwargs } ,
) :
yield item
RunnableBinding . update_forward_refs ( RunnableConfig = RunnableConfig )
def coerce_to_runnable (
thing : Union [
Runnable [ Input , Output ] ,