core[patch]: fix no current event loop for sql history in async mode (#22933)

- **Description:** When use
RunnableWithMessageHistory/SQLChatMessageHistory in async mode, we'll
get the following error:
```
Error in RootListenersTracer.on_chain_end callback: RuntimeError("There is no current event loop in thread 'asyncio_3'.")
```
which throwed by
ddfbca38df/libs/community/langchain_community/chat_message_histories/sql.py (L259).
and no message history will be add to database.

In this patch, a new _aexit_history function which will'be called in
async mode is added, and in turn aadd_messages will be called.

In this patch, we use `afunc` attribute of a Runnable to check if the
end listener should be run in async mode or not.

  - **Issue:** #22021, #22022 
  - **Dependencies:** N/A
This commit is contained in:
mackong 2024-06-21 22:39:47 +08:00 committed by GitHub
parent 1c2b9cc9ab
commit 360a70c8a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 345 additions and 35 deletions

View File

@ -1,4 +1,3 @@
import asyncio
import contextlib import contextlib
import json import json
import logging import logging
@ -252,13 +251,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
await session.commit() await session.commit()
def add_messages(self, messages: Sequence[BaseMessage]) -> None: def add_messages(self, messages: Sequence[BaseMessage]) -> None:
# The method RunnableWithMessageHistory._exit_history() call # Add all messages in one transaction
# add_message method by mistake and not aadd_message.
# See https://github.com/langchain-ai/langchain/issues/22021
if self.async_mode:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.aadd_messages(messages))
else:
with self._make_sync_session() as session: with self._make_sync_session() as session:
for message in messages: for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id)) session.add(self.converter.to_sql_model(message, self.session_id))

View File

@ -16,6 +16,7 @@ from typing import (
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load from langchain_core.load.load import load
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableBranch
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
@ -306,8 +307,17 @@ class RunnableWithMessageHistory(RunnableBindingBase):
history_chain = RunnablePassthrough.assign( history_chain = RunnablePassthrough.assign(
**{messages_key: history_chain} **{messages_key: history_chain}
).with_config(run_name="insert_history") ).with_config(run_name="insert_history")
bound = ( bound: Runnable = (
history_chain | runnable.with_listeners(on_end=self._exit_history) history_chain
| RunnableBranch(
(
RunnableLambda(
self._is_not_async, afunc=self._is_async
).with_config(run_name="RunnableWithMessageHistoryInAsyncMode"),
runnable.with_alisteners(on_end=self._aexit_history),
),
runnable.with_listeners(on_end=self._exit_history),
)
).with_config(run_name="RunnableWithMessageHistory") ).with_config(run_name="RunnableWithMessageHistory")
if history_factory_config: if history_factory_config:
@ -367,6 +377,12 @@ class RunnableWithMessageHistory(RunnableBindingBase):
else: else:
return super_schema return super_schema
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return False
async def _is_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return True
def _get_input_messages( def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]: ) -> List[BaseMessage]:
@ -483,6 +499,23 @@ class RunnableWithMessageHistory(RunnableBindingBase):
output_messages = self._get_output_messages(output_val) output_messages = self._get_output_messages(output_val)
hist.add_messages(input_messages + output_messages) hist.add_messages(input_messages + output_messages)
async def _aexit_history(self, run: Run, config: RunnableConfig) -> None:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
# Get the input messages
inputs = load(run.inputs)
input_messages = self._get_input_messages(inputs)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:
historic_messages = config["configurable"]["message_history"].messages
input_messages = input_messages[len(historic_messages) :]
# Get the output messages
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)
await hist.aadd_messages(input_messages + output_messages)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs) config = super()._merge_configs(*configs)
expected_keys = [field_spec.id for field_spec in self.history_factory_config] expected_keys = [field_spec.id for field_spec in self.history_factory_config]

View File

