Add ScaNN support in vectorstore. (#8251)

Description: Add ScaNN vectorstore to langchain.
ScaNN is a Open Source, high performance vector similarity library
optimized for AVX2-enabled CPUs.
https://github.com/google-research/google-research/tree/master/scann

- Dependencies: scann

Python notebook to illustrate the usage:
docs/extras/integrations/vectorstores/scann.ipynb
Integration test:
libs/langchain/tests/integration_tests/vectorstores/test_scann.py

@rlancemartin, @eyurtsev for review.

Thanks!
pull/8744/head
Ruiqi Guo 1 year ago committed by GitHub
parent 5b7ff215e8
commit 6aee589eec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,190 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e4afbbb6",
"metadata": {},
"source": [
"# ScaNN\n",
"\n",
"ScaNN (Scalable Nearest Neighbors) is a method for efficient vector similarity search at scale.\n",
"\n",
"ScaNN includes search space pruning and quantization for Maximum Inner Product Search and also supports other distance functions such as Euclidean distance. The implementation is optimized for x86 processors with AVX2 support. See its [Google Research github](https://github.com/google-research/google-research/tree/master/scann) for more details."
]
},
{
"cell_type": "markdown",
"id": "082f593e",
"metadata": {},
"source": [
"## Installation\n",
"Install ScaNN through pip. Alternatively, you can follow instructions on the [ScaNN Website](https://github.com/google-research/google-research/tree/master/scann#building-from-source) to install from source."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a35e4f09",
"metadata": {},
"outputs": [],
"source": [
"!pip install scann"
]
},
{
"cell_type": "markdown",
"id": "44bf38a8",
"metadata": {},
"source": [
"## Retrieval Demo\n",
"\n",
"Below we show how to use ScaNN in conjunction with Huggingface Embeddings."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "377bc723",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': 'state_of_the_union.txt'})"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.embeddings import HuggingFaceEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import ScaNN\n",
"from langchain.document_loaders import TextLoader\n",
"\n",
"loader = TextLoader(\"state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
"\n",
"from langchain.embeddings import TensorflowHubEmbeddings\n",
"embeddings = HuggingFaceEmbeddings()\n",
"\n",
"db = ScaNN.from_documents(docs, embeddings)\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = db.similarity_search(query)\n",
"\n",
"docs[0]"
]
},
{
"cell_type": "markdown",
"id": "9ad5b151",
"metadata": {},
"source": [
"## RetrievalQA Demo\n",
"\n",
"Next, we demonstrate using ScaNN in conjunction with Google PaLM API.\n",
"\n",
"You can obtain an API key from https://developers.generativeai.google/tutorials/setup"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fc27ad51",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import RetrievalQA\n",
"from langchain.chat_models import google_palm\n",
"\n",
"palm_client = google_palm.ChatGooglePalm(google_api_key='YOUR_GOOGLE_PALM_API_KEY')\n",
"\n",
"qa = RetrievalQA.from_chain_type(\n",
" llm=palm_client,\n",
" chain_type=\"stuff\",\n",
" retriever=db.as_retriever(search_kwargs={'k': 10})\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5b77f919",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The president said that Ketanji Brown Jackson is one of our nation's top legal minds, who will continue Justice Breyer's legacy of excellence.\n"
]
}
],
"source": [
"print(qa.run('What did the president say about Ketanji Brown Jackson?'))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0c6deec6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The president did not mention Michael Phelps in his speech.\n"
]
}
],
"source": [
"print(qa.run('What did the president say about Michael Phelps?'))"
]
},
{
"cell_type": "markdown",
"id": "8a49f4a6",
"metadata": {},
"source": [
"## Save and loading local retrieval index"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "6b7496b9",
"metadata": {},
"outputs": [],
"source": [
"db.save_local('/tmp/db', 'state_of_union')\n",
"restored_db = ScaNN.load_local('/tmp/db', embeddings, index_name='state_of_union')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -54,6 +54,7 @@ from langchain.vectorstores.pinecone import Pinecone
from langchain.vectorstores.qdrant import Qdrant
from langchain.vectorstores.redis import Redis
from langchain.vectorstores.rocksetdb import Rockset
from langchain.vectorstores.scann import ScaNN
from langchain.vectorstores.singlestoredb import SingleStoreDB
from langchain.vectorstores.sklearn import SKLearnVectorStore
from langchain.vectorstores.starrocks import StarRocks
@ -106,6 +107,7 @@ __all__ = [
"Qdrant",
"Redis",
"Rockset",
"ScaNN",
"SKLearnVectorStore",
"SingleStoreDB",
"StarRocks",

@ -0,0 +1,544 @@
"""Wrapper around ScaNN vector database."""
from __future__ import annotations
import operator
import pickle
import uuid
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np
from langchain.docstore.base import AddableMixin, Docstore
from langchain.docstore.document import Document
from langchain.docstore.in_memory import InMemoryDocstore
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import DistanceStrategy
def normalize(x: np.ndarray) -> np.ndarray:
x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None)
return x
def dependable_scann_import() -> Any:
"""
Import scann if available, otherwise raise error.
"""
try:
import scann
except ImportError:
raise ImportError(
"Could not import scann python package. "
"Please install it with `pip install scann` "
)
return scann
class ScaNN(VectorStore):
"""Wrapper around ScaNN vector database.
To use, you should have the ``scann`` python package installed.
Example:
.. code-block:: python
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import ScaNN
db = ScaNN.from_texts(
['foo', 'bar', 'barz', 'qux'],
HuggingFaceEmbeddings())
db.similarity_search('foo?', k=1)
"""
def __init__(
self,
embedding: Embeddings,
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str],
relevance_score_fn: Optional[Callable[[float], float]] = None,
normalize_L2: bool = False,
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
scann_config: Optional[str] = None,
):
"""Initialize with necessary components."""
self.embedding = embedding
self.index = index
self.docstore = docstore
self.index_to_docstore_id = index_to_docstore_id
self.distance_strategy = distance_strategy
self.override_relevance_score_fn = relevance_score_fn
self._normalize_L2 = normalize_L2
self._scann_config = scann_config
def __add(
self,
texts: Iterable[str],
embeddings: Iterable[List[float]],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
if not isinstance(self.docstore, AddableMixin):
raise ValueError(
"If trying to add texts, the underlying docstore should support "
f"adding items, which {self.docstore} does not"
)
raise NotImplementedError("Updates are not available in ScaNN, yet.")
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of unique IDs.
Returns:
List of ids from adding the texts into the vectorstore.
"""
# Embed and create the documents.
embeddings = self.embedding.embed_documents(list(texts))
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
def add_embeddings(
self,
text_embeddings: Iterable[Tuple[str, List[float]]],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
text_embeddings: Iterable pairs of string and embedding to
add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of unique IDs.
Returns:
List of ids from adding the texts into the vectorstore.
"""
if not isinstance(self.docstore, AddableMixin):
raise ValueError(
"If trying to add texts, the underlying docstore should support "
f"adding items, which {self.docstore} does not"
)
# Embed and create the documents.
texts, embeddings = zip(*text_embeddings)
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector ID or other criteria.
Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
raise NotImplementedError("Deletions are not available in ScaNN, yet.")
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
**kwargs: kwargs to be passed to similarity search. Can include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of documents most similar to the query text and L2 distance
in float for each. Lower score represents more similarity.
"""
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
vector = normalize(vector)
indices, scores = self.index.search_batched(
vector, k if filter is None else fetch_k
)
docs = []
for j, i in enumerate(indices[0]):
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if filter is not None:
filter = {
key: [value] if not isinstance(value, list) else value
for key, value in filter.items()
}
if all(doc.metadata.get(key) in value for key, value in filter.items()):
docs.append((doc, scores[0][j]))
else:
docs.append((doc, scores[0][j]))
score_threshold = kwargs.get("score_threshold")
if score_threshold is not None:
cmp = (
operator.ge
if self.distance_strategy
in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
else operator.le
)
docs = [
(doc, similarity)
for doc, similarity in docs
if cmp(similarity, score_threshold)
]
return docs[:k]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns:
List of documents most similar to the query text with
L2 distance in float. Lower score represents more similarity.
"""
embedding = self.embedding.embed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return docs
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns:
List of Documents most similar to the embedding.
"""
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return [doc for doc, _ in docs_and_scores]
def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
Defaults to 20.
Returns:
List of Documents most similar to the query.
"""
docs_and_scores = self.similarity_search_with_score(
query, k, filter=filter, fetch_k=fetch_k, **kwargs
)
return [doc for doc, _ in docs_and_scores]
@classmethod
def __from(
cls,
texts: List[str],
embeddings: List[List[float]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
normalize_L2: bool = False,
**kwargs: Any,
) -> ScaNN:
scann = dependable_scann_import()
distance_strategy = kwargs.get(
"distance_strategy", DistanceStrategy.EUCLIDEAN_DISTANCE
)
scann_config = kwargs.get("scann_config", None)
vector = np.array(embeddings, dtype=np.float32)
if normalize_L2:
vector = normalize(vector)
if scann_config is not None:
index = scann.scann_ops_pybind.create_searcher(vector, scann_config)
else:
if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
index = (
scann.scann_ops_pybind.builder(vector, 1, "dot_product")
.score_brute_force()
.build()
)
else:
# Default to L2, currently other metric types not initialized.
index = (
scann.scann_ops_pybind.builder(vector, 1, "squared_l2")
.score_brute_force()
.build()
)
documents = []
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
index_to_id = dict(enumerate(ids))
if len(index_to_id) != len(documents):
raise Exception(
f"{len(index_to_id)} ids provided for {len(documents)} documents."
" Each document should have an id."
)
docstore = InMemoryDocstore(dict(zip(index_to_id.values(), documents)))
return cls(
embedding,
index,
docstore,
index_to_id,
normalize_L2=normalize_L2,
**kwargs,
)
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> ScaNN:
"""Construct ScaNN wrapper from raw documents.
This is a user friendly interface that:
1. Embeds documents.
2. Creates an in memory docstore
3. Initializes the ScaNN database
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain import ScaNN
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
scann = ScaNN.from_texts(texts, embeddings)
"""
embeddings = embedding.embed_documents(texts)
return cls.__from(
texts,
embeddings,
embedding,
metadatas=metadatas,
ids=ids,
**kwargs,
)
@classmethod
def from_embeddings(
cls,
text_embeddings: List[Tuple[str, List[float]]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> ScaNN:
"""Construct ScaNN wrapper from raw documents.
This is a user friendly interface that:
1. Embeds documents.
2. Creates an in memory docstore
3. Initializes the ScaNN database
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain import ScaNN
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
text_embeddings = embeddings.embed_documents(texts)
text_embedding_pairs = list(zip(texts, text_embeddings))
scann = ScaNN.from_embeddings(text_embedding_pairs, embeddings)
"""
texts = [t[0] for t in text_embeddings]
embeddings = [t[1] for t in text_embeddings]
return cls.__from(
texts,
embeddings,
embedding,
metadatas=metadatas,
ids=ids,
**kwargs,
)
def save_local(self, folder_path: str, index_name: str = "index") -> None:
"""Save ScaNN index, docstore, and index_to_docstore_id to disk.
Args:
folder_path: folder path to save index, docstore,
and index_to_docstore_id to.
"""
path = Path(folder_path)
scann_path = path / "{index_name}.scann".format(index_name=index_name)
scann_path.mkdir(exist_ok=True, parents=True)
# save index separately since it is not picklable
self.index.serialize(str(scann_path))
# save docstore and index_to_docstore_id
with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
pickle.dump((self.docstore, self.index_to_docstore_id), f)
@classmethod
def load_local(
cls,
folder_path: str,
embedding: Embeddings,
index_name: str = "index",
**kwargs: Any,
) -> ScaNN:
"""Load ScaNN index, docstore, and index_to_docstore_id from disk.
Args:
folder_path: folder path to load index, docstore,
and index_to_docstore_id from.
embeddings: Embeddings to use when generating queries
index_name: for saving with a specific index file name
"""
path = Path(folder_path)
scann_path = path / "{index_name}.scann".format(index_name=index_name)
scann_path.mkdir(exist_ok=True, parents=True)
# load index separately since it is not picklable
scann = dependable_scann_import()
index = scann.scann_ops_pybind.load_searcher(str(scann_path))
# load docstore and index_to_docstore_id
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
docstore, index_to_docstore_id = pickle.load(f)
return cls(embedding, index, docstore, index_to_docstore_id, **kwargs)
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
The 'correct' relevance function
may differ depending on a few things, including:
- the distance / similarity metric used by the VectorStore
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
- embedding dimensionality
- etc.
"""
if self.override_relevance_score_fn is not None:
return self.override_relevance_score_fn
# Default strategy is to rely on distance strategy provided in
# vectorstore constructor
if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
return self._max_inner_product_relevance_score_fn
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
# Default behavior is to use euclidean distance relevancy
return self._euclidean_relevance_score_fn
else:
raise ValueError(
"Unknown distance strategy, must be cosine, max_inner_product,"
" or euclidean"
)
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and their similarity scores on a scale from 0 to 1."""
# Pop score threshold so that only relevancy scores, not raw scores, are
# filtered.
score_threshold = kwargs.pop("score_threshold", None)
relevance_score_fn = self._select_relevance_score_fn()
if relevance_score_fn is None:
raise ValueError(
"normalize_score_fn must be provided to"
" ScaNN constructor to normalize scores"
)
docs_and_scores = self.similarity_search_with_score(
query,
k=k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
docs_and_rel_scores = [
(doc, relevance_score_fn(score)) for doc, score in docs_and_scores
]
if score_threshold is not None:
docs_and_rel_scores = [
(doc, similarity)
for doc, similarity in docs_and_rel_scores
if similarity >= score_threshold
]
return docs_and_rel_scores

@ -0,0 +1,262 @@
"""Test ScaNN functionality."""
import datetime
import tempfile
import numpy as np
import pytest
from langchain.docstore.document import Document
from langchain.docstore.in_memory import InMemoryDocstore
from langchain.vectorstores.scann import ScaNN, dependable_scann_import, normalize
from langchain.vectorstores.utils import DistanceStrategy
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
)
def test_scann() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_scann_vector_mips_l2() -> None:
"""Test vector similarity with MIPS and L2."""
texts = ["foo", "bar", "baz"]
euclidean_search = ScaNN.from_texts(texts, FakeEmbeddings())
output = euclidean_search.similarity_search_with_score("foo", k=1)
expected_euclidean = [(Document(page_content="foo", metadata={}), 0.0)]
assert output == expected_euclidean
mips_search = ScaNN.from_texts(
texts,
FakeEmbeddings(),
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
normalize_L2=True,
)
output = mips_search.similarity_search_with_score("foo", k=1)
expected_mips = [(Document(page_content="foo", metadata={}), 1.0)]
assert output == expected_mips
def test_scann_with_config() -> None:
"""Test ScaNN with approximate search config."""
texts = [str(i) for i in range(10000)]
# Create a config with dimension = 10, k = 10.
# Tree: search 10 leaves in a search tree of 100 leaves.
# Quantization: uses 16-centroid quantizer every 2 dimension.
# Reordering: reorder top 100 results.
scann_config = (
dependable_scann_import()
.scann_ops_pybind.builder(np.zeros(shape=(0, 10)), 10, "squared_l2")
.tree(num_leaves=100, num_leaves_to_search=10)
.score_ah(2)
.reorder(100)
.create_config()
)
mips_search = ScaNN.from_texts(
texts,
ConsistentFakeEmbeddings(),
scann_config=scann_config,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
normalize_L2=True,
)
output = mips_search.similarity_search_with_score("42", k=1)
expected = [(Document(page_content="42", metadata={}), 0.0)]
assert output == expected
def test_scann_vector_sim() -> None:
"""Test vector similarity."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_by_vector(query_vec, k=1)
assert output == [Document(page_content="foo")]
def test_scann_vector_sim_with_score_threshold() -> None:
"""Test vector similarity."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_by_vector(query_vec, k=2, score_threshold=0.2)
assert output == [Document(page_content="foo")]
def test_similarity_search_with_score_by_vector() -> None:
"""Test vector similarity with score by vector."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_with_score_by_vector(query_vec, k=1)
assert len(output) == 1
assert output[0][0] == Document(page_content="foo")
def test_similarity_search_with_score_by_vector_with_score_threshold() -> None:
"""Test vector similarity with score by vector."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
index_to_id = docsearch.index_to_docstore_id
expected_docstore = InMemoryDocstore(
{
index_to_id[0]: Document(page_content="foo"),
index_to_id[1]: Document(page_content="bar"),
index_to_id[2]: Document(page_content="baz"),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
query_vec = FakeEmbeddings().embed_query(text="foo")
output = docsearch.similarity_search_with_score_by_vector(
query_vec,
k=2,
score_threshold=0.2,
)
assert len(output) == 1
assert output[0][0] == Document(page_content="foo")
assert output[0][1] < 0.2
def test_scann_with_metadatas() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})]
def test_scann_with_metadatas_and_filter() -> None:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
assert output == [Document(page_content="bar", metadata={"page": 1})]
def test_scann_with_metadatas_and_list_filter() -> None:
texts = ["foo", "bar", "baz", "foo", "qux"]
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
docsearch.index_to_docstore_id[3]: Document(
page_content="foo", metadata={"page": 3}
),
docsearch.index_to_docstore_id[4]: Document(
page_content="qux", metadata={"page": 3}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foor", k=1, filter={"page": [0, 1, 2]})
assert output == [Document(page_content="foo", metadata={"page": 0})]
def test_scann_search_not_found() -> None:
"""Test what happens when document is not found."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
# Get rid of the docstore to purposefully induce errors.
docsearch.docstore = InMemoryDocstore({})
with pytest.raises(ValueError):
docsearch.similarity_search("foo")
def test_scann_local_save_load() -> None:
"""Test end to end serialization."""
texts = ["foo", "bar", "baz"]
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
docsearch.save_local(temp_folder)
new_docsearch = ScaNN.load_local(temp_folder, FakeEmbeddings())
assert new_docsearch.index is not None
def test_scann_normalize_l2() -> None:
"""Test normalize L2."""
texts = ["foo", "bar", "baz"]
emb = np.array(FakeEmbeddings().embed_documents(texts))
# Test norm is 1.
np.testing.assert_allclose(1, np.linalg.norm(normalize(emb), axis=-1))
# Test that there is no NaN after normalization.
np.testing.assert_array_equal(False, np.isnan(normalize(np.zeros(10))))
Loading…
Cancel
Save