community: add support for callable filters in FAISS (#16190)

- **Description:**
Filtering in a FAISS vectorstores is very inflexible and doesn't allow
that many use case. I think supporting callable like this enables a lot:
regular expressions, condition on multiple keys etc. **Note** I had to
manually alter a test. I don't understand if it was falty to begin with
or if there is something funky going on.
- **Issue:** None
- **Dependencies:** None
- **Twitter handle:** None

Signed-off-by: thiswillbeyourgithub <26625900+thiswillbeyourgithub@users.noreply.github.com>
pull/15265/head^2
thiswillbeyourgithub 5 months ago committed by GitHub
parent 1703fe2361
commit 1d082359ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -416,7 +416,7 @@
"metadata": {},
"source": [
"## Similarity Search with filtering\n",
"FAISS vectorstore can also support filtering, since the FAISS does not natively support filtering we have to do it manually. This is done by first fetching more results than `k` and then filtering them. You can filter the documents based on metadata. You can also set the `fetch_k` parameter when calling any search method to set how many documents you want to fetch before filtering. Here is a small example:"
"FAISS vectorstore can also support filtering, since the FAISS does not natively support filtering we have to do it manually. This is done by first fetching more results than `k` and then filtering them. This filter is either a callble that takes as input a metadata dict and returns a bool, or a metadata dict where each missing key is ignored and each present k must be in a list of values. You can also set the `fetch_k` parameter when calling any search method to set how many documents you want to fetch before filtering. Here is a small example:"
]
},
{
@ -480,6 +480,8 @@
],
"source": [
"results_with_scores = db.similarity_search_with_score(\"foo\", filter=dict(page=1))\n",
"# Or with a callable:\n",
"# results_with_scores = db.similarity_search_with_score(\"foo\", filter=lambda d: d[\"page\"] == 1)\n",
"for doc, score in results_with_scores:\n",
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}\")"
]

