mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Update RunnableWithMessageHistory (#14351)
This PR updates RunnableWithMessage history to support user specific configuration for the factory. It extends support to passing multiple named arguments into the factory if the factory takes more than a single argument.
This commit is contained in:
parent
8a126c5d04
commit
76905aa043
@ -4,13 +4,10 @@ import asyncio
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
@ -28,6 +25,9 @@ if TYPE_CHECKING:
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
|
||||
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
||||
|
||||
@ -38,8 +38,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
Base runnable must have inputs and outputs that can be converted to a list of
|
||||
BaseMessages.
|
||||
|
||||
RunnableWithMessageHistory must always be called with a config that contains session_id, e.g.:
|
||||
``{"configurable": {"session_id": "<SESSION_ID>"}}``
|
||||
RunnableWithMessageHistory must always be called with a config that contains
|
||||
session_id, e.g.:
|
||||
|
||||
``{"configurable": {"session_id": "<SESSION_ID>"}}`
|
||||
|
||||
Example (dict input):
|
||||
.. code-block:: python
|
||||
@ -79,12 +81,66 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
)
|
||||
# -> "The inverse of cosine is called arccosine ..."
|
||||
|
||||
|
||||
Here's an example that uses an in memory chat history, and a factory that
|
||||
takes in two keys (user_id and conversation id) to create a chat history instance.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store = {}
|
||||
|
||||
def get_session_history(
|
||||
user_id: str, conversation_id: str
|
||||
) -> ChatMessageHistory:
|
||||
if (user_id, conversation_id) not in store:
|
||||
store[(user_id, conversation_id)] = ChatMessageHistory()
|
||||
return store[(user_id, conversation_id)]
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You're an assistant who's good at {ability}"),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", "{question}"),
|
||||
])
|
||||
|
||||
chain = prompt | ChatAnthropic(model="claude-2")
|
||||
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
chain,
|
||||
get_session_history=get_session_history,
|
||||
input_messages_key="messages",
|
||||
history_messages_key="history",
|
||||
history_factory_config=[
|
||||
ConfigurableFieldSpec(
|
||||
id="user_id",
|
||||
annotation=str,
|
||||
name="User ID",
|
||||
description="Unique identifier for the user.",
|
||||
default="",
|
||||
is_shared=True,
|
||||
),
|
||||
ConfigurableFieldSpec(
|
||||
id="conversation_id",
|
||||
annotation=str,
|
||||
name="Conversation ID",
|
||||
description="Unique identifier for the conversation.",
|
||||
default="",
|
||||
is_shared=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
chain_with_history.invoke(
|
||||
{"ability": "math", "question": "What does cosine mean?"},
|
||||
config={"configurable": {"user_id": "123", "conversation_id": "1"}}
|
||||
)
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
get_session_history: GetSessionHistoryCallable
|
||||
input_messages_key: Optional[str] = None
|
||||
output_messages_key: Optional[str] = None
|
||||
history_messages_key: Optional[str] = None
|
||||
history_factory_config: Sequence[ConfigurableFieldSpec]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
@ -102,6 +158,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
input_messages_key: Optional[str] = None,
|
||||
output_messages_key: Optional[str] = None,
|
||||
history_messages_key: Optional[str] = None,
|
||||
history_factory_config: Optional[Sequence[ConfigurableFieldSpec]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize RunnableWithMessageHistory.
|
||||
@ -121,10 +178,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
- A BaseMessage or sequence of BaseMessages
|
||||
- A dict with a key for a BaseMessage or sequence of BaseMessages
|
||||
|
||||
get_session_history: Function that returns a new BaseChatMessageHistory
|
||||
given a session id. Should take a single
|
||||
positional argument `session_id` which is a string and a named argument
|
||||
`user_id` which can be a string or None. e.g.:
|
||||
get_session_history: Function that returns a new BaseChatMessageHistory.
|
||||
This function should either take a single positional argument
|
||||
`session_id` of type string and return a corresponding
|
||||
chat message history instance.
|
||||
|
||||
```python
|
||||
def get_session_history(
|
||||
@ -135,12 +192,29 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
...
|
||||
```
|
||||
|
||||
Or it should take keyword arguments that match the keys of
|
||||
`session_history_config_specs` and return a corresponding
|
||||
chat message history instance.
|
||||
|
||||
```python
|
||||
def get_session_history(
|
||||
*,
|
||||
user_id: str,
|
||||
thread_id: str,
|
||||
) -> BaseChatMessageHistory:
|
||||
...
|
||||
```
|
||||
|
||||
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input.
|
||||
output_messages_key: Must be specified if the base runnable returns a dict
|
||||
as output.
|
||||
history_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input and expects a separate key for historical messages.
|
||||
history_factory_config: Configure fields that should be passed to the
|
||||
chat history factory. See ``ConfigurableFieldSpec`` for more details.
|
||||
Specifying these allows you to pass multiple config keys
|
||||
into the get_session_history factory.
|
||||
**kwargs: Arbitrary additional kwargs to pass to parent class
|
||||
``RunnableBindingBase`` init.
|
||||
""" # noqa: E501
|
||||
@ -155,20 +229,12 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
bound = (
|
||||
history_chain | runnable.with_listeners(on_end=self._exit_history)
|
||||
).with_config(run_name="RunnableWithMessageHistory")
|
||||
super().__init__(
|
||||
get_session_history=get_session_history,
|
||||
input_messages_key=input_messages_key,
|
||||
output_messages_key=output_messages_key,
|
||||
bound=bound,
|
||||
history_messages_key=history_messages_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
super().config_specs
|
||||
+ [
|
||||
if history_factory_config:
|
||||
_config_specs = history_factory_config
|
||||
else:
|
||||
# If not provided, then we'll use the default session_id field
|
||||
_config_specs = [
|
||||
ConfigurableFieldSpec(
|
||||
id="session_id",
|
||||
annotation=str,
|
||||
@ -178,6 +244,21 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
is_shared=True,
|
||||
),
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
get_session_history=get_session_history,
|
||||
input_messages_key=input_messages_key,
|
||||
output_messages_key=output_messages_key,
|
||||
bound=bound,
|
||||
history_messages_key=history_messages_key,
|
||||
history_factory_config=_config_specs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
super().config_specs + list(self.history_factory_config)
|
||||
)
|
||||
|
||||
def get_input_schema(
|
||||
@ -278,16 +359,46 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
config = super()._merge_configs(*configs)
|
||||
# extract session_id
|
||||
if "session_id" not in config.get("configurable", {}):
|
||||
expected_keys = [field_spec.id for field_spec in self.history_factory_config]
|
||||
|
||||
configurable = config.get("configurable", {})
|
||||
|
||||
missing_keys = set(expected_keys) - set(configurable.keys())
|
||||
|
||||
if missing_keys:
|
||||
example_input = {self.input_messages_key: "foo"}
|
||||
example_config = {"configurable": {"session_id": "123"}}
|
||||
example_configurable = {
|
||||
missing_key: "[your-value-here]" for missing_key in missing_keys
|
||||
}
|
||||
example_config = {"configurable": example_configurable}
|
||||
raise ValueError(
|
||||
"session_id is required."
|
||||
" Pass it in as part of the config argument to .invoke() or .stream()"
|
||||
f"\neg. chain.invoke({example_input}, {example_config})"
|
||||
f"Missing keys {sorted(missing_keys)} in config['configurable'] "
|
||||
f"Expected keys are {sorted(expected_keys)}."
|
||||
f"When using via .invoke() or .stream(), pass in a config; "
|
||||
f"e.g., chain.invoke({example_input}, {example_config})"
|
||||
)
|
||||
# attach message_history
|
||||
session_id = config["configurable"]["session_id"]
|
||||
config["configurable"]["message_history"] = self.get_session_history(session_id)
|
||||
|
||||
parameter_names = _get_parameter_names(self.get_session_history)
|
||||
|
||||
if len(expected_keys) == 1:
|
||||
# If arity = 1, then invoke function by positional arguments
|
||||
message_history = self.get_session_history(configurable[expected_keys[0]])
|
||||
else:
|
||||
# otherwise verify that names of keys patch and invoke by named arguments
|
||||
if set(expected_keys) != set(parameter_names):
|
||||
raise ValueError(
|
||||
f"Expected keys {sorted(expected_keys)} do not match parameter "
|
||||
f"names {sorted(parameter_names)} of get_session_history."
|
||||
)
|
||||
|
||||
message_history = self.get_session_history(
|
||||
**{key: configurable[key] for key in expected_keys}
|
||||
)
|
||||
config["configurable"]["message_history"] = message_history
|
||||
return config
|
||||
|
||||
|
||||
def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]:
|
||||
"""Get the parameter names of the callable."""
|
||||
sig = inspect.signature(callable_)
|
||||
return list(sig.parameters.keys())
|
||||
|
@ -1,10 +1,11 @@
|
||||
from typing import Any, Callable, Sequence, Union
|
||||
from typing import Any, Callable, Dict, List, Sequence, Union
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec
|
||||
from tests.unit_tests.fake.memory import ChatMessageHistory
|
||||
|
||||
|
||||
@ -130,7 +131,7 @@ def test_output_messages() -> None:
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable,
|
||||
runnable, # type: ignore
|
||||
get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="history",
|
||||
@ -238,3 +239,114 @@ def test_get_input_schema_input_messages() -> None:
|
||||
with_history.get_input_schema().schema()
|
||||
== RunnableWithChatHistoryInput.schema()
|
||||
)
|
||||
|
||||
|
||||
def test_using_custom_config_specs() -> None:
|
||||
"""Test that we can configure which keys should be passed to the session factory."""
|
||||
|
||||
def _fake_llm(input: Dict[str, Any]) -> List[BaseMessage]:
|
||||
messages = input["messages"]
|
||||
return [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
str(m.content) for m in messages if isinstance(m, HumanMessage)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
runnable = RunnableLambda(_fake_llm)
|
||||
store = {}
|
||||
|
||||
def get_session_history(user_id: str, conversation_id: str) -> ChatMessageHistory:
|
||||
if (user_id, conversation_id) not in store:
|
||||
store[(user_id, conversation_id)] = ChatMessageHistory()
|
||||
return store[(user_id, conversation_id)]
|
||||
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
runnable, # type: ignore
|
||||
get_session_history=get_session_history,
|
||||
input_messages_key="messages",
|
||||
history_messages_key="history",
|
||||
history_factory_config=[
|
||||
ConfigurableFieldSpec(
|
||||
id="user_id",
|
||||
annotation=str,
|
||||
name="User ID",
|
||||
description="Unique identifier for the user.",
|
||||
default="",
|
||||
is_shared=True,
|
||||
),
|
||||
ConfigurableFieldSpec(
|
||||
id="conversation_id",
|
||||
annotation=str,
|
||||
name="Conversation ID",
|
||||
description="Unique identifier for the conversation.",
|
||||
default=None,
|
||||
is_shared=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
result = with_message_history.invoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="hello")],
|
||||
},
|
||||
{"configurable": {"user_id": "user1", "conversation_id": "1"}},
|
||||
)
|
||||
assert result == [
|
||||
AIMessage(content="you said: hello"),
|
||||
]
|
||||
assert store == {
|
||||
("user1", "1"): ChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
result = with_message_history.invoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="goodbye")],
|
||||
},
|
||||
{"configurable": {"user_id": "user1", "conversation_id": "1"}},
|
||||
)
|
||||
assert result == [
|
||||
AIMessage(content="you said: goodbye"),
|
||||
]
|
||||
assert store == {
|
||||
("user1", "1"): ChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
HumanMessage(content="goodbye"),
|
||||
AIMessage(content="you said: goodbye"),
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
result = with_message_history.invoke(
|
||||
{
|
||||
"messages": [HumanMessage(content="meow")],
|
||||
},
|
||||
{"configurable": {"user_id": "user2", "conversation_id": "1"}},
|
||||
)
|
||||
assert result == [
|
||||
AIMessage(content="you said: meow"),
|
||||
]
|
||||
assert store == {
|
||||
("user1", "1"): ChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="you said: hello"),
|
||||
HumanMessage(content="goodbye"),
|
||||
AIMessage(content="you said: goodbye"),
|
||||
]
|
||||
),
|
||||
("user2", "1"): ChatMessageHistory(
|
||||
messages=[
|
||||
HumanMessage(content="meow"),
|
||||
AIMessage(content="you said: meow"),
|
||||
]
|
||||
),
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user