feat: Added filtering option to FAISS vectorstore (#5966)

Inspired by the filtering capability available in ChromaDB, added the
same functionality to the FAISS vectorestore as well. Since FAISS does
not have an inbuilt method of filtering used the approach suggested in
this [thread](https://github.com/facebookresearch/faiss/issues/1079)
Langchain Issue inspiration:
https://github.com/hwchase17/langchain/issues/4572

- [x] Added filtering capability to semantic similarly and MMR
- [x] Added test cases for filtering in
`tests/integration_tests/vectorstores/test_faiss.py`

#### Who can review?

Tag maintainers/contributors who might be interested:

  VectorStores / Retrievers / Memory
  - @dev2049
  - @hwchase17
searx_updates
Akhil Vempali 11 months ago committed by GitHub
parent 6e90406e0f
commit d7d629911b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,20 +40,12 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"id": "47f9b495-88f1-4286-8d5d-1416103931a7",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OpenAI API Key: ········\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
@ -66,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "aac9563e",
"metadata": {
"tags": []
@ -81,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 10,
"id": "a3c3999a",
"metadata": {
"tags": []
@ -99,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 11,
"id": "5eabdb75",
"metadata": {
"tags": []
@ -114,7 +106,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 12,
"id": "4b172de8",
"metadata": {
"tags": []
@ -150,7 +142,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 13,
"id": "186ee1d8",
"metadata": {},
"outputs": [],
@ -160,18 +152,18 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 14,
"id": "284e04b5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Document(page_content='In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \\n\\nWe cannot let this happen. \\n\\nTonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', lookup_str='', metadata={'source': '../../state_of_the_union.txt'}, lookup_index=0),\n",
" 0.3914415)"
"(Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': '../../../state_of_the_union.txt'}),\n",
" 0.36913747)"
]
},
"execution_count": 7,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@ -191,7 +183,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 15,
"id": "b558ebb7",
"metadata": {},
"outputs": [],
@ -212,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 16,
"id": "428a6816",
"metadata": {},
"outputs": [],
@ -222,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 17,
"id": "56d1841c",
"metadata": {},
"outputs": [],
@ -232,7 +224,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 18,
"id": "39055525",
"metadata": {},
"outputs": [],
@ -242,17 +234,17 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 19,
"id": "98378c4e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Document(page_content='In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \\n\\nWe cannot let this happen. \\n\\nTonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', lookup_str='', metadata={'source': '../../state_of_the_union.txt'}, lookup_index=0)"
"Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': '../../../state_of_the_union.txt'})"
]
},
"execution_count": 13,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@ -273,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 20,
"id": "6dfd2b78",
"metadata": {},
"outputs": [],
@ -284,17 +276,17 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 21,
"id": "29960da7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'e0b74348-6c93-4893-8764-943139ec1d17': Document(page_content='foo', lookup_str='', metadata={}, lookup_index=0)}"
"{'068c473b-d420-487a-806b-fb0ccea7f711': Document(page_content='foo', metadata={})}"
]
},
"execution_count": 8,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@ -305,17 +297,17 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 22,
"id": "83392605",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'bdc50ae3-a1bb-4678-9260-1b0979578f40': Document(page_content='bar', lookup_str='', metadata={}, lookup_index=0)}"
"{'807e0c63-13f6-4070-9774-5c6f0fbb9866': Document(page_content='bar', metadata={})}"
]
},
"execution_count": 9,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@ -326,7 +318,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 23,
"id": "a3fcc1c7",
"metadata": {},
"outputs": [],
@ -336,18 +328,18 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 24,
"id": "41c51f89",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'e0b74348-6c93-4893-8764-943139ec1d17': Document(page_content='foo', lookup_str='', metadata={}, lookup_index=0),\n",
" 'd5211050-c777-493d-8825-4800e74cfdb6': Document(page_content='bar', lookup_str='', metadata={}, lookup_index=0)}"
"{'068c473b-d420-487a-806b-fb0ccea7f711': Document(page_content='foo', metadata={}),\n",
" '807e0c63-13f6-4070-9774-5c6f0fbb9866': Document(page_content='bar', metadata={})}"
]
},
"execution_count": 11,
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@ -356,13 +348,140 @@
"db1.docstore._dict"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f4294b96",
"metadata": {},
"source": [
"## Similarity Search with filtering\n",
"FAISS vectorstore can also support filtering, since the FAISS does not natively support filtering we have to do it manually. This is done by first fetching more results than `k` and then filtering them. You can filter the documents based on metadata. You can also set the `fetch_k` parameter when calling any search method to set how many documents you want to fetch before filtering. Here is a small example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f80b60de",
"execution_count": 25,
"id": "d5bf812c",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Content: foo, Metadata: {'page': 1}, Score: 5.159960813797904e-15\n",
"Content: foo, Metadata: {'page': 2}, Score: 5.159960813797904e-15\n",
"Content: foo, Metadata: {'page': 3}, Score: 5.159960813797904e-15\n",
"Content: foo, Metadata: {'page': 4}, Score: 5.159960813797904e-15\n"
]
}
],
"source": [
"from langchain.schema import Document\n",
"list_of_documents = [\n",
" Document(page_content=\"foo\", metadata=dict(page=1)),\n",
" Document(page_content=\"bar\", metadata=dict(page=1)),\n",
" Document(page_content=\"foo\", metadata=dict(page=2)),\n",
" Document(page_content=\"barbar\", metadata=dict(page=2)),\n",
" Document(page_content=\"foo\", metadata=dict(page=3)),\n",
" Document(page_content=\"bar burr\", metadata=dict(page=3)),\n",
" Document(page_content=\"foo\", metadata=dict(page=4)),\n",
" Document(page_content=\"bar bruh\", metadata=dict(page=4))\n",
"]\n",
"db = FAISS.from_documents(list_of_documents, embeddings)\n",
"results_with_scores = db.similarity_search_with_score(\"foo\")\n",
"for doc, score in results_with_scores:\n",
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3d33c126",
"metadata": {},
"source": [
"Now we make the same query call but we filter for only `page = 1` "
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "83159330",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Content: foo, Metadata: {'page': 1}, Score: 5.159960813797904e-15\n",
"Content: bar, Metadata: {'page': 1}, Score: 0.3131446838378906\n"
]
}
],
"source": [
"results_with_scores = db.similarity_search_with_score(\"foo\", filter=dict(page=1))\n",
"for doc, score in results_with_scores:\n",
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0be136e0",
"metadata": {},
"source": [
"Same thing can be done with the `max_marginal_relevance_search` as well."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "432c6980",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Content: foo, Metadata: {'page': 1}\n",
"Content: bar, Metadata: {'page': 1}\n"
]
}
],
"source": [
"results = db.max_marginal_relevance_search(\"foo\", filter=dict(page=1))\n",
"for doc in results:\n",
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "1b4ecd86",
"metadata": {},
"source": [
"Here is an example of how to set `fetch_k` parameter when calling `similarity_search`. Usually you would want the `fetch_k` parameter >> `k` parameter. This is because the `fetch_k` parameter is the number of documents that will be fetched before filtering. If you set `fetch_k` to a low number, you might not get enough documents to filter from."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "1fd60fd1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Content: foo, Metadata: {'page': 1}, Score: 5.159960813797904e-15\n",
"Content: bar, Metadata: {'page': 1}, Score: 0.3131446838378906\n"
]
}
],
"source": [
"results = db.similarity_search(\"foo\", filter=dict(page=1), k=1, fetch_k=4)\n",
"for doc, score in results_with_scores:\n",
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}\")"
]
}
],
"metadata": {
@ -381,7 +500,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.16"
}
},
"nbformat": 4,

