From 6abb2c2c08770fd020979d49f53f8cb8cde99896 Mon Sep 17 00:00:00 2001 From: Nikhil Kumar <96546721+nikumar1206@users.noreply.github.com> Date: Thu, 10 Aug 2023 21:17:22 -0400 Subject: [PATCH] Buffer method of ConversationTokenBufferMemory should be able to return messages as string (#7057) ### Description: `ConversationBufferTokenMemory` should have a simple way of returning the conversation messages as a string. Previously to complete this, you would only have the option to return memory as an array through the buffer method and call `get_buffer_string` by importing it from `langchain.schema`, or use the `load_memory_variables` method and key into `self.memory_key`. ### Maintainer @hwchase17 --------- Co-authored-by: Bagatur --- libs/langchain/langchain/memory/buffer.py | 4 +-- .../langchain/memory/token_buffer.py | 27 +++++++++++-------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/libs/langchain/langchain/memory/buffer.py b/libs/langchain/langchain/memory/buffer.py index 7eac112b45..3809577d06 100644 --- a/libs/langchain/langchain/memory/buffer.py +++ b/libs/langchain/langchain/memory/buffer.py @@ -4,7 +4,7 @@ from pydantic import root_validator from langchain.memory.chat_memory import BaseChatMemory, BaseMemory from langchain.memory.utils import get_prompt_input_key -from langchain.schema.messages import get_buffer_string +from langchain.schema.messages import BaseMessage, get_buffer_string class ConversationBufferMemory(BaseChatMemory): @@ -29,7 +29,7 @@ class ConversationBufferMemory(BaseChatMemory): ) @property - def buffer_as_messages(self) -> List[Any]: + def buffer_as_messages(self) -> List[BaseMessage]: """Exposes the buffer as a list of messages in case return_messages is False.""" return self.chat_memory.messages diff --git a/libs/langchain/langchain/memory/token_buffer.py b/libs/langchain/langchain/memory/token_buffer.py index a964f2f6b4..864ded2fc5 100644 --- a/libs/langchain/langchain/memory/token_buffer.py +++ b/libs/langchain/langchain/memory/token_buffer.py @@ -15,8 +15,22 @@ class ConversationTokenBufferMemory(BaseChatMemory): max_token_limit: int = 2000 @property - def buffer(self) -> List[BaseMessage]: + def buffer(self) -> Any: """String buffer of memory.""" + return self.buffer_as_messages if self.return_messages else self.buffer_as_str + + @property + def buffer_as_str(self) -> str: + """Exposes the buffer as a string in case return_messages is True.""" + return get_buffer_string( + self.chat_memory.messages, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + + @property + def buffer_as_messages(self) -> List[BaseMessage]: + """Exposes the buffer as a list of messages in case return_messages is False.""" return self.chat_memory.messages @property @@ -29,16 +43,7 @@ class ConversationTokenBufferMemory(BaseChatMemory): def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Return history buffer.""" - buffer: Any = self.buffer - if self.return_messages: - final_buffer: Any = buffer - else: - final_buffer = get_buffer_string( - buffer, - human_prefix=self.human_prefix, - ai_prefix=self.ai_prefix, - ) - return {self.memory_key: final_buffer} + return {self.memory_key: self.buffer} def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer. Pruned."""