Add `clear()` method for `Memory` (#305)

a simple helper to clear the buffer in `Conversation*Memory` classes
harrison/agent_multi_inputs^2
Shobith Alva 1 year ago committed by GitHub
parent e02d6b2288
commit 19a9fa16a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,6 +29,10 @@ class Memory(BaseModel, ABC):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""
def _get_verbosity() -> bool:
return langchain.verbose

@ -46,6 +46,10 @@ class ConversationBufferMemory(Memory, BaseModel):
ai = "AI: " + outputs[list(outputs.keys())[0]]
self.buffer += "\n" + "\n".join([human, ai])
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = ""
class ConversationalBufferWindowMemory(Memory, BaseModel):
"""Buffer for storing conversation memory."""
@ -75,6 +79,10 @@ class ConversationalBufferWindowMemory(Memory, BaseModel):
ai = "AI: " + outputs[list(outputs.keys())[0]]
self.buffer.append("\n".join([human, ai]))
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = []
class ConversationSummaryMemory(Memory, BaseModel):
"""Conversation summarizer to memory."""
@ -118,3 +126,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
new_lines = "\n".join([human, ai])
chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines)
def clear(self) -> None:
"""Clear memory contents."""
self.buffer = ""

@ -4,6 +4,7 @@ import pytest
from langchain.chains.base import Memory
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.conversation.memory import (
ConversationalBufferWindowMemory,
ConversationBufferMemory,
ConversationSummaryMemory,
)
@ -66,3 +67,23 @@ def test_conversation_memory(memory: Memory) -> None:
bad_outputs = {"foo": "bar", "foo1": "bar"}
with pytest.raises(ValueError):
memory.save_context(good_inputs, bad_outputs)
@pytest.mark.parametrize(
"memory",
[
ConversationBufferMemory(memory_key="baz"),
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
ConversationalBufferWindowMemory(memory_key="baz"),
],
)
def test_clearing_conversation_memory(memory: Memory) -> None:
"""Test clearing the conversation memory."""
# This is a good input because the input is not the same as baz.
good_inputs = {"foo": "bar", "baz": "foo"}
# This is a good output because these is one variable.
good_outputs = {"bar": "foo"}
memory.save_context(good_inputs, good_outputs)
memory.clear()
assert memory.load_memory_variables({}) == {"baz": ""}

Loading…
Cancel
Save