@ -62,6 +62,31 @@ def test_input_messages() -> None:
}
async def test_input_messages_async ( ) - > None :
runnable = RunnableLambda (
lambda messages : " you said: "
+ " \n " . join ( str ( m . content ) for m in messages if isinstance ( m , HumanMessage ) )
)
store : Dict = { }
get_session_history = _get_get_session_history ( store = store )
with_history = RunnableWithMessageHistory ( runnable , get_session_history )
config : RunnableConfig = { " configurable " : { " session_id " : " 1_async " } }
output = await with_history . ainvoke ( [ HumanMessage ( content = " hello " ) ] , config )
assert output == " you said: hello "
output = await with_history . ainvoke ( [ HumanMessage ( content = " good bye " ) ] , config )
assert output == " you said: hello \n good bye "
assert store == {
" 1_async " : ChatMessageHistory (
messages = [
HumanMessage ( content = " hello " ) ,
AIMessage ( content = " you said: hello " ) ,
HumanMessage ( content = " good bye " ) ,
AIMessage ( content = " you said: hello \n good bye " ) ,
]
)
}
def test_input_dict ( ) - > None :
runnable = RunnableLambda (
lambda input : " you said: "
@ -82,6 +107,28 @@ def test_input_dict() -> None:
assert output == " you said: hello \n good bye "
async def test_input_dict_async ( ) - > None :
runnable = RunnableLambda (
lambda input : " you said: "
+ " \n " . join (
str ( m . content ) for m in input [ " messages " ] if isinstance ( m , HumanMessage )
)
)
get_session_history = _get_get_session_history ( )
with_history = RunnableWithMessageHistory (
runnable , get_session_history , input_messages_key = " messages "
)
config : RunnableConfig = { " configurable " : { " session_id " : " 2_async " } }
output = await with_history . ainvoke (
{ " messages " : [ HumanMessage ( content = " hello " ) ] } , config
)
assert output == " you said: hello "
output = await with_history . ainvoke (
{ " messages " : [ HumanMessage ( content = " good bye " ) ] } , config
)
assert output == " you said: hello \n good bye "
def test_input_dict_with_history_key ( ) - > None :
runnable = RunnableLambda (
lambda input : " you said: "
@ -104,6 +151,28 @@ def test_input_dict_with_history_key() -> None:
assert output == " you said: hello \n good bye "
async def test_input_dict_with_history_key_async ( ) - > None :
runnable = RunnableLambda (
lambda input : " you said: "
+ " \n " . join (
[ str ( m . content ) for m in input [ " history " ] if isinstance ( m , HumanMessage ) ]
+ [ input [ " input " ] ]
)
)
get_session_history = _get_get_session_history ( )
with_history = RunnableWithMessageHistory (
runnable ,
get_session_history ,
input_messages_key = " input " ,
history_messages_key = " history " ,
)
config : RunnableConfig = { " configurable " : { " session_id " : " 3_async " } }
output = await with_history . ainvoke ( { " input " : " hello " } , config )
assert output == " you said: hello "
output = await with_history . ainvoke ( { " input " : " good bye " } , config )
assert output == " you said: hello \n good bye "
def test_output_message ( ) - > None :
runnable = RunnableLambda (
lambda input : AIMessage (
@ -132,41 +201,82 @@ def test_output_message() -> None:
assert output == AIMessage ( content = " you said: hello \n good bye " )
def test_input_messages_output_message ( ) - > None :
class LengthChatModel ( BaseChatModel ) :
""" A fake chat model that returns the length of the messages passed in. """
def _generate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > ChatResult :
""" Top Level call """
return ChatResult (
generations = [
ChatGeneration ( message = AIMessage ( content = str ( len ( messages ) ) ) )
async def test_output_message_async ( ) - > None :
runnable = RunnableLambda (
lambda input : AIMessage (
content = " you said: "
+ " \n " . join (
[
str ( m . content )
for m in input [ " history " ]
if isinstance ( m , HumanMessage )
]
+ [ input [ " input " ] ]
)
)
)
get_session_history = _get_get_session_history ( )
with_history = RunnableWithMessageHistory (
runnable ,
get_session_history ,
input_messages_key = " input " ,
history_messages_key = " history " ,
)
config : RunnableConfig = { " configurable " : { " session_id " : " 4_async " } }
output = await with_history . ainvoke ( { " input " : " hello " } , config )
assert output == AIMessage ( content = " you said: hello " )
output = await with_history . ainvoke ( { " input " : " good bye " } , config )
assert output == AIMessage ( content = " you said: hello \n good bye " )
class LengthChatModel ( BaseChatModel ) :
""" A fake chat model that returns the length of the messages passed in. """
def _generate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > ChatResult :
""" Top Level call """
return ChatResult (
generations = [ ChatGeneration ( message = AIMessage ( content = str ( len ( messages ) ) ) ) ]
)
@property
def _llm_type ( self ) - > str :
return " length-fake-chat-model "
@property
def _llm_type ( self ) - > str :
return " length-fake-chat-model "
def test_input_messages_output_message ( ) - > None :
runnable = LengthChatModel ( )
get_session_history = _get_get_session_history ( )
with_history = RunnableWithMessageHistory (
runnable ,
get_session_history ,
)
config : RunnableConfig = { " configurable " : { " session_id " : " 4 " } }
config : RunnableConfig = { " configurable " : { " session_id " : " 5 " } }
output = with_history . invoke ( [ HumanMessage ( content = " hi " ) ] , config )
assert output . content == " 1 "
output = with_history . invoke ( [ HumanMessage ( content = " hi " ) ] , config )
assert output . content == " 3 "
async def test_input_messages_output_message_async ( ) - > None :
runnable = LengthChatModel ( )
get_session_history = _get_get_session_history ( )
with_history = RunnableWithMessageHistory (
runnable ,
get_session_history ,
)
config : RunnableConfig = { " configurable " : { " session_id " : " 5_async " } }
output = await with_history . ainvoke ( [ HumanMessage ( content = " hi " ) ] , config )
assert output . content == " 1 "
output = await with_history . ainvoke ( [ HumanMessage ( content = " hi " ) ] , config )
assert output . content == " 3 "
def test_output_messages ( ) - > None :
runnable = RunnableLambda (
lambda input : [
@ -190,13 +300,43 @@ def test_output_messages() -> None:
input_messages_key = " input " ,
history_messages_key = " history " ,
)
config : RunnableConfig = { " configurable " : { " session_id " : " 5 " } }
config : RunnableConfig = { " configurable " : { " session_id " : " 6 " } }
output = with_history . invoke ( { " input " : " hello " } , config )
assert output == [ AIMessage ( content = " you said: hello " ) ]
output = with_history . invoke ( { " input " : " good bye " } , config )
assert output == [ AIMessage ( content = " you said: hello \n good bye " ) ]
async def test_output_messages_async ( ) - > None :
runnable = RunnableLambda (
lambda input : [
AIMessage (
content = " you said: "
+ " \n " . join (
[
str ( m . content )
for m in input [ " history " ]
if isinstance ( m , HumanMessage )
]
+ [ input [ " input " ] ]
)
)
]
)
get_session_history = _get_get_session_history ( )
with_history = RunnableWithMessageHistory (
runnable , # type: ignore
get_session_history ,
input_messages_key = " input " ,
history_messages_key = " history " ,
)
config : RunnableConfig = { " configurable " : { " session_id " : " 6_async " } }
output = await with_history . ainvoke ( { " input " : " hello " } , config )
assert output == [ AIMessage ( content = " you said: hello " ) ]
output = await with_history . ainvoke ( { " input " : " good bye " } , config )
assert output == [ AIMessage ( content = " you said: hello \n good bye " ) ]
def test_output_dict ( ) - > None :
runnable = RunnableLambda (
lambda input : {
@ -223,13 +363,46 @@ def test_output_dict() -> None:
history_messages_key = " history " ,
output_messages_key = " output " ,
)
config : RunnableConfig = { " configurable " : { " session_id " : " 6 " } }
config : RunnableConfig = { " configurable " : { " session_id " : " 7 " } }
output = with_history . invoke ( { " input " : " hello " } , config )
assert output == { " output " : [ AIMessage ( content = " you said: hello " ) ] }
output = with_history . invoke ( { " input " : " good bye " } , config )
assert output == { " output " : [ AIMessage ( content = " you said: hello \n good bye " ) ] }
async def test_output_dict_async ( ) - > None :
runnable = RunnableLambda (
lambda input : {
" output " : [
AIMessage (
content = " you said: "
+ " \n " . join (
[
str ( m . content )
for m in input [ " history " ]
if isinstance ( m , HumanMessage )
]
+ [ input [ " input " ] ]
)
)
]
}
)
get_session_history = _get_get_session_history ( )
with_history = RunnableWithMessageHistory (
runnable ,
get_session_history ,
input_messages_key = " input " ,
history_messages_key = " history " ,
output_messages_key = " output " ,
)
config : RunnableConfig = { " configurable " : { " session_id " : " 7_async " } }
output = await with_history . ainvoke ( { " input " : " hello " } , config )
assert output == { " output " : [ AIMessage ( content = " you said: hello " ) ] }
output = await with_history . ainvoke ( { " input " : " good bye " } , config )
assert output == { " output " : [ AIMessage ( content = " you said: hello \n good bye " ) ] }
def test_get_input_schema_input_dict ( ) - > None :
class RunnableWithChatHistoryInput ( BaseModel ) :
input : Union [ str , BaseMessage , Sequence [ BaseMessage ] ]
@ -404,3 +577,114 @@ def test_using_custom_config_specs() -> None:
]
) ,
}
async def test_using_custom_config_specs_async ( ) - > None :
""" Test that we can configure which keys should be passed to the session factory. """
def _fake_llm ( input : Dict [ str , Any ] ) - > List [ BaseMessage ] :
messages = input [ " messages " ]
return [
AIMessage (
content = " you said: "
+ " \n " . join (
str ( m . content ) for m in messages if isinstance ( m , HumanMessage )
)
)
]
runnable = RunnableLambda ( _fake_llm )
store = { }
def get_session_history ( user_id : str , conversation_id : str ) - > ChatMessageHistory :
if ( user_id , conversation_id ) not in store :
store [ ( user_id , conversation_id ) ] = ChatMessageHistory ( )
return store [ ( user_id , conversation_id ) ]
with_message_history = RunnableWithMessageHistory (
runnable , # type: ignore
get_session_history = get_session_history ,
input_messages_key = " messages " ,
history_messages_key = " history " ,
history_factory_config = [
ConfigurableFieldSpec (
id = " user_id " ,
annotation = str ,
name = " User ID " ,
description = " Unique identifier for the user. " ,
default = " " ,
is_shared = True ,
) ,
ConfigurableFieldSpec (
id = " conversation_id " ,
annotation = str ,
name = " Conversation ID " ,
description = " Unique identifier for the conversation. " ,
default = None ,
is_shared = True ,
) ,
] ,
)
result = await with_message_history . ainvoke (
{
" messages " : [ HumanMessage ( content = " hello " ) ] ,
} ,
{ " configurable " : { " user_id " : " user1_async " , " conversation_id " : " 1_async " } } ,
)
assert result == [
AIMessage ( content = " you said: hello " ) ,
]
assert store == {
( " user1_async " , " 1_async " ) : ChatMessageHistory (
messages = [
HumanMessage ( content = " hello " ) ,
AIMessage ( content = " you said: hello " ) ,
]
)
}
result = await with_message_history . ainvoke (
{
" messages " : [ HumanMessage ( content = " goodbye " ) ] ,
} ,
{ " configurable " : { " user_id " : " user1_async " , " conversation_id " : " 1_async " } } ,
)
assert result == [
AIMessage ( content = " you said: goodbye " ) ,
]
assert store == {
( " user1_async " , " 1_async " ) : ChatMessageHistory (
messages = [
HumanMessage ( content = " hello " ) ,
AIMessage ( content = " you said: hello " ) ,
HumanMessage ( content = " goodbye " ) ,
AIMessage ( content = " you said: goodbye " ) ,
]
)
}
result = await with_message_history . ainvoke (
{
" messages " : [ HumanMessage ( content = " meow " ) ] ,
} ,
{ " configurable " : { " user_id " : " user2_async " , " conversation_id " : " 1_async " } } ,
)
assert result == [
AIMessage ( content = " you said: meow " ) ,
]
assert store == {
( " user1_async " , " 1_async " ) : ChatMessageHistory (
messages = [
HumanMessage ( content = " hello " ) ,
AIMessage ( content = " you said: hello " ) ,
HumanMessage ( content = " goodbye " ) ,
AIMessage ( content = " you said: goodbye " ) ,
]
) ,
( " user2_async " , " 1_async " ) : ChatMessageHistory (
messages = [
HumanMessage ( content = " meow " ) ,
AIMessage ( content = " you said: meow " ) ,
]
) ,
}