core[patch]: Runnable with message history to use add_messages (#17958)

This PR updates RunnableWithMessageHistory to use add_messages which
will save on round-trips for any chat
history abstractions that implement the optimization. If the
optimization isn't
implemented, add_messages automatically invokes add_message serially.
This commit is contained in:
Eugene Yurtsev 2024-02-23 21:19:38 -05:00 committed by GitHub
parent 1c1bb1152e
commit 68527b809d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -97,9 +97,9 @@ class RunnableWithMessageHistory(RunnableBindingBase):
messages: List[BaseMessage] = Field(default_factory=list) messages: List[BaseMessage] = Field(default_factory=list)
def add_message(self, message: BaseMessage) -> None: def add_messages(self, messages: List[BaseMessage]) -> None:
\"\"\"Add a self-created message to the store\"\"\" \"\"\"Add a list of messages to the store\"\"\"
self.messages.append(message) self.messages.extend(messages)
def clear(self) -> None: def clear(self) -> None:
self.messages = [] self.messages = []
@ -420,7 +420,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
return await run_in_executor(config, self._enter_history, input, config) return await run_in_executor(config, self._enter_history, input, config)
def _exit_history(self, run: Run, config: RunnableConfig) -> None: def _exit_history(self, run: Run, config: RunnableConfig) -> None:
hist = config["configurable"]["message_history"] hist: BaseChatMessageHistory = config["configurable"]["message_history"]
# Get the input messages # Get the input messages
inputs = load(run.inputs) inputs = load(run.inputs)
@ -436,9 +436,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
# Get the output messages # Get the output messages
output_val = load(run.outputs) output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val) output_messages = self._get_output_messages(output_val)
hist.add_messages(input_messages + output_messages)
for m in input_messages + output_messages:
hist.add_message(m)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs) config = super()._merge_configs(*configs)