From 19a9fa16a9df812f04f5b7dd9359b0f7b2a0435b Mon Sep 17 00:00:00 2001 From: Shobith Alva Date: Sun, 11 Dec 2022 07:09:06 -0800 Subject: [PATCH] Add `clear()` method for `Memory` (#305) a simple helper to clear the buffer in `Conversation*Memory` classes --- langchain/chains/base.py | 4 ++++ langchain/chains/conversation/memory.py | 12 +++++++++++ tests/unit_tests/chains/test_conversation.py | 21 ++++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 44f0bdbf95..e041cdea29 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -29,6 +29,10 @@ class Memory(BaseModel, ABC): def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save the context of this model run to memory.""" + @abstractmethod + def clear(self) -> None: + """Clear memory contents.""" + def _get_verbosity() -> bool: return langchain.verbose diff --git a/langchain/chains/conversation/memory.py b/langchain/chains/conversation/memory.py index 4ab6cbbd95..ebdbb7efcf 100644 --- a/langchain/chains/conversation/memory.py +++ b/langchain/chains/conversation/memory.py @@ -46,6 +46,10 @@ class ConversationBufferMemory(Memory, BaseModel): ai = "AI: " + outputs[list(outputs.keys())[0]] self.buffer += "\n" + "\n".join([human, ai]) + def clear(self) -> None: + """Clear memory contents.""" + self.buffer = "" + class ConversationalBufferWindowMemory(Memory, BaseModel): """Buffer for storing conversation memory.""" @@ -75,6 +79,10 @@ class ConversationalBufferWindowMemory(Memory, BaseModel): ai = "AI: " + outputs[list(outputs.keys())[0]] self.buffer.append("\n".join([human, ai])) + def clear(self) -> None: + """Clear memory contents.""" + self.buffer = [] + class ConversationSummaryMemory(Memory, BaseModel): """Conversation summarizer to memory.""" @@ -118,3 +126,7 @@ class ConversationSummaryMemory(Memory, BaseModel): new_lines = "\n".join([human, ai]) chain = LLMChain(llm=self.llm, prompt=self.prompt) self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines) + + def clear(self) -> None: + """Clear memory contents.""" + self.buffer = "" diff --git a/tests/unit_tests/chains/test_conversation.py b/tests/unit_tests/chains/test_conversation.py index 82653986b0..fd7eb55fb1 100644 --- a/tests/unit_tests/chains/test_conversation.py +++ b/tests/unit_tests/chains/test_conversation.py @@ -4,6 +4,7 @@ import pytest from langchain.chains.base import Memory from langchain.chains.conversation.base import ConversationChain from langchain.chains.conversation.memory import ( + ConversationalBufferWindowMemory, ConversationBufferMemory, ConversationSummaryMemory, ) @@ -66,3 +67,23 @@ def test_conversation_memory(memory: Memory) -> None: bad_outputs = {"foo": "bar", "foo1": "bar"} with pytest.raises(ValueError): memory.save_context(good_inputs, bad_outputs) + + +@pytest.mark.parametrize( + "memory", + [ + ConversationBufferMemory(memory_key="baz"), + ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"), + ConversationalBufferWindowMemory(memory_key="baz"), + ], +) +def test_clearing_conversation_memory(memory: Memory) -> None: + """Test clearing the conversation memory.""" + # This is a good input because the input is not the same as baz. + good_inputs = {"foo": "bar", "baz": "foo"} + # This is a good output because these is one variable. + good_outputs = {"bar": "foo"} + memory.save_context(good_inputs, good_outputs) + + memory.clear() + assert memory.load_memory_variables({}) == {"baz": ""}