@ -1,4 +1,5 @@
from typing import Any , Dict , List , Optional
from operator import itemgetter
from typing import Any , Dict , List , Optional , Union
from uuid import UUID
import pytest
@ -176,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags = [ ] ,
callbacks = None ,
_locals = { } ,
recursion_limit = 10 ,
) ,
) ,
mocker . call (
@ -185,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags = [ ] ,
callbacks = None ,
_locals = { } ,
recursion_limit = 10 ,
) ,
) ,
]
@ -438,6 +441,50 @@ async def test_prompt_with_llm(
)
@pytest.mark.asyncio
@freeze_time ( " 2023-01-01 " )
async def test_prompt_with_llm_and_async_lambda (
mocker : MockerFixture , snapshot : SnapshotAssertion
) - > None :
prompt = (
SystemMessagePromptTemplate . from_template ( " You are a nice assistant. " )
+ " {question} "
)
llm = FakeListLLM ( responses = [ " foo " , " bar " ] )
async def passthrough ( input : Any ) - > Any :
return input
chain = prompt | llm | passthrough
assert isinstance ( chain , RunnableSequence )
assert chain . first == prompt
assert chain . middle == [ llm ]
assert chain . last == RunnableLambda ( func = passthrough )
assert dumps ( chain , pretty = True ) == snapshot
# Test invoke
prompt_spy = mocker . spy ( prompt . __class__ , " ainvoke " )
llm_spy = mocker . spy ( llm . __class__ , " ainvoke " )
tracer = FakeTracer ( )
assert (
await chain . ainvoke (
{ " question " : " What is your name? " } , dict ( callbacks = [ tracer ] )
)
== " foo "
)
assert prompt_spy . call_args . args [ 1 ] == { " question " : " What is your name? " }
assert llm_spy . call_args . args [ 1 ] == ChatPromptValue (
messages = [
SystemMessage ( content = " You are a nice assistant. " ) ,
HumanMessage ( content = " What is your name? " ) ,
]
)
assert tracer . runs == snapshot
mocker . stop ( prompt_spy )
mocker . stop ( llm_spy )
@freeze_time ( " 2023-01-01 " )
def test_prompt_with_chat_model_and_parser (
mocker : MockerFixture , snapshot : SnapshotAssertion
@ -722,6 +769,105 @@ async def test_router_runnable(
assert len ( router_run . child_runs ) == 2
@pytest.mark.asyncio
@freeze_time ( " 2023-01-01 " )
async def test_higher_order_lambda_runnable (
mocker : MockerFixture , snapshot : SnapshotAssertion
) - > None :
math_chain = ChatPromptTemplate . from_template (
" You are a math genius. Answer the question: {question} "
) | FakeListLLM ( responses = [ " 4 " ] )
english_chain = ChatPromptTemplate . from_template (
" You are an english major. Answer the question: {question} "
) | FakeListLLM ( responses = [ " 2 " ] )
input_map : Runnable = RunnableMap (
{ # type: ignore[arg-type]
" key " : lambda x : x [ " key " ] ,
" input " : { " question " : lambda x : x [ " question " ] } ,
}
)
def router ( input : Dict [ str , Any ] ) - > Runnable :
if input [ " key " ] == " math " :
return itemgetter ( " input " ) | math_chain
elif input [ " key " ] == " english " :
return itemgetter ( " input " ) | english_chain
else :
raise ValueError ( f " Unknown key: { input [ ' key ' ] } " )
chain : Runnable = input_map | router
assert dumps ( chain , pretty = True ) == snapshot
result = chain . invoke ( { " key " : " math " , " question " : " 2 + 2 " } )
assert result == " 4 "
result2 = chain . batch (
[ { " key " : " math " , " question " : " 2 + 2 " } , { " key " : " english " , " question " : " 2 + 2 " } ]
)
assert result2 == [ " 4 " , " 2 " ]
result = await chain . ainvoke ( { " key " : " math " , " question " : " 2 + 2 " } )
assert result == " 4 "
result2 = await chain . abatch (
[ { " key " : " math " , " question " : " 2 + 2 " } , { " key " : " english " , " question " : " 2 + 2 " } ]
)
assert result2 == [ " 4 " , " 2 " ]
# Test invoke
math_spy = mocker . spy ( math_chain . __class__ , " invoke " )
tracer = FakeTracer ( )
assert (
chain . invoke ( { " key " : " math " , " question " : " 2 + 2 " } , dict ( callbacks = [ tracer ] ) )
== " 4 "
)
assert math_spy . call_args . args [ 1 ] == {
" key " : " math " ,
" input " : { " question " : " 2 + 2 " } ,
}
assert len ( [ r for r in tracer . runs if r . parent_run_id is None ] ) == 1
parent_run = next ( r for r in tracer . runs if r . parent_run_id is None )
assert len ( parent_run . child_runs ) == 2
router_run = parent_run . child_runs [ 1 ]
assert router_run . name == " RunnableLambda "
assert len ( router_run . child_runs ) == 1
math_run = router_run . child_runs [ 0 ]
assert math_run . name == " RunnableSequence "
assert len ( math_run . child_runs ) == 3
# Test ainvoke
async def arouter ( input : Dict [ str , Any ] ) - > Runnable :
if input [ " key " ] == " math " :
return itemgetter ( " input " ) | math_chain
elif input [ " key " ] == " english " :
return itemgetter ( " input " ) | english_chain
else :
raise ValueError ( f " Unknown key: { input [ ' key ' ] } " )
achain : Runnable = input_map | arouter
math_spy = mocker . spy ( math_chain . __class__ , " ainvoke " )
tracer = FakeTracer ( )
assert (
await achain . ainvoke (
{ " key " : " math " , " question " : " 2 + 2 " } , dict ( callbacks = [ tracer ] )
)
== " 4 "
)
assert math_spy . call_args . args [ 1 ] == {
" key " : " math " ,
" input " : { " question " : " 2 + 2 " } ,
}
assert len ( [ r for r in tracer . runs if r . parent_run_id is None ] ) == 1
parent_run = next ( r for r in tracer . runs if r . parent_run_id is None )
assert len ( parent_run . child_runs ) == 2
router_run = parent_run . child_runs [ 1 ]
assert router_run . name == " RunnableLambda "
assert len ( router_run . child_runs ) == 1
math_run = router_run . child_runs [ 0 ]
assert math_run . name == " RunnableSequence "
assert len ( math_run . child_runs ) == 3
@freeze_time ( " 2023-01-01 " )
def test_seq_prompt_map ( mocker : MockerFixture , snapshot : SnapshotAssertion ) - > None :
passthrough = mocker . Mock ( side_effect = lambda x : x )
@ -1136,3 +1282,17 @@ def test_each(snapshot: SnapshotAssertion) -> None:
" test " ,
" this " ,
]
def test_recursive_lambda ( ) - > None :
def _simple_recursion ( x : int ) - > Union [ int , Runnable ] :
if x < 10 :
return RunnableLambda ( lambda * args : _simple_recursion ( x + 1 ) )
else :
return x
runnable = RunnableLambda ( _simple_recursion )
assert runnable . invoke ( 5 ) == 10
with pytest . raises ( RecursionError ) :
runnable . invoke ( 0 , { " recursion_limit " : 9 } )