From c1dc784a3dddc0376d3cf507be8db832de3306db Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 10 Mar 2023 11:27:15 -0800 Subject: [PATCH] buffer memory old version (#1581) bring back an older version of memory since people seem to be using it more widely --- langchain/chains/conversation/memory.py | 6 ++- langchain/memory/__init__.py | 6 ++- langchain/memory/buffer.py | 60 +++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/langchain/chains/conversation/memory.py b/langchain/chains/conversation/memory.py index f4a079b1..7aad58f8 100644 --- a/langchain/chains/conversation/memory.py +++ b/langchain/chains/conversation/memory.py @@ -1,6 +1,9 @@ """Memory modules for conversation prompts.""" -from langchain.memory.buffer import ConversationBufferMemory +from langchain.memory.buffer import ( + ConversationBufferMemory, + ConversationStringBufferMemory, +) from langchain.memory.buffer_window import ConversationBufferWindowMemory from langchain.memory.combined import CombinedMemory from langchain.memory.entity import ConversationEntityMemory @@ -18,4 +21,5 @@ __all__ = [ "ConversationEntityMemory", "ConversationBufferMemory", "CombinedMemory", + "ConversationStringBufferMemory", ] diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index cc4f456b..3b030039 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -1,4 +1,7 @@ -from langchain.memory.buffer import ConversationBufferMemory +from langchain.memory.buffer import ( + ConversationBufferMemory, + ConversationStringBufferMemory, +) from langchain.memory.buffer_window import ConversationBufferWindowMemory from langchain.memory.chat_memory import ChatMessageHistory from langchain.memory.combined import CombinedMemory @@ -18,4 +21,5 @@ __all__ = [ "ConversationEntityMemory", "ConversationSummaryMemory", "ChatMessageHistory", + "ConversationStringBufferMemory", ] diff --git a/langchain/memory/buffer.py b/langchain/memory/buffer.py index 1c7fa3bf..8ca19e9b 100644 --- a/langchain/memory/buffer.py +++ b/langchain/memory/buffer.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, root_validator -from langchain.memory.chat_memory import BaseChatMemory -from langchain.memory.utils import get_buffer_string +from langchain.memory.chat_memory import BaseChatMemory, BaseMemory +from langchain.memory.utils import get_buffer_string, get_prompt_input_key class ConversationBufferMemory(BaseChatMemory, BaseModel): @@ -36,3 +36,55 @@ class ConversationBufferMemory(BaseChatMemory, BaseModel): def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Return history buffer.""" return {self.memory_key: self.buffer} + + +class ConversationStringBufferMemory(BaseMemory, BaseModel): + """Buffer for storing conversation memory.""" + + human_prefix: str = "Human" + ai_prefix: str = "AI" + """Prefix to use for AI generated responses.""" + buffer: str = "" + output_key: Optional[str] = None + input_key: Optional[str] = None + memory_key: str = "history" #: :meta private: + + @root_validator() + def validate_chains(cls, values: Dict) -> Dict: + """Validate that return messages is not True.""" + if values.get("return_messages", False): + raise ValueError( + "return_messages must be False for ConversationStringBufferMemory" + ) + return values + + @property + def memory_variables(self) -> List[str]: + """Will always return list of memory variables. + :meta private: + """ + return [self.memory_key] + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """Return history buffer.""" + return {self.memory_key: self.buffer} + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Save context from this conversation to buffer.""" + if self.input_key is None: + prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) + else: + prompt_input_key = self.input_key + if self.output_key is None: + if len(outputs) != 1: + raise ValueError(f"One output key expected, got {outputs.keys()}") + output_key = list(outputs.keys())[0] + else: + output_key = self.output_key + human = f"{self.human_prefix}: " + inputs[prompt_input_key] + ai = f"{self.ai_prefix}: " + outputs[output_key] + self.buffer += "\n" + "\n".join([human, ai]) + + def clear(self) -> None: + """Clear memory contents.""" + self.buffer = ""