mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
38 lines
1.6 KiB
Python
38 lines
1.6 KiB
Python
|
"""Test for CombinedMemory class"""
|
||
|
# from langchain.prompts import PromptTemplate
|
||
|
from typing import List
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from langchain.memory import CombinedMemory, ConversationBufferMemory
|
||
|
|
||
|
|
||
|
@pytest.fixture()
|
||
|
def example_memory() -> List[ConversationBufferMemory]:
|
||
|
example_1 = ConversationBufferMemory(memory_key="foo")
|
||
|
example_2 = ConversationBufferMemory(memory_key="bar")
|
||
|
example_3 = ConversationBufferMemory(memory_key="bar")
|
||
|
return [example_1, example_2, example_3]
|
||
|
|
||
|
|
||
|
def test_basic_functionality(example_memory: List[ConversationBufferMemory]) -> None:
|
||
|
"""Test basic functionality of methods exposed by class"""
|
||
|
combined_memory = CombinedMemory(memories=[example_memory[0], example_memory[1]])
|
||
|
assert combined_memory.memory_variables == ["foo", "bar"]
|
||
|
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
|
||
|
combined_memory.save_context(
|
||
|
{"input": "Hello there"}, {"output": "Hello, how can I help you?"}
|
||
|
)
|
||
|
assert combined_memory.load_memory_variables({}) == {
|
||
|
"foo": "Human: Hello there\nAI: Hello, how can I help you?",
|
||
|
"bar": "Human: Hello there\nAI: Hello, how can I help you?",
|
||
|
}
|
||
|
combined_memory.clear()
|
||
|
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
|
||
|
|
||
|
|
||
|
def test_repeated_memory_var(example_memory: List[ConversationBufferMemory]) -> None:
|
||
|
"""Test raising error when repeated memory variables found"""
|
||
|
with pytest.raises(ValueError):
|
||
|
CombinedMemory(memories=[example_memory[1], example_memory[2]])
|