@ -62,6 +62,31 @@ def test_input_messages() -> None:
} }
async def test_input_messages_async() -> None:
runnable = RunnableLambda(
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
store: Dict = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {"configurable": {"session_id": "1_async"}}
output = await with_history.ainvoke([HumanMessage(content="hello")], config)
assert output == "you said: hello"
output = await with_history.ainvoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye"
assert store == {
"1_async": ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
HumanMessage(content="good bye"),
AIMessage(content="you said: hello\ngood bye"),
]
)
}
def test_input_dict() -> None: def test_input_dict() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: "you said: " lambda input: "you said: "
@ -82,6 +107,28 @@ def test_input_dict() -> None:
assert output == "you said: hello\ngood bye" assert output == "you said: hello\ngood bye"
async def test_input_dict_async() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
+ "\n".join(
str(m.content) for m in input["messages"] if isinstance(m, HumanMessage)
)
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable, get_session_history, input_messages_key="messages"
)
config: RunnableConfig = {"configurable": {"session_id": "2_async"}}
output = await with_history.ainvoke(
{"messages": [HumanMessage(content="hello")]}, config
)
assert output == "you said: hello"
output = await with_history.ainvoke(
{"messages": [HumanMessage(content="good bye")]}, config
)
assert output == "you said: hello\ngood bye"
def test_input_dict_with_history_key() -> None: def test_input_dict_with_history_key() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: "you said: " lambda input: "you said: "
@ -104,6 +151,28 @@ def test_input_dict_with_history_key() -> None:
assert output == "you said: hello\ngood bye" assert output == "you said: hello\ngood bye"
async def test_input_dict_with_history_key_async() -> None:
runnable = RunnableLambda(
lambda input: "you said: "
+ "\n".join(
[str(m.content) for m in input["history"] if isinstance(m, HumanMessage)]
+ [input["input"]]
)
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
)
config: RunnableConfig = {"configurable": {"session_id": "3_async"}}
output = await with_history.ainvoke({"input": "hello"}, config)
assert output == "you said: hello"
output = await with_history.ainvoke({"input": "good bye"}, config)
assert output == "you said: hello\ngood bye"
def test_output_message() -> None: def test_output_message() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: AIMessage( lambda input: AIMessage(
@ -132,7 +201,34 @@ def test_output_message() -> None:
assert output == AIMessage(content="you said: hello\ngood bye") assert output == AIMessage(content="you said: hello\ngood bye")
def test_input_messages_output_message() -> None: async def test_output_message_async() -> None:
runnable = RunnableLambda(
lambda input: AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
)
config: RunnableConfig = {"configurable": {"session_id": "4_async"}}
output = await with_history.ainvoke({"input": "hello"}, config)
assert output == AIMessage(content="you said: hello")
output = await with_history.ainvoke({"input": "good bye"}, config)
assert output == AIMessage(content="you said: hello\ngood bye")
class LengthChatModel(BaseChatModel): class LengthChatModel(BaseChatModel):
"""A fake chat model that returns the length of the messages passed in.""" """A fake chat model that returns the length of the messages passed in."""
@ -145,28 +241,42 @@ def test_input_messages_output_message() -> None:
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
return ChatResult( return ChatResult(
generations=[ generations=[ChatGeneration(message=AIMessage(content=str(len(messages))))]
ChatGeneration(message=AIMessage(content=str(len(messages))))
]
) )
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "length-fake-chat-model" return "length-fake-chat-model"
def test_input_messages_output_message() -> None:
runnable = LengthChatModel() runnable = LengthChatModel()
get_session_history = _get_get_session_history() get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory( with_history = RunnableWithMessageHistory(
runnable, runnable,
get_session_history, get_session_history,
) )
config: RunnableConfig = {"configurable": {"session_id": "4"}} config: RunnableConfig = {"configurable": {"session_id": "5"}}
output = with_history.invoke([HumanMessage(content="hi")], config) output = with_history.invoke([HumanMessage(content="hi")], config)
assert output.content == "1" assert output.content == "1"
output = with_history.invoke([HumanMessage(content="hi")], config) output = with_history.invoke([HumanMessage(content="hi")], config)
assert output.content == "3" assert output.content == "3"
async def test_input_messages_output_message_async() -> None:
runnable = LengthChatModel()
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
)
config: RunnableConfig = {"configurable": {"session_id": "5_async"}}
output = await with_history.ainvoke([HumanMessage(content="hi")], config)
assert output.content == "1"
output = await with_history.ainvoke([HumanMessage(content="hi")], config)
assert output.content == "3"
def test_output_messages() -> None: def test_output_messages() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: [ lambda input: [
@ -190,13 +300,43 @@ def test_output_messages() -> None:
input_messages_key="input", input_messages_key="input",
history_messages_key="history", history_messages_key="history",
) )
config: RunnableConfig = {"configurable": {"session_id": "5"}} config: RunnableConfig = {"configurable": {"session_id": "6"}}
output = with_history.invoke({"input": "hello"}, config) output = with_history.invoke({"input": "hello"}, config)
assert output == [AIMessage(content="you said: hello")] assert output == [AIMessage(content="you said: hello")]
output = with_history.invoke({"input": "good bye"}, config) output = with_history.invoke({"input": "good bye"}, config)
assert output == [AIMessage(content="you said: hello\ngood bye")] assert output == [AIMessage(content="you said: hello\ngood bye")]
async def test_output_messages_async() -> None:
runnable = RunnableLambda(
lambda input: [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable, # type: ignore
get_session_history,
input_messages_key="input",
history_messages_key="history",
)
config: RunnableConfig = {"configurable": {"session_id": "6_async"}}
output = await with_history.ainvoke({"input": "hello"}, config)
assert output == [AIMessage(content="you said: hello")]
output = await with_history.ainvoke({"input": "good bye"}, config)
assert output == [AIMessage(content="you said: hello\ngood bye")]
def test_output_dict() -> None: def test_output_dict() -> None:
runnable = RunnableLambda( runnable = RunnableLambda(
lambda input: { lambda input: {
@ -223,13 +363,46 @@ def test_output_dict() -> None:
history_messages_key="history", history_messages_key="history",
output_messages_key="output", output_messages_key="output",
) )
config: RunnableConfig = {"configurable": {"session_id": "6"}} config: RunnableConfig = {"configurable": {"session_id": "7"}}
output = with_history.invoke({"input": "hello"}, config) output = with_history.invoke({"input": "hello"}, config)
assert output == {"output": [AIMessage(content="you said: hello")]} assert output == {"output": [AIMessage(content="you said: hello")]}
output = with_history.invoke({"input": "good bye"}, config) output = with_history.invoke({"input": "good bye"}, config)
assert output == {"output": [AIMessage(content="you said: hello\ngood bye")]} assert output == {"output": [AIMessage(content="you said: hello\ngood bye")]}
async def test_output_dict_async() -> None:
runnable = RunnableLambda(
lambda input: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
}
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
output_messages_key="output",
)
config: RunnableConfig = {"configurable": {"session_id": "7_async"}}
output = await with_history.ainvoke({"input": "hello"}, config)
assert output == {"output": [AIMessage(content="you said: hello")]}
output = await with_history.ainvoke({"input": "good bye"}, config)
assert output == {"output": [AIMessage(content="you said: hello\ngood bye")]}
def test_get_input_schema_input_dict() -> None: def test_get_input_schema_input_dict() -> None:
class RunnableWithChatHistoryInput(BaseModel): class RunnableWithChatHistoryInput(BaseModel):
input: Union[str, BaseMessage, Sequence[BaseMessage]] input: Union[str, BaseMessage, Sequence[BaseMessage]]
@ -404,3 +577,114 @@ def test_using_custom_config_specs() -> None:
] ]
), ),
} }
async def test_using_custom_config_specs_async() -> None:
"""Test that we can configure which keys should be passed to the session factory."""
def _fake_llm(input: Dict[str, Any]) -> List[BaseMessage]:
messages = input["messages"]
return [
AIMessage(
content="you said: "
+ "\n".join(
str(m.content) for m in messages if isinstance(m, HumanMessage)
)
)
]
runnable = RunnableLambda(_fake_llm)
store = {}
def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory:
if (user_id, conversation_id) not in store:
store[(user_id, conversation_id)] = ChatMessageHistory()
return store[(user_id, conversation_id)]
with_message_history = RunnableWithMessageHistory(
runnable, # type: ignore
get_session_history=get_session_history,
input_messages_key="messages",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID",
description="Unique identifier for the user.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="conversation_id",
annotation=str,
name="Conversation ID",
description="Unique identifier for the conversation.",
default=None,
is_shared=True,
),
],
)
result = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content="hello")],
},
{"configurable": {"user_id": "user1_async", "conversation_id": "1_async"}},
)
assert result == [
AIMessage(content="you said: hello"),
]
assert store == {
("user1_async", "1_async"): ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
]
)
}
result = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content="goodbye")],
},
{"configurable": {"user_id": "user1_async", "conversation_id": "1_async"}},
)
assert result == [
AIMessage(content="you said: goodbye"),
]
assert store == {
("user1_async", "1_async"): ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
HumanMessage(content="goodbye"),
AIMessage(content="you said: goodbye"),
]
)
}
result = await with_message_history.ainvoke(
{
"messages": [HumanMessage(content="meow")],
},
{"configurable": {"user_id": "user2_async", "conversation_id": "1_async"}},
)
assert result == [
AIMessage(content="you said: meow"),
]
assert store == {
("user1_async", "1_async"): ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
HumanMessage(content="goodbye"),
AIMessage(content="you said: goodbye"),
]
),
("user2_async", "1_async"): ChatMessageHistory(
messages=[
HumanMessage(content="meow"),
AIMessage(content="you said: meow"),
]
),
}