mirror of https://github.com/hwchase17/langchain
Add ScaNN support in vectorstore. (#8251)
Description: Add ScaNN vectorstore to langchain. ScaNN is a Open Source, high performance vector similarity library optimized for AVX2-enabled CPUs. https://github.com/google-research/google-research/tree/master/scann - Dependencies: scann Python notebook to illustrate the usage: docs/extras/integrations/vectorstores/scann.ipynb Integration test: libs/langchain/tests/integration_tests/vectorstores/test_scann.py @rlancemartin, @eyurtsev for review. Thanks!pull/8744/head
parent
5b7ff215e8
commit
6aee589eec
@ -0,0 +1,544 @@
|
||||
"""Wrapper around ScaNN vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
import pickle
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import DistanceStrategy
|
||||
|
||||
|
||||
def normalize(x: np.ndarray) -> np.ndarray:
|
||||
x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None)
|
||||
return x
|
||||
|
||||
|
||||
def dependable_scann_import() -> Any:
|
||||
"""
|
||||
Import scann if available, otherwise raise error.
|
||||
"""
|
||||
try:
|
||||
import scann
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import scann python package. "
|
||||
"Please install it with `pip install scann` "
|
||||
)
|
||||
return scann
|
||||
|
||||
|
||||
class ScaNN(VectorStore):
|
||||
"""Wrapper around ScaNN vector database.
|
||||
|
||||
To use, you should have the ``scann`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.vectorstores import ScaNN
|
||||
|
||||
db = ScaNN.from_texts(
|
||||
['foo', 'bar', 'barz', 'qux'],
|
||||
HuggingFaceEmbeddings())
|
||||
db.similarity_search('foo?', k=1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
index: Any,
|
||||
docstore: Docstore,
|
||||
index_to_docstore_id: Dict[int, str],
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
normalize_L2: bool = False,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
scann_config: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding = embedding
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
self.index_to_docstore_id = index_to_docstore_id
|
||||
self.distance_strategy = distance_strategy
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
self._normalize_L2 = normalize_L2
|
||||
self._scann_config = scann_config
|
||||
|
||||
def __add(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: Iterable[List[float]],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError(
|
||||
"If trying to add texts, the underlying docstore should support "
|
||||
f"adding items, which {self.docstore} does not"
|
||||
)
|
||||
raise NotImplementedError("Updates are not available in ScaNN, yet.")
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique IDs.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
# Embed and create the documents.
|
||||
embeddings = self.embedding.embed_documents(list(texts))
|
||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
|
||||
|
||||
def add_embeddings(
|
||||
self,
|
||||
text_embeddings: Iterable[Tuple[str, List[float]]],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
text_embeddings: Iterable pairs of string and embedding to
|
||||
add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique IDs.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError(
|
||||
"If trying to add texts, the underlying docstore should support "
|
||||
f"adding items, which {self.docstore} does not"
|
||||
)
|
||||
# Embed and create the documents.
|
||||
texts, embeddings = zip(*text_embeddings)
|
||||
|
||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by vector ID or other criteria.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
**kwargs: Other keyword arguments that subclasses might use.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("Deletions are not available in ScaNN, yet.")
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
**kwargs: kwargs to be passed to similarity search. Can include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and L2 distance
|
||||
in float for each. Lower score represents more similarity.
|
||||
"""
|
||||
vector = np.array([embedding], dtype=np.float32)
|
||||
if self._normalize_L2:
|
||||
vector = normalize(vector)
|
||||
indices, scores = self.index.search_batched(
|
||||
vector, k if filter is None else fetch_k
|
||||
)
|
||||
docs = []
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if filter is not None:
|
||||
filter = {
|
||||
key: [value] if not isinstance(value, list) else value
|
||||
for key, value in filter.items()
|
||||
}
|
||||
if all(doc.metadata.get(key) in value for key, value in filter.items()):
|
||||
docs.append((doc, scores[0][j]))
|
||||
else:
|
||||
docs.append((doc, scores[0][j]))
|
||||
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
operator.ge
|
||||
if self.distance_strategy
|
||||
in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
|
||||
else operator.le
|
||||
)
|
||||
docs = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs
|
||||
if cmp(similarity, score_threshold)
|
||||
]
|
||||
return docs[:k]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text with
|
||||
L2 distance in float. Lower score represents more similarity.
|
||||
"""
|
||||
embedding = self.embedding.embed_query(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query, k, filter=filter, fetch_k=fetch_k, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
@classmethod
|
||||
def __from(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: List[List[float]],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
normalize_L2: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ScaNN:
|
||||
scann = dependable_scann_import()
|
||||
distance_strategy = kwargs.get(
|
||||
"distance_strategy", DistanceStrategy.EUCLIDEAN_DISTANCE
|
||||
)
|
||||
scann_config = kwargs.get("scann_config", None)
|
||||
|
||||
vector = np.array(embeddings, dtype=np.float32)
|
||||
if normalize_L2:
|
||||
vector = normalize(vector)
|
||||
if scann_config is not None:
|
||||
index = scann.scann_ops_pybind.create_searcher(vector, scann_config)
|
||||
else:
|
||||
if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
index = (
|
||||
scann.scann_ops_pybind.builder(vector, 1, "dot_product")
|
||||
.score_brute_force()
|
||||
.build()
|
||||
)
|
||||
else:
|
||||
# Default to L2, currently other metric types not initialized.
|
||||
index = (
|
||||
scann.scann_ops_pybind.builder(vector, 1, "squared_l2")
|
||||
.score_brute_force()
|
||||
.build()
|
||||
)
|
||||
documents = []
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid4()) for _ in texts]
|
||||
for i, text in enumerate(texts):
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
documents.append(Document(page_content=text, metadata=metadata))
|
||||
index_to_id = dict(enumerate(ids))
|
||||
|
||||
if len(index_to_id) != len(documents):
|
||||
raise Exception(
|
||||
f"{len(index_to_id)} ids provided for {len(documents)} documents."
|
||||
" Each document should have an id."
|
||||
)
|
||||
|
||||
docstore = InMemoryDocstore(dict(zip(index_to_id.values(), documents)))
|
||||
return cls(
|
||||
embedding,
|
||||
index,
|
||||
docstore,
|
||||
index_to_id,
|
||||
normalize_L2=normalize_L2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ScaNN:
|
||||
"""Construct ScaNN wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the ScaNN database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import ScaNN
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
scann = ScaNN.from_texts(texts, embeddings)
|
||||
"""
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
return cls.__from(
|
||||
texts,
|
||||
embeddings,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_embeddings(
|
||||
cls,
|
||||
text_embeddings: List[Tuple[str, List[float]]],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ScaNN:
|
||||
"""Construct ScaNN wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the ScaNN database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import ScaNN
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
text_embeddings = embeddings.embed_documents(texts)
|
||||
text_embedding_pairs = list(zip(texts, text_embeddings))
|
||||
scann = ScaNN.from_embeddings(text_embedding_pairs, embeddings)
|
||||
"""
|
||||
texts = [t[0] for t in text_embeddings]
|
||||
embeddings = [t[1] for t in text_embeddings]
|
||||
return cls.__from(
|
||||
texts,
|
||||
embeddings,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def save_local(self, folder_path: str, index_name: str = "index") -> None:
|
||||
"""Save ScaNN index, docstore, and index_to_docstore_id to disk.
|
||||
|
||||
Args:
|
||||
folder_path: folder path to save index, docstore,
|
||||
and index_to_docstore_id to.
|
||||
"""
|
||||
path = Path(folder_path)
|
||||
scann_path = path / "{index_name}.scann".format(index_name=index_name)
|
||||
scann_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# save index separately since it is not picklable
|
||||
self.index.serialize(str(scann_path))
|
||||
|
||||
# save docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
|
||||
pickle.dump((self.docstore, self.index_to_docstore_id), f)
|
||||
|
||||
@classmethod
|
||||
def load_local(
|
||||
cls,
|
||||
folder_path: str,
|
||||
embedding: Embeddings,
|
||||
index_name: str = "index",
|
||||
**kwargs: Any,
|
||||
) -> ScaNN:
|
||||
"""Load ScaNN index, docstore, and index_to_docstore_id from disk.
|
||||
|
||||
Args:
|
||||
folder_path: folder path to load index, docstore,
|
||||
and index_to_docstore_id from.
|
||||
embeddings: Embeddings to use when generating queries
|
||||
index_name: for saving with a specific index file name
|
||||
"""
|
||||
path = Path(folder_path)
|
||||
scann_path = path / "{index_name}.scann".format(index_name=index_name)
|
||||
scann_path.mkdir(exist_ok=True, parents=True)
|
||||
# load index separately since it is not picklable
|
||||
scann = dependable_scann_import()
|
||||
index = scann.scann_ops_pybind.load_searcher(str(scann_path))
|
||||
|
||||
# load docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
||||
docstore, index_to_docstore_id = pickle.load(f)
|
||||
return cls(embedding, index, docstore, index_to_docstore_id, **kwargs)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
if self.override_relevance_score_fn is not None:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
# Default strategy is to rely on distance strategy provided in
|
||||
# vectorstore constructor
|
||||
if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||
# Default behavior is to use euclidean distance relevancy
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown distance strategy, must be cosine, max_inner_product,"
|
||||
" or euclidean"
|
||||
)
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores on a scale from 0 to 1."""
|
||||
# Pop score threshold so that only relevancy scores, not raw scores, are
|
||||
# filtered.
|
||||
score_threshold = kwargs.pop("score_threshold", None)
|
||||
relevance_score_fn = self._select_relevance_score_fn()
|
||||
if relevance_score_fn is None:
|
||||
raise ValueError(
|
||||
"normalize_score_fn must be provided to"
|
||||
" ScaNN constructor to normalize scores"
|
||||
)
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query,
|
||||
k=k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
docs_and_rel_scores = [
|
||||
(doc, relevance_score_fn(score)) for doc, score in docs_and_scores
|
||||
]
|
||||
if score_threshold is not None:
|
||||
docs_and_rel_scores = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs_and_rel_scores
|
||||
if similarity >= score_threshold
|
||||
]
|
||||
return docs_and_rel_scores
|
@ -0,0 +1,262 @@
|
||||
"""Test ScaNN functionality."""
|
||||
import datetime
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.vectorstores.scann import ScaNN, dependable_scann_import, normalize
|
||||
from langchain.vectorstores.utils import DistanceStrategy
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
def test_scann() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_scann_vector_mips_l2() -> None:
|
||||
"""Test vector similarity with MIPS and L2."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
euclidean_search = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||
output = euclidean_search.similarity_search_with_score("foo", k=1)
|
||||
expected_euclidean = [(Document(page_content="foo", metadata={}), 0.0)]
|
||||
assert output == expected_euclidean
|
||||
|
||||
mips_search = ScaNN.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
|
||||
normalize_L2=True,
|
||||
)
|
||||
output = mips_search.similarity_search_with_score("foo", k=1)
|
||||
expected_mips = [(Document(page_content="foo", metadata={}), 1.0)]
|
||||
assert output == expected_mips
|
||||
|
||||
|
||||
def test_scann_with_config() -> None:
|
||||
"""Test ScaNN with approximate search config."""
|
||||
texts = [str(i) for i in range(10000)]
|
||||
# Create a config with dimension = 10, k = 10.
|
||||
# Tree: search 10 leaves in a search tree of 100 leaves.
|
||||
# Quantization: uses 16-centroid quantizer every 2 dimension.
|
||||
# Reordering: reorder top 100 results.
|
||||
scann_config = (
|
||||
dependable_scann_import()
|
||||
.scann_ops_pybind.builder(np.zeros(shape=(0, 10)), 10, "squared_l2")
|
||||
.tree(num_leaves=100, num_leaves_to_search=10)
|
||||
.score_ah(2)
|
||||
.reorder(100)
|
||||
.create_config()
|
||||
)
|
||||
|
||||
mips_search = ScaNN.from_texts(
|
||||
texts,
|
||||
ConsistentFakeEmbeddings(),
|
||||
scann_config=scann_config,
|
||||
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
|
||||
normalize_L2=True,
|
||||
)
|
||||
output = mips_search.similarity_search_with_score("42", k=1)
|
||||
expected = [(Document(page_content="42", metadata={}), 0.0)]
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_scann_vector_sim() -> None:
|
||||
"""Test vector similarity."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_by_vector(query_vec, k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_scann_vector_sim_with_score_threshold() -> None:
|
||||
"""Test vector similarity."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_by_vector(query_vec, k=2, score_threshold=0.2)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_similarity_search_with_score_by_vector() -> None:
|
||||
"""Test vector similarity with score by vector."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_with_score_by_vector(query_vec, k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
|
||||
|
||||
def test_similarity_search_with_score_by_vector_with_score_threshold() -> None:
|
||||
"""Test vector similarity with score by vector."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||
index_to_id = docsearch.index_to_docstore_id
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
index_to_id[0]: Document(page_content="foo"),
|
||||
index_to_id[1]: Document(page_content="bar"),
|
||||
index_to_id[2]: Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_with_score_by_vector(
|
||||
query_vec,
|
||||
k=2,
|
||||
score_threshold=0.2,
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] < 0.2
|
||||
|
||||
|
||||
def test_scann_with_metadatas() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
docsearch.index_to_docstore_id[0]: Document(
|
||||
page_content="foo", metadata={"page": 0}
|
||||
),
|
||||
docsearch.index_to_docstore_id[1]: Document(
|
||||
page_content="bar", metadata={"page": 1}
|
||||
),
|
||||
docsearch.index_to_docstore_id[2]: Document(
|
||||
page_content="baz", metadata={"page": 2}
|
||||
),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
|
||||
def test_scann_with_metadatas_and_filter() -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
docsearch.index_to_docstore_id[0]: Document(
|
||||
page_content="foo", metadata={"page": 0}
|
||||
),
|
||||
docsearch.index_to_docstore_id[1]: Document(
|
||||
page_content="bar", metadata={"page": 1}
|
||||
),
|
||||
docsearch.index_to_docstore_id[2]: Document(
|
||||
page_content="baz", metadata={"page": 2}
|
||||
),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
|
||||
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
||||
|
||||
|
||||
def test_scann_with_metadatas_and_list_filter() -> None:
|
||||
texts = ["foo", "bar", "baz", "foo", "qux"]
|
||||
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
docsearch.index_to_docstore_id[0]: Document(
|
||||
page_content="foo", metadata={"page": 0}
|
||||
),
|
||||
docsearch.index_to_docstore_id[1]: Document(
|
||||
page_content="bar", metadata={"page": 1}
|
||||
),
|
||||
docsearch.index_to_docstore_id[2]: Document(
|
||||
page_content="baz", metadata={"page": 2}
|
||||
),
|
||||
docsearch.index_to_docstore_id[3]: Document(
|
||||
page_content="foo", metadata={"page": 3}
|
||||
),
|
||||
docsearch.index_to_docstore_id[4]: Document(
|
||||
page_content="qux", metadata={"page": 3}
|
||||
),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foor", k=1, filter={"page": [0, 1, 2]})
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
|
||||
def test_scann_search_not_found() -> None:
|
||||
"""Test what happens when document is not found."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||
# Get rid of the docstore to purposefully induce errors.
|
||||
docsearch.docstore = InMemoryDocstore({})
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.similarity_search("foo")
|
||||
|
||||
|
||||
def test_scann_local_save_load() -> None:
|
||||
"""Test end to end serialization."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||
docsearch.save_local(temp_folder)
|
||||
new_docsearch = ScaNN.load_local(temp_folder, FakeEmbeddings())
|
||||
assert new_docsearch.index is not None
|
||||
|
||||
|
||||
def test_scann_normalize_l2() -> None:
|
||||
"""Test normalize L2."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
emb = np.array(FakeEmbeddings().embed_documents(texts))
|
||||
# Test norm is 1.
|
||||
np.testing.assert_allclose(1, np.linalg.norm(normalize(emb), axis=-1))
|
||||
# Test that there is no NaN after normalization.
|
||||
np.testing.assert_array_equal(False, np.isnan(normalize(np.zeros(10))))
|
Loading…
Reference in New Issue