add faiss test for score threshold (#8143)

# What
- Add faiss vector search test for score threshold
- Fix failing faiss vector search test; filtering with list value is
wrong.

<!-- Thank you for contributing to LangChain!

Replace this comment with:
- Description: Add faiss vector search test for score threshold; Fix
failing faiss vector search test; filtering with list value is wrong.
  - Issue: None
  - Dependencies: None
  - Tag maintainer: @rlancemartin, @eyurtsev
  - Twitter handle: @MlopsJ

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
pull/8162/head
shibuiwilliam 1 year ago committed by GitHub
parent 7686dabd36
commit 8f5000146c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -370,7 +370,12 @@ class FAISS(VectorStore):
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if all(doc.metadata.get(key) == value for key, value in filter.items()):
if all(
doc.metadata.get(key) in value
if isinstance(value, list)
else doc.metadata.get(key) == value
for key, value in filter.items()
):
filtered_indices.append(i)
indices = np.array([filtered_indices])
# -1 happens when not enough docs are returned.

@ -47,6 +47,24 @@ def test_faiss_vector_sim() -> None:
assert output == [Document(page_content="foo")]
def test_faiss_vector_sim_with_score_threshold() -> None:
"""Test vector similarity."""
texts = ["foo", "bar", "baz"]
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_by_vector(query_vec, k=2, score_threshold=0.2)
assert output == [Document(page_content="foo")]
def test_similarity_search_with_score_by_vector() -> None:
"""Test vector similarity with score by vector."""
texts = ["foo", "bar", "baz"]
@ -66,6 +84,30 @@ def test_similarity_search_with_score_by_vector() -> None:
assert output[0][0] == Document(page_content="foo")
def test_similarity_search_with_score_by_vector_with_score_threshold() -> None:
"""Test vector similarity with score by vector."""
texts = ["foo", "bar", "baz"]
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_with_score_by_vector(
query_vec,
k=2,
score_threshold=0.2,
)
assert len(output) == 1
assert output[0][0] == Document(page_content="foo")
assert output[0][1] < 0.2
def test_faiss_mmr() -> None:
texts = ["foo", "foo", "fou", "foy"]
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
@ -102,10 +144,9 @@ def test_faiss_mmr_with_metadatas_and_filter() -> None:
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
query_vec, k=10, lambda_mult=0.1, filter={"page": 1}
)
assert len(output) == len(texts)
assert len(output) == 1
assert output[0][0] == Document(page_content="foo", metadata={"page": 1})
assert output[0][1] == 0.0
assert output[1][0] != Document(page_content="foo", metadata={"page": 1})
def test_faiss_mmr_with_metadatas_and_list_filter() -> None:
@ -116,7 +157,7 @@ def test_faiss_mmr_with_metadatas_and_list_filter() -> None:
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
query_vec, k=10, lambda_mult=0.1, filter={"page": [0, 1, 2]}
)
assert len(output) == len(texts)
assert len(output) == 3
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
assert output[0][1] == 0.0
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})

Loading…
Cancel
Save