feat: Added filtering option to FAISS vectorstore (#5966)

Inspired by the filtering capability available in ChromaDB, added the
same functionality to the FAISS vectorestore as well. Since FAISS does
not have an inbuilt method of filtering used the approach suggested in
this [thread](https://github.com/facebookresearch/faiss/issues/1079)
Langchain Issue inspiration:
https://github.com/hwchase17/langchain/issues/4572

- [x] Added filtering capability to semantic similarly and MMR
- [x] Added test cases for filtering in
`tests/integration_tests/vectorstores/test_faiss.py`

#### Who can review?

Tag maintainers/contributors who might be interested:

  VectorStores / Retrievers / Memory
  - @dev2049
  - @hwchase17
searx_updates
Akhil Vempali 12 months ago committed by GitHub
parent 6e90406e0f
commit d7d629911b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,20 +40,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 2,
"id": "47f9b495-88f1-4286-8d5d-1416103931a7", "id": "47f9b495-88f1-4286-8d5d-1416103931a7",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"OpenAI API Key: ········\n"
]
}
],
"source": [ "source": [
"import os\n", "import os\n",
"import getpass\n", "import getpass\n",
@ -66,7 +58,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"id": "aac9563e", "id": "aac9563e",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -81,7 +73,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 10,
"id": "a3c3999a", "id": "a3c3999a",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -99,7 +91,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 11,
"id": "5eabdb75", "id": "5eabdb75",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -114,7 +106,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 12,
"id": "4b172de8", "id": "4b172de8",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -150,7 +142,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 13,
"id": "186ee1d8", "id": "186ee1d8",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -160,18 +152,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 14,
"id": "284e04b5", "id": "284e04b5",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(Document(page_content='In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \\n\\nWe cannot let this happen. \\n\\nTonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', lookup_str='', metadata={'source': '../../state_of_the_union.txt'}, lookup_index=0),\n", "(Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': '../../../state_of_the_union.txt'}),\n",
" 0.3914415)" " 0.36913747)"
] ]
}, },
"execution_count": 7, "execution_count": 14,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -191,7 +183,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 15,
"id": "b558ebb7", "id": "b558ebb7",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -212,7 +204,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 16,
"id": "428a6816", "id": "428a6816",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -222,7 +214,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 17,
"id": "56d1841c", "id": "56d1841c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -232,7 +224,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 18,
"id": "39055525", "id": "39055525",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -242,17 +234,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 19,
"id": "98378c4e", "id": "98378c4e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"Document(page_content='In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \\n\\nWe cannot let this happen. \\n\\nTonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', lookup_str='', metadata={'source': '../../state_of_the_union.txt'}, lookup_index=0)" "Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': '../../../state_of_the_union.txt'})"
] ]
}, },
"execution_count": 13, "execution_count": 19,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -273,7 +265,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 20,
"id": "6dfd2b78", "id": "6dfd2b78",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -284,17 +276,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 21,
"id": "29960da7", "id": "29960da7",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'e0b74348-6c93-4893-8764-943139ec1d17': Document(page_content='foo', lookup_str='', metadata={}, lookup_index=0)}" "{'068c473b-d420-487a-806b-fb0ccea7f711': Document(page_content='foo', metadata={})}"
] ]
}, },
"execution_count": 8, "execution_count": 21,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -305,17 +297,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 22,
"id": "83392605", "id": "83392605",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'bdc50ae3-a1bb-4678-9260-1b0979578f40': Document(page_content='bar', lookup_str='', metadata={}, lookup_index=0)}" "{'807e0c63-13f6-4070-9774-5c6f0fbb9866': Document(page_content='bar', metadata={})}"
] ]
}, },
"execution_count": 9, "execution_count": 22,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -326,7 +318,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 23,
"id": "a3fcc1c7", "id": "a3fcc1c7",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -336,18 +328,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 24,
"id": "41c51f89", "id": "41c51f89",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'e0b74348-6c93-4893-8764-943139ec1d17': Document(page_content='foo', lookup_str='', metadata={}, lookup_index=0),\n", "{'068c473b-d420-487a-806b-fb0ccea7f711': Document(page_content='foo', metadata={}),\n",
" 'd5211050-c777-493d-8825-4800e74cfdb6': Document(page_content='bar', lookup_str='', metadata={}, lookup_index=0)}" " '807e0c63-13f6-4070-9774-5c6f0fbb9866': Document(page_content='bar', metadata={})}"
] ]
}, },
"execution_count": 11, "execution_count": 24,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -356,13 +348,140 @@
"db1.docstore._dict" "db1.docstore._dict"
] ]
}, },
{
"attachments": {},
"cell_type": "markdown",
"id": "f4294b96",
"metadata": {},
"source": [
"## Similarity Search with filtering\n",
"FAISS vectorstore can also support filtering, since the FAISS does not natively support filtering we have to do it manually. This is done by first fetching more results than `k` and then filtering them. You can filter the documents based on metadata. You can also set the `fetch_k` parameter when calling any search method to set how many documents you want to fetch before filtering. Here is a small example:"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 25,
"id": "f80b60de", "id": "d5bf812c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
"source": [] {
"name": "stdout",
"output_type": "stream",
"text": [
"Content: foo, Metadata: {'page': 1}, Score: 5.159960813797904e-15\n",
"Content: foo, Metadata: {'page': 2}, Score: 5.159960813797904e-15\n",
"Content: foo, Metadata: {'page': 3}, Score: 5.159960813797904e-15\n",
"Content: foo, Metadata: {'page': 4}, Score: 5.159960813797904e-15\n"
]
}
],
"source": [
"from langchain.schema import Document\n",
"list_of_documents = [\n",
" Document(page_content=\"foo\", metadata=dict(page=1)),\n",
" Document(page_content=\"bar\", metadata=dict(page=1)),\n",
" Document(page_content=\"foo\", metadata=dict(page=2)),\n",
" Document(page_content=\"barbar\", metadata=dict(page=2)),\n",
" Document(page_content=\"foo\", metadata=dict(page=3)),\n",
" Document(page_content=\"bar burr\", metadata=dict(page=3)),\n",
" Document(page_content=\"foo\", metadata=dict(page=4)),\n",
" Document(page_content=\"bar bruh\", metadata=dict(page=4))\n",
"]\n",
"db = FAISS.from_documents(list_of_documents, embeddings)\n",
"results_with_scores = db.similarity_search_with_score(\"foo\")\n",
"for doc, score in results_with_scores:\n",
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3d33c126",
"metadata": {},
"source": [
"Now we make the same query call but we filter for only `page = 1` "
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "83159330",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Content: foo, Metadata: {'page': 1}, Score: 5.159960813797904e-15\n",
"Content: bar, Metadata: {'page': 1}, Score: 0.3131446838378906\n"
]
}
],
"source": [
"results_with_scores = db.similarity_search_with_score(\"foo\", filter=dict(page=1))\n",
"for doc, score in results_with_scores:\n",
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0be136e0",
"metadata": {},
"source": [
"Same thing can be done with the `max_marginal_relevance_search` as well."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "432c6980",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Content: foo, Metadata: {'page': 1}\n",
"Content: bar, Metadata: {'page': 1}\n"
]
}
],
"source": [
"results = db.max_marginal_relevance_search(\"foo\", filter=dict(page=1))\n",
"for doc in results:\n",
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "1b4ecd86",
"metadata": {},
"source": [
"Here is an example of how to set `fetch_k` parameter when calling `similarity_search`. Usually you would want the `fetch_k` parameter >> `k` parameter. This is because the `fetch_k` parameter is the number of documents that will be fetched before filtering. If you set `fetch_k` to a low number, you might not get enough documents to filter from."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "1fd60fd1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Content: foo, Metadata: {'page': 1}, Score: 5.159960813797904e-15\n",
"Content: bar, Metadata: {'page': 1}, Score: 0.3131446838378906\n"
]
}
],
"source": [
"results = db.similarity_search(\"foo\", filter=dict(page=1), k=1, fetch_k=4)\n",
"for doc, score in results_with_scores:\n",
" print(f\"Content: {doc.page_content}, Metadata: {doc.metadata}, Score: {score}\")"
]
} }
], ],
"metadata": { "metadata": {
@ -381,7 +500,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.6" "version": "3.9.16"
} }
}, },
"nbformat": 4, "nbformat": 4,

