forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
from typing import Any, Dict, List
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.schema import BaseMemory
|
|
|
|
|
|
class CombinedMemory(BaseMemory, BaseModel):
|
|
"""Class for combining multiple memories' data together."""
|
|
|
|
memories: List[BaseMemory]
|
|
"""For tracking all the memories that should be accessed."""
|
|
|
|
@property
|
|
def memory_variables(self) -> List[str]:
|
|
"""All the memory variables that this instance provides."""
|
|
"""Collected from the all the linked memories."""
|
|
|
|
memory_variables = []
|
|
|
|
for memory in self.memories:
|
|
memory_variables.extend(memory.memory_variables)
|
|
|
|
return memory_variables
|
|
|
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
"""Load all vars from sub-memories."""
|
|
memory_data: Dict[str, Any] = {}
|
|
|
|
# Collect vars from all sub-memories
|
|
for memory in self.memories:
|
|
data = memory.load_memory_variables(inputs)
|
|
memory_data = {
|
|
**memory_data,
|
|
**data,
|
|
}
|
|
|
|
return memory_data
|
|
|
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
"""Save context from this session for every memory."""
|
|
# Save context for all sub-memories
|
|
for memory in self.memories:
|
|
memory.save_context(inputs, outputs)
|
|
|
|
def clear(self) -> None:
|
|
"""Clear context from this session for every memory."""
|
|
for memory in self.memories:
|
|
memory.clear()
|