@ -180,13 +180,20 @@ class FAISS(VectorStore):
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns:
List of documents most similar to the query text and L2 distance
@ -196,7 +203,7 @@ class FAISS(VectorStore):
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k)
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
docs = []
for j, i in enumerate(indices[0]):
if i == -1:
@ -206,54 +213,96 @@ class FAISS(VectorStore):
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, scores[0][j]))
return docs
if filter is not None:
if all(doc.metadata.get(key) == value for key, value in filter.items()):
docs.append((doc, scores[0][j]))
else:
docs.append((doc, scores[0][j]))
return docs[:k]
def similarity_search_with_score(
self, query: str, k: int = 4
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns:
List of documents most similar to the query text with
L2 distance in float. Lower score represents more similarity.
"""
embedding = self.embedding_function(query)
docs = self.similarity_search_with_score_by_vector(embedding, k)
docs = self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return docs
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns:
List of Documents most similar to the embedding.
"""
docs_and_scores = self.similarity_search_with_score_by_vector(embedding, k)
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return [doc for doc, _ in docs_and_scores]
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns:
List of Documents most similar to the query.
"""
docs_and_scores = self.similarity_search_with_score(query, k)
docs_and_scores = self.similarity_search_with_score(
query, k, filter=filter, fetch_k=fetch_k, **kwargs
)
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search_by_vector(
@ -262,6 +311,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -272,7 +322,8 @@ class FAISS(VectorStore):
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
fetch_k: Number of Documents to fetch before filtering to
pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
@ -280,7 +331,23 @@ class FAISS(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
_, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
_, indices = self.index.search(
np.array([embedding], dtype=np.float32),
fetch_k if filter is None else fetch_k * 2,
)
if filter is not None:
filtered_indices = []
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if all(doc.metadata.get(key) == value for key, value in filter.items()):
filtered_indices.append(i)
indices = np.array([filtered_indices])
# -1 happens when not enough docs are returned.
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
mmr_selected = maximal_marginal_relevance(
@ -308,6 +375,7 @@ class FAISS(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -318,7 +386,8 @@ class FAISS(VectorStore):
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
fetch_k: Number of Documents to fetch before filtering (if needed) to
pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
@ -328,7 +397,12 @@ class FAISS(VectorStore):
"""
embedding = self.embedding_function(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult=lambda_mult
embedding,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
**kwargs,
)
return docs
@ -522,6 +596,8 @@ class FAISS(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and their similarity scores on a scale from 0 to 1."""
@ -530,5 +606,11 @@ class FAISS(VectorStore):
"normalize_score_fn must be provided to"
" FAISS constructor to normalize scores"
)
docs_and_scores = self.similarity_search_with_score(query, k=k)
docs_and_scores = self.similarity_search_with_score(
query,
k=k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]

@ -74,6 +74,28 @@ def test_faiss_with_metadatas() -> None:
assert output == [Document(page_content="foo", metadata={"page": 0})]
def test_faiss_with_metadatas_and_filter() -> None:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
assert output == []
def test_faiss_search_not_found() -> None:
"""Test what happens when document is not found."""
texts = ["foo", "bar", "baz"]

Loading…
Cancel
Save