@ -180,13 +180,20 @@ class FAISS(VectorStore):
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs) return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4 self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to query. """Return docs most similar to query.
Args: Args:
embedding: Embedding vector to look up documents similar to. embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns: Returns:
List of documents most similar to the query text and L2 distance List of documents most similar to the query text and L2 distance
@ -196,7 +203,7 @@ class FAISS(VectorStore):
vector = np.array([embedding], dtype=np.float32) vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2: if self._normalize_L2:
faiss.normalize_L2(vector) faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k) scores, indices = self.index.search(vector, k if filter is None else fetch_k)
docs = [] docs = []
for j, i in enumerate(indices[0]): for j, i in enumerate(indices[0]):
if i == -1: if i == -1:
@ -206,54 +213,96 @@ class FAISS(VectorStore):
doc = self.docstore.search(_id) doc = self.docstore.search(_id)
if not isinstance(doc, Document): if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}") raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, scores[0][j])) if filter is not None:
return docs if all(doc.metadata.get(key) == value for key, value in filter.items()):
docs.append((doc, scores[0][j]))
else:
docs.append((doc, scores[0][j]))
return docs[:k]
def similarity_search_with_score( def similarity_search_with_score(
self, query: str, k: int = 4 self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to query. """Return docs most similar to query.
Args: Args:
query: Text to look up documents similar to. query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns: Returns:
List of documents most similar to the query text with List of documents most similar to the query text with
L2 distance in float. Lower score represents more similarity. L2 distance in float. Lower score represents more similarity.
""" """
embedding = self.embedding_function(query) embedding = self.embedding_function(query)
docs = self.similarity_search_with_score_by_vector(embedding, k) docs = self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return docs return docs
def similarity_search_by_vector( def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to embedding vector. """Return docs most similar to embedding vector.
Args: Args:
embedding: Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns: Returns:
List of Documents most similar to the embedding. List of Documents most similar to the embedding.
""" """
docs_and_scores = self.similarity_search_with_score_by_vector(embedding, k) docs_and_scores = self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to query. """Return docs most similar to query.
Args: Args:
query: Text to look up documents similar to. query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
""" """
docs_and_scores = self.similarity_search_with_score(query, k) docs_and_scores = self.similarity_search_with_score(
query, k, filter=filter, fetch_k=fetch_k, **kwargs
)
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
@ -262,6 +311,7 @@ class FAISS(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
@ -272,7 +322,8 @@ class FAISS(VectorStore):
Args: Args:
embedding: Embedding to look up documents similar to. embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm. fetch_k: Number of Documents to fetch before filtering to
pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
@ -280,7 +331,23 @@ class FAISS(VectorStore):
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
_, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k) _, indices = self.index.search(
np.array([embedding], dtype=np.float32),
fetch_k if filter is None else fetch_k * 2,
)
if filter is not None:
filtered_indices = []
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if all(doc.metadata.get(key) == value for key, value in filter.items()):
filtered_indices.append(i)
indices = np.array([filtered_indices])
# -1 happens when not enough docs are returned. # -1 happens when not enough docs are returned.
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
mmr_selected = maximal_marginal_relevance( mmr_selected = maximal_marginal_relevance(
@ -308,6 +375,7 @@ class FAISS(VectorStore):
k: int = 4, k: int = 4,
fetch_k: int = 20, fetch_k: int = 20,
lambda_mult: float = 0.5, lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return docs selected using the maximal marginal relevance. """Return docs selected using the maximal marginal relevance.
@ -318,7 +386,8 @@ class FAISS(VectorStore):
Args: Args:
query: Text to look up documents similar to. query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4. k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm. fetch_k: Number of Documents to fetch before filtering (if needed) to
pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. to maximum diversity and 1 to minimum diversity.
@ -328,7 +397,12 @@ class FAISS(VectorStore):
""" """
embedding = self.embedding_function(query) embedding = self.embedding_function(query)
docs = self.max_marginal_relevance_search_by_vector( docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult=lambda_mult embedding,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
**kwargs,
) )
return docs return docs
@ -522,6 +596,8 @@ class FAISS(VectorStore):
self, self,
query: str, query: str,
k: int = 4, k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs and their similarity scores on a scale from 0 to 1.""" """Return docs and their similarity scores on a scale from 0 to 1."""
@ -530,5 +606,11 @@ class FAISS(VectorStore):
"normalize_score_fn must be provided to" "normalize_score_fn must be provided to"
" FAISS constructor to normalize scores" " FAISS constructor to normalize scores"
) )
docs_and_scores = self.similarity_search_with_score(query, k=k) docs_and_scores = self.similarity_search_with_score(
query,
k=k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores] return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]

@ -74,6 +74,28 @@ def test_faiss_with_metadatas() -> None:
assert output == [Document(page_content="foo", metadata={"page": 0})] assert output == [Document(page_content="foo", metadata={"page": 0})]
def test_faiss_with_metadatas_and_filter() -> None:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
assert output == []
def test_faiss_search_not_found() -> None: def test_faiss_search_not_found() -> None:
"""Test what happens when document is not found.""" """Test what happens when document is not found."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]

Loading…
Cancel
Save