@ -3981,6 +3981,140 @@ async def test_runnable_branch_abatch() -> None:
assert await branch . abatch ( [ 1 , 10 , 0 ] ) == [ 2 , 100 , - 1 ]
def test_runnable_branch_stream ( ) - > None :
""" Verify that stream works for RunnableBranch. """
llm_res = " i ' m a textbot "
# sleep to better simulate a real stream
llm = FakeStreamingListLLM ( responses = [ llm_res ] , sleep = 0.01 )
branch = RunnableBranch [ str , Any ] (
( lambda x : x == " hello " , llm ) ,
lambda x : x ,
)
assert list ( branch . stream ( " hello " ) ) == list ( llm_res )
assert list ( branch . stream ( " bye " ) ) == [ " bye " ]
def test_runnable_branch_stream_with_callbacks ( ) - > None :
""" Verify that stream works for RunnableBranch when using callbacks. """
tracer = FakeTracer ( )
def raise_value_error ( x : str ) - > Any :
""" Raise a value error. """
raise ValueError ( f " x is { x } " )
llm_res = " i ' m a textbot "
# sleep to better simulate a real stream
llm = FakeStreamingListLLM ( responses = [ llm_res ] , sleep = 0.01 )
branch = RunnableBranch [ str , Any ] (
( lambda x : x == " error " , raise_value_error ) ,
( lambda x : x == " hello " , llm ) ,
lambda x : x ,
)
config : RunnableConfig = { " callbacks " : [ tracer ] }
assert list ( branch . stream ( " hello " , config = config ) ) == list ( llm_res )
assert len ( tracer . runs ) == 1
assert tracer . runs [ 0 ] . error is None
assert tracer . runs [ 0 ] . outputs == { " output " : llm_res }
# Verify that the chain on error is invoked
with pytest . raises ( ValueError ) :
for _ in branch . stream ( " error " , config = config ) :
pass
assert len ( tracer . runs ) == 2
assert " ValueError( ' x is error ' ) " in str ( tracer . runs [ 1 ] . error )
assert tracer . runs [ 1 ] . outputs is None
assert list ( branch . stream ( " bye " , config = config ) ) == [ " bye " ]
assert len ( tracer . runs ) == 3
assert tracer . runs [ 2 ] . error is None
assert tracer . runs [ 2 ] . outputs == { " output " : " bye " }
async def test_runnable_branch_astream ( ) - > None :
""" Verify that astream works for RunnableBranch. """
llm_res = " i ' m a textbot "
# sleep to better simulate a real stream
llm = FakeStreamingListLLM ( responses = [ llm_res ] , sleep = 0.01 )
branch = RunnableBranch [ str , Any ] (
( lambda x : x == " hello " , llm ) ,
lambda x : x ,
)
assert [ _ async for _ in branch . astream ( " hello " ) ] == list ( llm_res )
assert [ _ async for _ in branch . astream ( " bye " ) ] == [ " bye " ]
# Verify that the async variant is used if available
async def condition ( x : str ) - > bool :
return x == " hello "
async def repeat ( x : str ) - > str :
return x + x
async def reverse ( x : str ) - > str :
return x [ : : - 1 ]
branch = RunnableBranch [ str , Any ] ( ( condition , repeat ) , llm )
assert [ _ async for _ in branch . astream ( " hello " ) ] == [ " hello " * 2 ]
assert [ _ async for _ in branch . astream ( " bye " ) ] == list ( llm_res )
branch = RunnableBranch [ str , Any ] ( ( condition , llm ) , reverse )
assert [ _ async for _ in branch . astream ( " hello " ) ] == list ( llm_res )
assert [ _ async for _ in branch . astream ( " bye " ) ] == [ " eyb " ]
async def test_runnable_branch_astream_with_callbacks ( ) - > None :
""" Verify that astream works for RunnableBranch when using callbacks. """
tracer = FakeTracer ( )
def raise_value_error ( x : str ) - > Any :
""" Raise a value error. """
raise ValueError ( f " x is { x } " )
llm_res = " i ' m a textbot "
# sleep to better simulate a real stream
llm = FakeStreamingListLLM ( responses = [ llm_res ] , sleep = 0.01 )
branch = RunnableBranch [ str , Any ] (
( lambda x : x == " error " , raise_value_error ) ,
( lambda x : x == " hello " , llm ) ,
lambda x : x ,
)
config : RunnableConfig = { " callbacks " : [ tracer ] }
assert [ _ async for _ in branch . astream ( " hello " , config = config ) ] == list ( llm_res )
assert len ( tracer . runs ) == 1
assert tracer . runs [ 0 ] . error is None
assert tracer . runs [ 0 ] . outputs == { " output " : llm_res }
# Verify that the chain on error is invoked
with pytest . raises ( ValueError ) :
async for _ in branch . astream ( " error " , config = config ) :
pass
assert len ( tracer . runs ) == 2
assert " ValueError( ' x is error ' ) " in str ( tracer . runs [ 1 ] . error )
assert tracer . runs [ 1 ] . outputs is None
assert [ _ async for _ in branch . astream ( " bye " , config = config ) ] == [ " bye " ]
assert len ( tracer . runs ) == 3
assert tracer . runs [ 2 ] . error is None
assert tracer . runs [ 2 ] . outputs == { " output " : " bye " }
@pytest.mark.skipif (
sys . version_info < ( 3 , 9 ) , reason = " Requires python version >= 3.9 to run. "
)