@ -273,7 +273,7 @@ class FAISS(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@ -282,7 +282,9 @@ class FAISS(VectorStore):
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
filter (Optional[Union[Callable, Dict[str, Any]]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
**kwargs: kwargs to be passed to similarity search. Can include:
@ -299,6 +301,27 @@ class FAISS(VectorStore):
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
docs = []
if filter is not None:
if isinstance(filter, dict):
def filter_func(metadata):
if all(
metadata.get(key) in value
if isinstance(value, list)
else metadata.get(key) == value
for key, value in filter.items()
):
return True
return False
elif callable(filter):
filter_func = filter
else:
raise ValueError(
"filter must be a dict of metadata or "
f"a callable, not {type(filter)}"
)
for j, i in enumerate(indices[0]):
if i == -1:
# This happens when not enough docs are returned.
@ -307,13 +330,8 @@ class FAISS(VectorStore):
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if filter is not None:
filter = {
key: [value] if not isinstance(value, list) else value
for key, value in filter.items()
}
if all(doc.metadata.get(key) in value for key, value in filter.items()):
docs.append((doc, scores[0][j]))
if filter is not None and filter_func(doc.metadata):
docs.append((doc, scores[0][j]))
else:
docs.append((doc, scores[0][j]))
@ -336,7 +354,7 @@ class FAISS(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@ -345,7 +363,10 @@ class FAISS(VectorStore):
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, Any]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
**kwargs: kwargs to be passed to similarity search. Can include:
@ -372,7 +393,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@ -381,7 +402,10 @@ class FAISS(VectorStore):
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, str]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
@ -403,7 +427,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@ -412,7 +436,10 @@ class FAISS(VectorStore):
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, str]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
@ -443,7 +470,10 @@ class FAISS(VectorStore):
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, str]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
@ -463,7 +493,7 @@ class FAISS(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
@ -472,7 +502,10 @@ class FAISS(VectorStore):
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
filter (Optional[Dict[str, str]]): Filter by metadata.
Defaults to None. If a callable, it must take as input the
metadata dict of Document and return a bool.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
@ -492,7 +525,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
@ -517,7 +550,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
@ -545,7 +578,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs and their similarity scores selected using the maximal marginal
relevance.
@ -572,6 +605,24 @@ class FAISS(VectorStore):
)
if filter is not None:
filtered_indices = []
if isinstance(filter, dict):
def filter_func(metadata):
if all(
metadata.get(key) in value
if isinstance(value, list)
else metadata.get(key) == value
for key, value in filter.items()
):
return True
return False
elif callable(filter):
filter_func = filter
else:
raise ValueError(
"filter must be a dict of metadata or "
f"a callable, not {type(filter)}"
)
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
@ -580,12 +631,7 @@ class FAISS(VectorStore):
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if all(
doc.metadata.get(key) in value
if isinstance(value, list)
else doc.metadata.get(key) == value
for key, value in filter.items()
):
if filter_func(doc.metadata):
filtered_indices.append(i)
indices = np.array([filtered_indices])
# -1 happens when not enough docs are returned.
@ -617,7 +663,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs and their similarity scores selected using the maximal marginal
relevance asynchronously.
@ -655,7 +701,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -686,7 +732,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance asynchronously.
@ -719,7 +765,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -756,7 +802,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance asynchronously.
@ -1110,7 +1156,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@ -1139,7 +1185,7 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
filter: Optional[Union[Callable, Dict[str, Any]]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:

@ -307,6 +307,9 @@ def test_faiss_mmr_with_metadatas_and_filter() -> None:
assert len(output) == 1
assert output[0][0] == Document(page_content="foo", metadata={"page": 1})
assert output[0][1] == 0.0
assert output == docsearch.max_marginal_relevance_search_with_score_by_vector(
query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] == 1
)
@pytest.mark.requires("faiss")
@ -321,6 +324,12 @@ async def test_faiss_async_mmr_with_metadatas_and_filter() -> None:
assert len(output) == 1
assert output[0][0] == Document(page_content="foo", metadata={"page": 1})
assert output[0][1] == 0.0
assert (
output
== await docsearch.amax_marginal_relevance_search_with_score_by_vector(
query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] == 1
)
)
@pytest.mark.requires("faiss")
@ -336,6 +345,9 @@ def test_faiss_mmr_with_metadatas_and_list_filter() -> None:
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
assert output[0][1] == 0.0
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
assert output == docsearch.max_marginal_relevance_search_with_score_by_vector(
query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] in [0, 1, 2]
)
@pytest.mark.requires("faiss")
@ -351,6 +363,11 @@ async def test_faiss_async_mmr_with_metadatas_and_list_filter() -> None:
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
assert output[0][1] == 0.0
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
assert output == (
await docsearch.amax_marginal_relevance_search_with_score_by_vector(
query_vec, k=10, lambda_mult=0.1, filter=lambda di: di["page"] in [0, 1, 2]
)
)
@pytest.mark.requires("faiss")
@ -421,7 +438,11 @@ def test_faiss_with_metadatas_and_filter() -> None:
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
assert output == [Document(page_content="bar", metadata={"page": 1})]
assert output == [Document(page_content="foo", metadata={"page": 0})]
assert output != [Document(page_content="bar", metadata={"page": 1})]
assert output == docsearch.similarity_search(
"foo", k=1, filter=lambda di: di["page"] == 1
)
@pytest.mark.requires("faiss")
@ -444,7 +465,11 @@ async def test_faiss_async_with_metadatas_and_filter() -> None:
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = await docsearch.asimilarity_search("foo", k=1, filter={"page": 1})
assert output == [Document(page_content="bar", metadata={"page": 1})]
assert output == [Document(page_content="foo", metadata={"page": 0})]
assert output != [Document(page_content="bar", metadata={"page": 1})]
assert output == await docsearch.asimilarity_search(
"foo", k=1, filter=lambda di: di["page"] == 1
)
@pytest.mark.requires("faiss")
@ -474,6 +499,9 @@ def test_faiss_with_metadatas_and_list_filter() -> None:
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foor", k=1, filter={"page": [0, 1, 2]})
assert output == [Document(page_content="foo", metadata={"page": 0})]
assert output == docsearch.similarity_search(
"foor", k=1, filter=lambda di: di["page"] in [0, 1, 2]
)
@pytest.mark.requires("faiss")
@ -503,6 +531,9 @@ async def test_faiss_async_with_metadatas_and_list_filter() -> None:
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = await docsearch.asimilarity_search("foor", k=1, filter={"page": [0, 1, 2]})
assert output == [Document(page_content="foo", metadata={"page": 0})]
assert output == await docsearch.asimilarity_search(
"foor", k=1, filter=lambda di: di["page"] in [0, 1, 2]
)
@pytest.mark.requires("faiss")

Loading…
Cancel
Save