You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/tests/integration_tests/vectorstores/test_scann.py

267 lines
10 KiB
Python

"""Test ScaNN functionality."""
import datetime
import tempfile
import numpy as np
import pytest
from langchain_core.documents import Document
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores.scann import (
ScaNN,
dependable_scann_import,
normalize,
)
from langchain_community.vectorstores.utils import DistanceStrategy
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
)
def test_scann() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_scann_vector_mips_l2() -> None:
"""Test vector similarity with MIPS and L2."""
texts = ["foo", "bar", "baz"]
euclidean_search = ScaNN.from_texts(texts, FakeEmbeddings())
output = euclidean_search.similarity_search_with_score("foo", k=1)
expected_euclidean = [(Document(page_content="foo", metadata={}), 0.0)]
assert output == expected_euclidean
mips_search = ScaNN.from_texts(
texts,
FakeEmbeddings(),
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
normalize_L2=True,
)
output = mips_search.similarity_search_with_score("foo", k=1)
expected_mips = [(Document(page_content="foo", metadata={}), 1.0)]
assert output == expected_mips
def test_scann_with_config() -> None:
"""Test ScaNN with approximate search config."""
texts = [str(i) for i in range(10000)]
# Create a config with dimension = 10, k = 10.
# Tree: search 10 leaves in a search tree of 100 leaves.
# Quantization: uses 16-centroid quantizer every 2 dimension.
# Reordering: reorder top 100 results.
scann_config = (
dependable_scann_import()
.scann_ops_pybind.builder(np.zeros(shape=(0, 10)), 10, "squared_l2")
.tree(num_leaves=100, num_leaves_to_search=10)
.score_ah(2)
.reorder(100)
.create_config()
)
mips_search = ScaNN.from_texts(
texts,
ConsistentFakeEmbeddings(),
scann_config=scann_config,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
normalize_L2=True,
)
output = mips_search.similarity_search_with_score("42", k=1)
expected = [(Document(page_content="42", metadata={}), 0.0)]
assert output == expected
def test_scann_vector_sim() -> None:
"""Test vector similarity."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_by_vector(query_vec, k=1)
assert output == [Document(page_content="foo")]
def test_scann_vector_sim_with_score_threshold() -> None:
"""Test vector similarity."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_by_vector(query_vec, k=2, score_threshold=0.2)
assert output == [Document(page_content="foo")]
def test_similarity_search_with_score_by_vector() -> None:
"""Test vector similarity with score by vector."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_with_score_by_vector(query_vec, k=1)
assert len(output) == 1
assert output[0][0] == Document(page_content="foo")
def test_similarity_search_with_score_by_vector_with_score_threshold() -> None:
"""Test vector similarity with score by vector."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_with_score_by_vector(
query_vec,
k=2,
score_threshold=0.2,
)
assert len(output) == 1
assert output[0][0] == Document(page_content="foo")
assert output[0][1] < 0.2
def test_scann_with_metadatas() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})]
def test_scann_with_metadatas_and_filter() -> None:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
assert output == [Document(page_content="bar", metadata={"page": 1})]
def test_scann_with_metadatas_and_list_filter() -> None:
texts = ["foo", "bar", "baz", "foo", "qux"]
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
docsearch.index_to_docstore_id[3]: Document(
page_content="foo", metadata={"page": 3}
),
docsearch.index_to_docstore_id[4]: Document(
page_content="qux", metadata={"page": 3}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foor", k=1, filter={"page": [0, 1, 2]})
assert output == [Document(page_content="foo", metadata={"page": 0})]
def test_scann_search_not_found() -> None:
"""Test what happens when document is not found."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
# Get rid of the docstore to purposefully induce errors.
docsearch.docstore = InMemoryDocstore({})
with pytest.raises(ValueError):
docsearch.similarity_search("foo")
def test_scann_local_save_load() -> None:
"""Test end to end serialization."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
docsearch.save_local(temp_folder)
new_docsearch = ScaNN.load_local(temp_folder, FakeEmbeddings())
assert new_docsearch.index is not None
def test_scann_normalize_l2() -> None:
"""Test normalize L2."""
texts = ["foo", "bar", "baz"]
emb = np.array(FakeEmbeddings().embed_documents(texts))
# Test norm is 1.
np.testing.assert_allclose(1, np.linalg.norm(normalize(emb), axis=-1))
# Test that there is no NaN after normalization.
np.testing.assert_array_equal(False, np.isnan(normalize(np.zeros(10))))