@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from abc import ABC , abstractmethod
from concurrent . futures import ThreadPoolExecutor
from itertools import tee
from typing import (
Any ,
AsyncIterator ,
@ -29,6 +30,7 @@ from pydantic import Field
from langchain . callbacks . base import BaseCallbackManager , Callbacks
from langchain . load . dump import dumpd
from langchain . load . serializable import Serializable
from langchain . utils . aiter import atee , py_anext
async def _gated_coro ( semaphore : asyncio . Semaphore , coro : Coroutine ) - > Any :
@ -92,6 +94,8 @@ class Runnable(Generic[Input, Output], ABC):
) - > RunnableSequence [ Other , Output ] :
return RunnableSequence ( first = _coerce_to_runnable ( other ) , last = self )
""" --- Public API --- """
@abstractmethod
def invoke ( self , input : Input , config : Optional [ RunnableConfig ] = None ) - > Output :
. . .
@ -99,6 +103,10 @@ class Runnable(Generic[Input, Output], ABC):
async def ainvoke (
self , input : Input , config : Optional [ RunnableConfig ] = None
) - > Output :
"""
Default implementation of ainvoke , which calls invoke in a thread pool .
Subclasses should override this method if they can run asynchronously .
"""
return await asyncio . get_running_loop ( ) . run_in_executor (
None , self . invoke , input , config
)
@ -110,6 +118,10 @@ class Runnable(Generic[Input, Output], ABC):
* ,
max_concurrency : Optional [ int ] = None ,
) - > List [ Output ] :
"""
Default implementation of batch , which calls invoke N times .
Subclasses should override this method if they can batch more efficiently .
"""
configs = self . _get_config_list ( config , len ( inputs ) )
# If there's only one input, don't bother with the executor
@ -126,6 +138,10 @@ class Runnable(Generic[Input, Output], ABC):
* ,
max_concurrency : Optional [ int ] = None ,
) - > List [ Output ] :
"""
Default implementation of abatch , which calls ainvoke N times .
Subclasses should override this method if they can batch more efficiently .
"""
configs = self . _get_config_list ( config , len ( inputs ) )
coros = map ( self . ainvoke , inputs , configs )
@ -134,22 +150,90 @@ class Runnable(Generic[Input, Output], ABC):
def stream (
self , input : Input , config : Optional [ RunnableConfig ] = None
) - > Iterator [ Output ] :
"""
Default implementation of stream , which calls invoke .
Subclasses should override this method if they support streaming output .
"""
yield self . invoke ( input , config )
async def astream (
self , input : Input , config : Optional [ RunnableConfig ] = None
) - > AsyncIterator [ Output ] :
"""
Default implementation of astream , which calls ainvoke .
Subclasses should override this method if they support streaming output .
"""
yield await self . ainvoke ( input , config )
def transform (
self , input : Iterator [ Input ] , config : Optional [ RunnableConfig ] = None
) - > Iterator [ Output ] :
"""
Default implementation of transform , which buffers input and then calls stream .
Subclasses should override this method if they can start producing output while
input is still being generated .
"""
final : Union [ Input , None ] = None
for chunk in input :
if final is None :
final = chunk
else :
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final + = chunk # type: ignore[operator]
if final :
yield from self . stream ( final , config )
async def atransform (
self , input : AsyncIterator [ Input ] , config : Optional [ RunnableConfig ] = None
) - > AsyncIterator [ Output ] :
"""
Default implementation of atransform , which buffers input and calls astream .
Subclasses should override this method if they can start producing output while
input is still being generated .
"""
final : Union [ Input , None ] = None
async for chunk in input :
if final is None :
final = chunk
else :
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final + = chunk # type: ignore[operator]
if final :
async for output in self . astream ( final , config ) :
yield output
def bind ( self , * * kwargs : Any ) - > Runnable [ Input , Output ] :
"""
Bind arguments to a Runnable , returning a new Runnable .
"""
return RunnableBinding ( bound = self , kwargs = kwargs )
def with_fallbacks (
self ,
fallbacks : Sequence [ Runnable [ Input , Output ] ] ,
* ,
exceptions_to_handle : Tuple [ Type [ BaseException ] ] = ( Exception , ) ,
) - > RunnableWithFallbacks [ Input , Output ] :
return RunnableWithFallbacks (
runnable = self ,
fallbacks = fallbacks ,
exceptions_to_handle = exceptions_to_handle ,
)
""" --- Helper methods for Subclasses --- """
def _get_config_list (
self , config : Optional [ Union [ RunnableConfig , List [ RunnableConfig ] ] ] , length : int
) - > List [ RunnableConfig ] :
"""
Helper method to get a list of configs from a single config or a list of
configs , useful for subclasses overriding batch ( ) or abatch ( ) .
"""
if isinstance ( config , list ) and len ( config ) != length :
raise ValueError (
f " config must be a list of the same length as inputs, "
@ -169,6 +253,8 @@ class Runnable(Generic[Input, Output], ABC):
config : Optional [ RunnableConfig ] ,
run_type : Optional [ str ] = None ,
) - > Output :
""" Helper method to transform an Input value to an Output value,
with callbacks . Use this method to implement invoke ( ) in subclasses . """
from langchain . callbacks . manager import CallbackManager
config = config or { }
@ -200,6 +286,8 @@ class Runnable(Generic[Input, Output], ABC):
config : Optional [ RunnableConfig ] ,
run_type : Optional [ str ] = None ,
) - > Output :
""" Helper method to transform an Input value to an Output value,
with callbacks . Use this method to implement ainvoke ( ) in subclasses . """
from langchain . callbacks . manager import AsyncCallbackManager
config = config or { }
@ -224,20 +312,154 @@ class Runnable(Generic[Input, Output], ABC):
)
return output
def with_fallbacks (
def _transform_stream_with_config (
self ,
fallbacks : Sequence [ Runnable [ Input , Output ] ] ,
* ,
exceptions_to_handle : Tuple [ Type [ BaseException ] ] = ( Exception , ) ,
) - > RunnableWithFallbacks [ Input , Output ] :
return RunnableWithFallbacks (
runnable = self ,
fallbacks = fallbacks ,
exceptions_to_handle = exceptions_to_handle ,
input : Iterator [ Input ] ,
transformer : Callable [ [ Iterator [ Input ] ] , Iterator [ Output ] ] ,
config : Optional [ RunnableConfig ] ,
run_type : Optional [ str ] = None ,
) - > Iterator [ Output ] :
""" Helper method to transform an Iterator of Input values into an Iterator of
Output values , with callbacks .
Use this to implement ` stream ( ) ` or ` transform ( ) ` in Runnable subclasses . """
from langchain . callbacks . manager import CallbackManager
# tee the input so we can iterate over it twice
input_for_tracing , input_for_transform = tee ( input , 2 )
# Start the input iterator to ensure the input runnable starts before this one
final_input : Optional [ Input ] = next ( input_for_tracing , None )
final_input_supported = True
final_output : Optional [ Output ] = None
final_output_supported = True
config = config or { }
callback_manager = CallbackManager . configure (
inheritable_callbacks = config . get ( " callbacks " ) ,
inheritable_tags = config . get ( " tags " ) ,
inheritable_metadata = config . get ( " metadata " ) ,
)
run_manager = callback_manager . on_chain_start (
dumpd ( self ) ,
{ " input " : " " } ,
run_type = run_type ,
)
try :
for chunk in transformer ( input_for_transform ) :
yield chunk
if final_output_supported :
if final_output is None :
final_output = chunk
else :
try :
final_output + = chunk # type: ignore[operator]
except TypeError :
final_output = None
final_output_supported = False
for ichunk in input_for_tracing :
if final_input_supported :
if final_input is None :
final_input = ichunk
else :
try :
final_input + = ichunk # type: ignore[operator]
except TypeError :
final_input = None
final_input_supported = False
except Exception as e :
run_manager . on_chain_error (
e ,
inputs = final_input
if isinstance ( final_input , dict )
else { " input " : final_input } ,
)
raise
else :
run_manager . on_chain_end (
final_output
if isinstance ( final_output , dict )
else { " output " : final_output } ,
inputs = final_input
if isinstance ( final_input , dict )
else { " input " : final_input } ,
)
async def _atransform_stream_with_config (
self ,
input : AsyncIterator [ Input ] ,
transformer : Callable [ [ AsyncIterator [ Input ] ] , AsyncIterator [ Output ] ] ,
config : Optional [ RunnableConfig ] ,
run_type : Optional [ str ] = None ,
) - > AsyncIterator [ Output ] :
""" Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values , with callbacks .
Use this to implement ` astream ( ) ` or ` atransform ( ) ` in Runnable subclasses . """
from langchain . callbacks . manager import AsyncCallbackManager
# tee the input so we can iterate over it twice
input_for_tracing , input_for_transform = atee ( input , 2 )
# Start the input iterator to ensure the input runnable starts before this one
final_input : Optional [ Input ] = await py_anext ( input_for_tracing , None )
final_input_supported = True
final_output : Optional [ Output ] = None
final_output_supported = True
config = config or { }
callback_manager = AsyncCallbackManager . configure (
inheritable_callbacks = config . get ( " callbacks " ) ,
inheritable_tags = config . get ( " tags " ) ,
inheritable_metadata = config . get ( " metadata " ) ,
)
run_manager = await callback_manager . on_chain_start (
dumpd ( self ) ,
{ " input " : " " } ,
run_type = run_type ,
)
try :
async for chunk in transformer ( input_for_transform ) :
yield chunk
if final_output_supported :
if final_output is None :
final_output = chunk
else :
try :
final_output + = chunk # type: ignore[operator]
except TypeError :
final_output = None
final_output_supported = False
async for ichunk in input_for_tracing :
if final_input_supported :
if final_input is None :
final_input = ichunk
else :
try :
final_input + = ichunk # type: ignore[operator]
except TypeError :
final_input = None
final_input_supported = False
except Exception as e :
await run_manager . on_chain_error (
e ,
inputs = final_input
if isinstance ( final_input , dict )
else { " input " : final_input } ,
)
raise
else :
await run_manager . on_chain_end (
final_output
if isinstance ( final_output , dict )
else { " output " : final_output } ,
inputs = final_input
if isinstance ( final_input , dict )
else { " input " : final_input } ,
)
class RunnableWithFallbacks ( Serializable , Runnable [ Input , Output ] ) :
"""
A Runnable that can fallback to other Runnables if it fails .
"""
runnable : Runnable [ Input , Output ]
fallbacks : Sequence [ Runnable [ Input , Output ] ]
exceptions_to_handle : Tuple [ Type [ BaseException ] ] = ( Exception , )
@ -467,6 +689,10 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
class RunnableSequence ( Serializable , Runnable [ Input , Output ] ) :
"""
A sequence of runnables , where the output of each is the input of the next .
"""
first : Runnable [ Input , Any ]
middle : List [ Runnable [ Any , Any ] ] = Field ( default_factory = list )
last : Runnable [ Any , Output ]
@ -738,9 +964,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
dumpd ( self ) , input if isinstance ( input , dict ) else { " input " : input }
)
steps = [ self . first ] + self . middle + [ self . last ]
streaming_start_index = 0
for i in range ( len ( steps ) - 1 , 0 , - 1 ) :
if type ( steps [ i ] ) . transform != Runnable . transform :
streaming_start_index = i - 1
else :
break
# invoke the first steps
try :
for step in [ self . first ] + self . middle :
for step in steps [ 0 : streaming_start_index ] :
input = step . invoke (
input ,
# mark each step as a child run
@ -750,15 +985,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
run_manager . on_chain_error ( e )
raise
# stream the last step
# stream the last step s
final : Union [ Output , None ] = None
final_supported = True
try :
for output in self . last . stream (
input ,
# mark the last step as a child run
_patch_config ( config , run_manager . get_child ( ) ) ,
) :
# stream the first of the last steps with non-streaming input
final_pipeline = steps [ streaming_start_index ] . stream (
input , _patch_config ( config , run_manager . get_child ( ) )
)
# stream the rest of the last steps with streaming input
for step in steps [ streaming_start_index + 1 : ] :
final_pipeline = step . transform (
final_pipeline , _patch_config ( config , run_manager . get_child ( ) )
)
for output in final_pipeline :
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported :
@ -801,9 +1041,18 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
dumpd ( self ) , input if isinstance ( input , dict ) else { " input " : input }
)
steps = [ self . first ] + self . middle + [ self . last ]
streaming_start_index = len ( steps ) - 1
for i in range ( len ( steps ) - 1 , 0 , - 1 ) :
if type ( steps [ i ] ) . transform != Runnable . transform :
streaming_start_index = i - 1
else :
break
# invoke the first steps
try :
for step in [ self . first ] + self . middle :
for step in steps [ 0 : streaming_start_index ] :
input = await step . ainvoke (
input ,
# mark each step as a child run
@ -813,15 +1062,20 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
await run_manager . on_chain_error ( e )
raise
# stream the last step
# stream the last step s
final : Union [ Output , None ] = None
final_supported = True
try :
async for output in self . last . astream (
input ,
# mark the last step as a child run
_patch_config ( config , run_manager . get_child ( ) ) ,
) :
# stream the first of the last steps with non-streaming input
final_pipeline = steps [ streaming_start_index ] . astream (
input , _patch_config ( config , run_manager . get_child ( ) )
)
# stream the rest of the last steps with streaming input
for step in steps [ streaming_start_index + 1 : ] :
final_pipeline = step . atransform (
final_pipeline , _patch_config ( config , run_manager . get_child ( ) )
)
async for output in final_pipeline :
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported :
@ -845,6 +1099,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
class RunnableMap ( Serializable , Runnable [ Input , Dict [ str , Any ] ] ) :
"""
A runnable that runs a mapping of runnables in parallel ,
and returns a mapping of their outputs .
"""
steps : Mapping [ str , Runnable [ Input , Any ] ]
def __init__ (
@ -957,6 +1216,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
class RunnableLambda ( Runnable [ Input , Output ] ) :
"""
A runnable that runs a callable .
"""
def __init__ ( self , func : Callable [ [ Input ] , Output ] ) - > None :
if callable ( func ) :
self . func = func
@ -977,6 +1240,10 @@ class RunnableLambda(Runnable[Input, Output]):
class RunnablePassthrough ( Serializable , Runnable [ Input , Input ] ) :
"""
A runnable that passes through the input .
"""
@property
def lc_serializable ( self ) - > bool :
return True
@ -986,6 +1253,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
class RunnableBinding ( Serializable , Runnable [ Input , Output ] ) :
"""
A runnable that binds a runnable to a set of kwargs .
"""
bound : Runnable [ Input , Output ]
kwargs : Mapping [ str , Any ]
@ -1041,6 +1312,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
async for item in self . bound . astream ( input , config , * * self . kwargs ) :
yield item
def transform (
self , input : Iterator [ Input ] , config : Optional [ RunnableConfig ] = None
) - > Iterator [ Output ] :
yield from self . bound . transform ( input , config , * * self . kwargs )
async def atransform (
self , input : AsyncIterator [ Input ] , config : Optional [ RunnableConfig ] = None
) - > AsyncIterator [ Output ] :
async for item in self . bound . atransform ( input , config , * * self . kwargs ) :
yield item
class RouterInput ( TypedDict ) :
key : str
@ -1050,6 +1332,11 @@ class RouterInput(TypedDict):
class RouterRunnable (
Serializable , Generic [ Input , Output ] , Runnable [ RouterInput , Output ]
) :
"""
A runnable that routes to a set of runnables based on Input [ ' key ' ] .
Returns the output of the selected runnable .
"""
runnables : Mapping [ str , Runnable [ Input , Output ] ]
def __init__ ( self , runnables : Mapping [ str , Runnable [ Input , Output ] ] ) - > None :