Harrison/combined memory (#3935)

Co-authored-by: engkheng <60956360+outday29@users.noreply.github.com>
pull/3942/head
Harrison Chase 1 year ago committed by GitHub
parent c4cb55a0c5
commit cd3f8582cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,6 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Set
from pydantic import validator
from langchain.schema import BaseMemory
@ -9,6 +11,22 @@ class CombinedMemory(BaseMemory):
memories: List[BaseMemory]
"""For tracking all the memories that should be accessed."""
@validator("memories")
def check_repeated_memory_variable(
cls, value: List[BaseMemory]
) -> List[BaseMemory]:
all_variables: Set[str] = set()
for val in value:
overlap = all_variables.intersection(val.memory_variables)
if overlap:
raise ValueError(
f"The same variables {overlap} are found in multiple"
"memory object, which is not allowed by CombinedMemory."
)
all_variables |= set(val.memory_variables)
return value
@property
def memory_variables(self) -> List[str]:
"""All the memory variables that this instance provides."""

@ -0,0 +1,37 @@
"""Test for CombinedMemory class"""
# from langchain.prompts import PromptTemplate
from typing import List
import pytest
from langchain.memory import CombinedMemory, ConversationBufferMemory
@pytest.fixture()
def example_memory() -> List[ConversationBufferMemory]:
example_1 = ConversationBufferMemory(memory_key="foo")
example_2 = ConversationBufferMemory(memory_key="bar")
example_3 = ConversationBufferMemory(memory_key="bar")
return [example_1, example_2, example_3]
def test_basic_functionality(example_memory: List[ConversationBufferMemory]) -> None:
"""Test basic functionality of methods exposed by class"""
combined_memory = CombinedMemory(memories=[example_memory[0], example_memory[1]])
assert combined_memory.memory_variables == ["foo", "bar"]
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
combined_memory.save_context(
{"input": "Hello there"}, {"output": "Hello, how can I help you?"}
)
assert combined_memory.load_memory_variables({}) == {
"foo": "Human: Hello there\nAI: Hello, how can I help you?",
"bar": "Human: Hello there\nAI: Hello, how can I help you?",
}
combined_memory.clear()
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
def test_repeated_memory_var(example_memory: List[ConversationBufferMemory]) -> None:
"""Test raising error when repeated memory variables found"""
with pytest.raises(ValueError):
CombinedMemory(memories=[example_memory[1], example_memory[2]])
Loading…
Cancel
Save