@ -1,10 +1,12 @@
from __future__ import annotations
import asyncio
import collections
import inspect
import threading
from abc import ABC , abstractmethod
from concurrent . futures import FIRST_COMPLETED , wait
from contextvars import copy_context
from copy import deepcopy
from functools import wraps
from itertools import groupby , tee
@ -15,6 +17,7 @@ from typing import (
AsyncIterator ,
Awaitable ,
Callable ,
Coroutine ,
Dict ,
Generic ,
Iterator ,
@ -48,6 +51,7 @@ from langchain_core.runnables.config import (
merge_configs ,
patch_config ,
run_in_executor ,
var_child_runnable_config ,
)
from langchain_core . runnables . graph import Graph
from langchain_core . runnables . utils import (
@ -58,6 +62,7 @@ from langchain_core.runnables.utils import (
Input ,
Output ,
accepts_config ,
accepts_context ,
accepts_run_manager ,
gather_with_concurrency ,
get_function_first_arg_dict_keys ,
@ -950,8 +955,19 @@ class Runnable(Generic[Input, Output], ABC):
name = config . get ( " run_name " ) or self . get_name ( ) ,
)
try :
output = call_func_with_variable_args (
func , input , config , run_manager , * * kwargs
child_config = patch_config ( config , callbacks = run_manager . get_child ( ) )
context = copy_context ( )
context . run ( var_child_runnable_config . set , child_config )
output = cast (
Output ,
context . run (
call_func_with_variable_args ,
func , # type: ignore[arg-type]
input , # type: ignore[arg-type]
config ,
run_manager ,
* * kwargs ,
) ,
)
except BaseException as e :
run_manager . on_chain_error ( e )
@ -986,9 +1002,16 @@ class Runnable(Generic[Input, Output], ABC):
name = config . get ( " run_name " ) or self . get_name ( ) ,
)
try :
output = await acall_func_with_variable_args (
child_config = patch_config ( config , callbacks = run_manager . get_child ( ) )
context = copy_context ( )
context . run ( var_child_runnable_config . set , child_config )
coro = acall_func_with_variable_args (
func , input , config , run_manager , * * kwargs
)
if accepts_context ( asyncio . create_task ) :
output : Output = await asyncio . create_task ( coro , context = context ) # type: ignore
else :
output = await coro
except BaseException as e :
await run_manager . on_chain_error ( e )
raise
@ -1178,24 +1201,29 @@ class Runnable(Generic[Input, Output], ABC):
name = config . get ( " run_name " ) or self . get_name ( ) ,
)
try :
child_config = patch_config ( config , callbacks = run_manager . get_child ( ) )
if accepts_config ( transformer ) :
kwargs [ " config " ] = patch_config (
config , callbacks = run_manager . get_child ( )
)
kwargs [ " config " ] = child_config
if accepts_run_manager ( transformer ) :
kwargs [ " run_manager " ] = run_manager
iterator = transformer ( input_for_transform , * * kwargs ) # type: ignore[call-arg]
for chunk in iterator :
yield chunk
if final_output_supported :
if final_output is None :
final_output = chunk
else :
try :
final_output = final_output + chunk # type: ignore
except TypeError :
final_output = None
final_output_supported = False
context = copy_context ( )
context . run ( var_child_runnable_config . set , child_config )
iterator = context . run ( transformer , input_for_transform , * * kwargs ) # type: ignore[arg-type]
try :
while True :
chunk : Output = context . run ( next , iterator ) # type: ignore
yield chunk
if final_output_supported :
if final_output is None :
final_output = chunk
else :
try :
final_output = final_output + chunk # type: ignore
except TypeError :
final_output = None
final_output_supported = False
except StopIteration :
pass
for ichunk in input_for_tracing :
if final_input_supported :
if final_input is None :
@ -1254,24 +1282,35 @@ class Runnable(Generic[Input, Output], ABC):
name = config . get ( " run_name " ) or self . get_name ( ) ,
)
try :
child_config = patch_config ( config , callbacks = run_manager . get_child ( ) )
if accepts_config ( transformer ) :
kwargs [ " config " ] = patch_config (
config , callbacks = run_manager . get_child ( )
)
kwargs [ " config " ] = child_config
if accepts_run_manager ( transformer ) :
kwargs [ " run_manager " ] = run_manager
iterator = transformer ( input_for_transform , * * kwargs ) # type: ignore[call-arg]
async for chunk in iterator :
yield chunk
if final_output_supported :
if final_output is None :
final_output = chunk
context = copy_context ( )
context . run ( var_child_runnable_config . set , child_config )
iterator = context . run ( transformer , input_for_transform , * * kwargs ) # type: ignore[arg-type]
try :
while True :
if accepts_context ( asyncio . create_task ) :
chunk : Output = await asyncio . create_task ( # type: ignore[call-arg]
py_anext ( iterator ) , # type: ignore[arg-type]
context = context ,
)
else :
try :
final_output = final_output + chunk # type: ignore
except TypeError :
final_output = None
final_output_supported = False
chunk = cast ( Output , await py_anext ( iterator ) )
yield chunk
if final_output_supported :
if final_output is None :
final_output = chunk
else :
try :
final_output = final_output + chunk # type: ignore
except TypeError :
final_output = None
final_output_supported = False
except StopAsyncIteration :
pass
async for ichunk in input_for_tracing :
if final_input_supported :
if final_input is None :
@ -1472,7 +1511,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
. . code - block : : python
from langchain_core . output_parsers . json import SimpleJsonOutputParser
from langchain _core . chat_models . openai import ChatOpenAI
from langchain . chat_models . openai import ChatOpenAI
prompt = PromptTemplate . from_template (
' In JSON format, give me a list of {topic} and their '
@ -2482,17 +2521,25 @@ class RunnableGenerator(Runnable[Input, Output]):
) - > None :
if atransform is not None :
self . _atransform = atransform
func_for_name : Callable = atransform
if inspect . isasyncgenfunction ( transform ) :
self . _atransform = transform
func_for_name = transform
elif inspect . isgeneratorfunction ( transform ) :
self . _transform = transform
func_for_name = transform
else :
raise TypeError (
" Expected a generator function type for `transform`. "
f " Instead got an unsupported type: { type ( transform ) } "
)
try :
self . name = func_for_name . __name__
except AttributeError :
pass
@property
def InputType ( self ) - > Any :
func = getattr ( self , " _transform " , None ) or getattr ( self , " _atransform " )
@ -2646,12 +2693,14 @@ class RunnableLambda(Runnable[Input, Output]):
func : Union [
Union [
Callable [ [ Input ] , Output ] ,
Callable [ [ Input ] , Iterator [ Output ] ] ,
Callable [ [ Input , RunnableConfig ] , Output ] ,
Callable [ [ Input , CallbackManagerForChainRun ] , Output ] ,
Callable [ [ Input , CallbackManagerForChainRun , RunnableConfig ] , Output ] ,
] ,
Union [
Callable [ [ Input ] , Awaitable [ Output ] ] ,
Callable [ [ Input ] , AsyncIterator [ Output ] ] ,
Callable [ [ Input , RunnableConfig ] , Awaitable [ Output ] ] ,
Callable [ [ Input , AsyncCallbackManagerForChainRun ] , Awaitable [ Output ] ] ,
Callable [
@ -2663,6 +2712,7 @@ class RunnableLambda(Runnable[Input, Output]):
afunc : Optional [
Union [
Callable [ [ Input ] , Awaitable [ Output ] ] ,
Callable [ [ Input ] , AsyncIterator [ Output ] ] ,
Callable [ [ Input , RunnableConfig ] , Awaitable [ Output ] ] ,
Callable [ [ Input , AsyncCallbackManagerForChainRun ] , Awaitable [ Output ] ] ,
Callable [
@ -2685,7 +2735,7 @@ class RunnableLambda(Runnable[Input, Output]):
self . afunc = afunc
func_for_name : Callable = afunc
if inspect . iscoroutinefunction ( func ) :
if inspect . iscoroutinefunction ( func ) or inspect . isasyncgenfunction ( func ) :
if afunc is not None :
raise TypeError (
" Func was provided as a coroutine function, but afunc was "
@ -2767,11 +2817,16 @@ class RunnableLambda(Runnable[Input, Output]):
func = getattr ( self , " func " , None ) or getattr ( self , " afunc " )
try :
sig = inspect . signature ( func )
return (
sig . return_annotation
if sig . return_annotation != inspect . Signature . empty
else Any
)
if sig . return_annotation != inspect . Signature . empty :
# unwrap iterator types
if getattr ( sig . return_annotation , " __origin__ " , None ) in (
collections . abc . Iterator ,
collections . abc . AsyncIterator ,
) :
return getattr ( sig . return_annotation , " __args__ " , ( Any , ) ) [ 0 ]
return sig . return_annotation
else :
return Any
except ValueError :
return Any
@ -2848,9 +2903,26 @@ class RunnableLambda(Runnable[Input, Output]):
config : RunnableConfig ,
* * kwargs : Any ,
) - > Output :
output = call_func_with_variable_args (
self . func , input , config , run_manager , * * kwargs
)
if inspect . isgeneratorfunction ( self . func ) :
output : Optional [ Output ] = None
for chunk in call_func_with_variable_args (
cast ( Callable [ [ Input ] , Iterator [ Output ] ] , self . func ) ,
input ,
config ,
run_manager ,
* * kwargs ,
) :
if output is None :
output = chunk
else :
try :
output = output + chunk # type: ignore[operator]
except TypeError :
output = chunk
else :
output = call_func_with_variable_args (
self . func , input , config , run_manager , * * kwargs
)
# If the output is a runnable, invoke it
if isinstance ( output , Runnable ) :
recursion_limit = config [ " recursion_limit " ]
@ -2866,7 +2938,7 @@ class RunnableLambda(Runnable[Input, Output]):
recursion_limit = recursion_limit - 1 ,
) ,
)
return output
return cast( Output , output)
async def _ainvoke (
self ,
@ -2878,16 +2950,69 @@ class RunnableLambda(Runnable[Input, Output]):
if hasattr ( self , " afunc " ) :
afunc = self . afunc
else :
if inspect . isgeneratorfunction ( self . func ) :
def func (
input : Input ,
run_manager : AsyncCallbackManagerForChainRun ,
config : RunnableConfig ,
) - > Output :
output : Optional [ Output ] = None
for chunk in call_func_with_variable_args (
cast ( Callable [ [ Input ] , Iterator [ Output ] ] , self . func ) ,
input ,
config ,
run_manager . get_sync ( ) ,
* * kwargs ,
) :
if output is None :
output = chunk
else :
try :
output = output + chunk # type: ignore[operator]
except TypeError :
output = chunk
return cast ( Output , output )
else :
def func (
input : Input ,
run_manager : AsyncCallbackManagerForChainRun ,
config : RunnableConfig ,
) - > Output :
return call_func_with_variable_args (
self . func , input , config , run_manager . get_sync ( ) , * * kwargs
)
@wraps ( self . func )
@wraps ( func )
async def f ( * args , * * kwargs ) : # type: ignore[no-untyped-def]
return await run_in_executor ( config , self . func , * args , * * kwargs )
return await run_in_executor ( config , func , * args , * * kwargs )
afunc = f
output = await acall_func_with_variable_args (
afunc , input , config , run_manager , * * kwargs
)
if inspect . isasyncgenfunction ( afunc ) :
output : Optional [ Output ] = None
async for chunk in cast (
AsyncIterator [ Output ] ,
acall_func_with_variable_args (
cast ( Callable , afunc ) ,
input ,
config ,
run_manager ,
* * kwargs ,
) ,
) :
if output is None :
output = chunk
else :
try :
output = output + chunk # type: ignore[operator]
except TypeError :
output = chunk
else :
output = await acall_func_with_variable_args (
cast ( Callable , afunc ) , input , config , run_manager , * * kwargs
)
# If the output is a runnable, invoke it
if isinstance ( output , Runnable ) :
recursion_limit = config [ " recursion_limit " ]
@ -2903,7 +3028,7 @@ class RunnableLambda(Runnable[Input, Output]):
recursion_limit = recursion_limit - 1 ,
) ,
)
return output
return cast( Output , output)
def _config (
self , config : Optional [ RunnableConfig ] , callable : Callable [ . . . , Any ]
@ -2972,9 +3097,23 @@ class RunnableLambda(Runnable[Input, Output]):
except TypeError :
final = ichunk
output = call_func_with_variable_args (
self . func , cast ( Input , final ) , config , run_manager , * * kwargs
)
if inspect . isgeneratorfunction ( self . func ) :
output : Optional [ Output ] = None
for chunk in call_func_with_variable_args (
self . func , cast ( Input , final ) , config , run_manager , * * kwargs
) :
yield chunk
if output is None :
output = chunk
else :
try :
output = output + chunk
except TypeError :
output = chunk
else :
output = call_func_with_variable_args (
self . func , cast ( Input , final ) , config , run_manager , * * kwargs
)
# If the output is a runnable, use its stream output
if isinstance ( output , Runnable ) :
@ -2993,9 +3132,9 @@ class RunnableLambda(Runnable[Input, Output]):
) ,
) :
yield chunk
el se:
el if not inspect . isgeneratorfunction ( self. func ) :
# Otherwise, just yield it
yield output
yield cast( Output , output)
def transform (
self ,
@ -3030,6 +3169,7 @@ class RunnableLambda(Runnable[Input, Output]):
input : AsyncIterator [ Input ] ,
run_manager : AsyncCallbackManagerForChainRun ,
config : RunnableConfig ,
* * kwargs : Any ,
) - > AsyncIterator [ Output ] :
final : Optional [ Input ] = None
async for ichunk in input :
@ -3044,16 +3184,51 @@ class RunnableLambda(Runnable[Input, Output]):
if hasattr ( self , " afunc " ) :
afunc = self . afunc
else :
if inspect . isgeneratorfunction ( self . func ) :
raise TypeError (
" Cannot stream from a generator function asynchronously. "
" Use .stream() instead. "
)
def func (
input : Input ,
run_manager : AsyncCallbackManagerForChainRun ,
config : RunnableConfig ,
) - > Output :
return call_func_with_variable_args (
self . func , input , config , run_manager . get_sync ( ) , * * kwargs
)
@wraps ( self . func )
@wraps ( func )
async def f ( * args , * * kwargs ) : # type: ignore[no-untyped-def]
return await run_in_executor ( config , self . func , * args , * * kwargs )
return await run_in_executor ( config , func , * args , * * kwargs )
afunc = f
output = await acall_func_with_variable_args (
afunc , cast ( Input , final ) , config , run_manager
)
if inspect . isasyncgenfunction ( afunc ) :
output : Optional [ Output ] = None
async for chunk in cast (
AsyncIterator [ Output ] ,
acall_func_with_variable_args (
cast ( Callable , afunc ) ,
cast ( Input , final ) ,
config ,
run_manager ,
* * kwargs ,
) ,
) :
yield chunk
if output is None :
output = chunk
else :
try :
output = output + chunk # type: ignore[operator]
except TypeError :
output = chunk
else :
output = await acall_func_with_variable_args (
cast ( Callable , afunc ) , cast ( Input , final ) , config , run_manager , * * kwargs
)
# If the output is a runnable, use its astream output
if isinstance ( output , Runnable ) :
@ -3072,9 +3247,9 @@ class RunnableLambda(Runnable[Input, Output]):
) ,
) :
yield chunk
el se:
el if not in sp ect. isasyncgenfunction ( afunc ) :
# Otherwise, just yield it
yield output
yield cast( Output , output)
async def atransform (
self ,
@ -3699,3 +3874,69 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
f " Expected a Runnable, callable or dict. "
f " Instead got an unsupported type: { type ( thing ) } "
)
@overload
def chain (
func : Callable [ [ Input ] , Coroutine [ Any , Any , Output ] ] ,
) - > Runnable [ Input , Output ] :
. . .
@overload
def chain (
func : Callable [ [ Input ] , Iterator [ Output ] ] ,
) - > Runnable [ Input , Output ] :
. . .
@overload
def chain (
func : Callable [ [ Input ] , AsyncIterator [ Output ] ] ,
) - > Runnable [ Input , Output ] :
. . .
@overload
def chain (
func : Callable [ [ Input ] , Output ] ,
) - > Runnable [ Input , Output ] :
. . .
def chain (
func : Union [
Callable [ [ Input ] , Output ] ,
Callable [ [ Input ] , Iterator [ Output ] ] ,
Callable [ [ Input ] , Coroutine [ Any , Any , Output ] ] ,
Callable [ [ Input ] , AsyncIterator [ Output ] ] ,
] ,
) - > Runnable [ Input , Output ] :
""" Decorate a function to make it a Runnable.
Sets the name of the runnable to the name of the function .
Any runnables called by the function will be traced as dependencies .
Args :
func : A callable .
Returns :
A Runnable .
Example :
. . code - block : : python
from langchain_core . runnables import chain
from langchain_core . prompts import PromptTemplate
from langchain . llms import OpenAI
@chain
def my_func ( fields ) :
prompt = PromptTemplate ( " Hello, {name} ! " )
llm = OpenAI ( )
formatted = prompt . invoke ( * * fields )
for chunk in llm . stream ( formatted ) :
yield chunk
"""
return RunnableLambda ( func )