VectorStoreRetrieverMemory exclude additional input keys feature (#7941)

- Description: Added a parameter in VectorStoreRetrieverMemory which
filters the input given by the key when constructing the buffering the
document for Vector. This feature is helpful if you have certain inputs
apart from the VectorMemory's own memory_key that needs to be ignored
e.g when using combined memory, we might need to filter the memory_key
of the other memory, Please see the issue.
  - Issue: #7695
  - Tag maintainer: @rlancemartin, @eyurtsev

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Mohammad Mohtashim 2023-07-20 19:23:27 +05:00 committed by GitHub
parent d593833e4d
commit 453d4c3a99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,6 @@
"""Class for a VectorStore-backed memory object."""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from pydantic import Field
@ -25,6 +25,9 @@ class VectorStoreRetrieverMemory(BaseMemory):
return_docs: bool = False
"""Whether or not to return the result of querying the database directly."""
exclude_input_keys: Sequence[str] = Field(default_factory=tuple)
"""Input keys to exclude in addition to memory key when constructing the document"""
@property
def memory_variables(self) -> List[str]:
"""The list of keys emitted from the load_memory_variables method."""
@ -55,10 +58,13 @@ class VectorStoreRetrieverMemory(BaseMemory):
) -> List[Document]:
"""Format context from this conversation to buffer."""
# Each document should only include the current turn, not the chat history
filtered_inputs = {k: v for k, v in inputs.items() if k != self.memory_key}
exclude = set(self.exclude_input_keys)
exclude.add(self.memory_key)
filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude}
texts = [
f"{k}: {v}"
for k, v in list(filtered_inputs.items()) + list(outputs.items())
if k not in exclude
]
page_content = "\n".join(texts)
return [Document(page_content=page_content)]