buffer memory old version (#1581)

bring back an older version of memory since people seem to be using it
more widely
tool-patch
Harrison Chase 1 year ago committed by GitHub
parent 5b0e747f9a
commit c1dc784a3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,9 @@
"""Memory modules for conversation prompts."""
from langchain.memory.buffer import ConversationBufferMemory
from langchain.memory.buffer import (
ConversationBufferMemory,
ConversationStringBufferMemory,
)
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain.memory.combined import CombinedMemory
from langchain.memory.entity import ConversationEntityMemory
@ -18,4 +21,5 @@ __all__ = [
"ConversationEntityMemory",
"ConversationBufferMemory",
"CombinedMemory",
"ConversationStringBufferMemory",
]

@ -1,4 +1,7 @@
from langchain.memory.buffer import ConversationBufferMemory
from langchain.memory.buffer import (
ConversationBufferMemory,
ConversationStringBufferMemory,
)
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain.memory.chat_memory import ChatMessageHistory
from langchain.memory.combined import CombinedMemory
@ -18,4 +21,5 @@ __all__ = [
"ConversationEntityMemory",
"ConversationSummaryMemory",
"ChatMessageHistory",
"ConversationStringBufferMemory",
]

@ -1,9 +1,9 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from pydantic import BaseModel, root_validator
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.utils import get_buffer_string
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
class ConversationBufferMemory(BaseChatMemory, BaseModel):
@ -36,3 +36,55 @@ class ConversationBufferMemory(BaseChatMemory, BaseModel):
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
class ConversationStringBufferMemory(BaseMemory, BaseModel):
"""Buffer for storing conversation memory."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
buffer: str = ""
output_key: Optional[str] = None
input_key: Optional[str] = None
memory_key: str = "history" #: :meta private:
@root_validator()
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that return messages is not True."""
if values.get("return_messages", False):
raise ValueError(
"return_messages must be False for ConversationStringBufferMemory"
)
return values
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[output_key]
self.buffer += "\n" + "\n".join([human, ai])
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = ""

Loading…
Cancel
Save