|
|
|
@ -4,6 +4,7 @@ import pytest
|
|
|
|
|
from langchain.chains.base import Memory
|
|
|
|
|
from langchain.chains.conversation.base import ConversationChain
|
|
|
|
|
from langchain.chains.conversation.memory import (
|
|
|
|
|
ConversationalBufferWindowMemory,
|
|
|
|
|
ConversationBufferMemory,
|
|
|
|
|
ConversationSummaryMemory,
|
|
|
|
|
)
|
|
|
|
@ -66,3 +67,23 @@ def test_conversation_memory(memory: Memory) -> None:
|
|
|
|
|
bad_outputs = {"foo": "bar", "foo1": "bar"}
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
memory.save_context(good_inputs, bad_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"memory",
|
|
|
|
|
[
|
|
|
|
|
ConversationBufferMemory(memory_key="baz"),
|
|
|
|
|
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
|
|
|
|
|
ConversationalBufferWindowMemory(memory_key="baz"),
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
def test_clearing_conversation_memory(memory: Memory) -> None:
|
|
|
|
|
"""Test clearing the conversation memory."""
|
|
|
|
|
# This is a good input because the input is not the same as baz.
|
|
|
|
|
good_inputs = {"foo": "bar", "baz": "foo"}
|
|
|
|
|
# This is a good output because these is one variable.
|
|
|
|
|
good_outputs = {"bar": "foo"}
|
|
|
|
|
memory.save_context(good_inputs, good_outputs)
|
|
|
|
|
|
|
|
|
|
memory.clear()
|
|
|
|
|
assert memory.load_memory_variables({}) == {"baz": ""}
|
|
|
|
|