|
|
|
@ -1,7 +1,9 @@
|
|
|
|
|
import warnings
|
|
|
|
|
from typing import Any, Dict, List, Set
|
|
|
|
|
|
|
|
|
|
from pydantic import validator
|
|
|
|
|
|
|
|
|
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
|
|
|
from langchain.schema import BaseMemory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -27,6 +29,19 @@ class CombinedMemory(BaseMemory):
|
|
|
|
|
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
@validator("memories")
|
|
|
|
|
def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]:
|
|
|
|
|
"""Check that if memories are of type BaseChatMemory that input keys exist."""
|
|
|
|
|
for val in value:
|
|
|
|
|
if isinstance(val, BaseChatMemory):
|
|
|
|
|
if val.input_key is None:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"When using CombinedMemory, "
|
|
|
|
|
"input keys should be so the input is known. "
|
|
|
|
|
f" Was not set on {val}"
|
|
|
|
|
)
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def memory_variables(self) -> List[str]:
|
|
|
|
|
"""All the memory variables that this instance provides."""
|
|
|
|
|