mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add ScaNN support in vectorstore. (#8251)
Description: Add ScaNN vectorstore to langchain. ScaNN is a Open Source, high performance vector similarity library optimized for AVX2-enabled CPUs. https://github.com/google-research/google-research/tree/master/scann - Dependencies: scann Python notebook to illustrate the usage: docs/extras/integrations/vectorstores/scann.ipynb Integration test: libs/langchain/tests/integration_tests/vectorstores/test_scann.py @rlancemartin, @eyurtsev for review. Thanks!
This commit is contained in:
parent
5b7ff215e8
commit
6aee589eec
190
docs/extras/integrations/vectorstores/scann.ipynb
Normal file
190
docs/extras/integrations/vectorstores/scann.ipynb
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e4afbbb6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ScaNN\n",
|
||||||
|
"\n",
|
||||||
|
"ScaNN (Scalable Nearest Neighbors) is a method for efficient vector similarity search at scale.\n",
|
||||||
|
"\n",
|
||||||
|
"ScaNN includes search space pruning and quantization for Maximum Inner Product Search and also supports other distance functions such as Euclidean distance. The implementation is optimized for x86 processors with AVX2 support. See its [Google Research github](https://github.com/google-research/google-research/tree/master/scann) for more details."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "082f593e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Installation\n",
|
||||||
|
"Install ScaNN through pip. Alternatively, you can follow instructions on the [ScaNN Website](https://github.com/google-research/google-research/tree/master/scann#building-from-source) to install from source."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a35e4f09",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install scann"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "44bf38a8",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Retrieval Demo\n",
|
||||||
|
"\n",
|
||||||
|
"Below we show how to use ScaNN in conjunction with Huggingface Embeddings."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "377bc723",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.', metadata={'source': 'state_of_the_union.txt'})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings import HuggingFaceEmbeddings\n",
|
||||||
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
|
"from langchain.vectorstores import ScaNN\n",
|
||||||
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
|
"\n",
|
||||||
|
"loader = TextLoader(\"state_of_the_union.txt\")\n",
|
||||||
|
"documents = loader.load()\n",
|
||||||
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||||
|
"docs = text_splitter.split_documents(documents)\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain.embeddings import TensorflowHubEmbeddings\n",
|
||||||
|
"embeddings = HuggingFaceEmbeddings()\n",
|
||||||
|
"\n",
|
||||||
|
"db = ScaNN.from_documents(docs, embeddings)\n",
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"docs = db.similarity_search(query)\n",
|
||||||
|
"\n",
|
||||||
|
"docs[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9ad5b151",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## RetrievalQA Demo\n",
|
||||||
|
"\n",
|
||||||
|
"Next, we demonstrate using ScaNN in conjunction with Google PaLM API.\n",
|
||||||
|
"\n",
|
||||||
|
"You can obtain an API key from https://developers.generativeai.google/tutorials/setup"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "fc27ad51",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import RetrievalQA\n",
|
||||||
|
"from langchain.chat_models import google_palm\n",
|
||||||
|
"\n",
|
||||||
|
"palm_client = google_palm.ChatGooglePalm(google_api_key='YOUR_GOOGLE_PALM_API_KEY')\n",
|
||||||
|
"\n",
|
||||||
|
"qa = RetrievalQA.from_chain_type(\n",
|
||||||
|
" llm=palm_client,\n",
|
||||||
|
" chain_type=\"stuff\",\n",
|
||||||
|
" retriever=db.as_retriever(search_kwargs={'k': 10})\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "5b77f919",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The president said that Ketanji Brown Jackson is one of our nation's top legal minds, who will continue Justice Breyer's legacy of excellence.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(qa.run('What did the president say about Ketanji Brown Jackson?'))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "0c6deec6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The president did not mention Michael Phelps in his speech.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(qa.run('What did the president say about Michael Phelps?'))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "8a49f4a6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Save and loading local retrieval index"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 22,
|
||||||
|
"id": "6b7496b9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"db.save_local('/tmp/db', 'state_of_union')\n",
|
||||||
|
"restored_db = ScaNN.load_local('/tmp/db', embeddings, index_name='state_of_union')"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.11"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -54,6 +54,7 @@ from langchain.vectorstores.pinecone import Pinecone
|
|||||||
from langchain.vectorstores.qdrant import Qdrant
|
from langchain.vectorstores.qdrant import Qdrant
|
||||||
from langchain.vectorstores.redis import Redis
|
from langchain.vectorstores.redis import Redis
|
||||||
from langchain.vectorstores.rocksetdb import Rockset
|
from langchain.vectorstores.rocksetdb import Rockset
|
||||||
|
from langchain.vectorstores.scann import ScaNN
|
||||||
from langchain.vectorstores.singlestoredb import SingleStoreDB
|
from langchain.vectorstores.singlestoredb import SingleStoreDB
|
||||||
from langchain.vectorstores.sklearn import SKLearnVectorStore
|
from langchain.vectorstores.sklearn import SKLearnVectorStore
|
||||||
from langchain.vectorstores.starrocks import StarRocks
|
from langchain.vectorstores.starrocks import StarRocks
|
||||||
@ -106,6 +107,7 @@ __all__ = [
|
|||||||
"Qdrant",
|
"Qdrant",
|
||||||
"Redis",
|
"Redis",
|
||||||
"Rockset",
|
"Rockset",
|
||||||
|
"ScaNN",
|
||||||
"SKLearnVectorStore",
|
"SKLearnVectorStore",
|
||||||
"SingleStoreDB",
|
"SingleStoreDB",
|
||||||
"StarRocks",
|
"StarRocks",
|
||||||
|
544
libs/langchain/langchain/vectorstores/scann.py
Normal file
544
libs/langchain/langchain/vectorstores/scann.py
Normal file
@ -0,0 +1,544 @@
|
|||||||
|
"""Wrapper around ScaNN vector database."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import operator
|
||||||
|
import pickle
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from langchain.docstore.base import AddableMixin, Docstore
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from langchain.vectorstores.utils import DistanceStrategy
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x: np.ndarray) -> np.ndarray:
|
||||||
|
x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def dependable_scann_import() -> Any:
|
||||||
|
"""
|
||||||
|
Import scann if available, otherwise raise error.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import scann
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import scann python package. "
|
||||||
|
"Please install it with `pip install scann` "
|
||||||
|
)
|
||||||
|
return scann
|
||||||
|
|
||||||
|
|
||||||
|
class ScaNN(VectorStore):
|
||||||
|
"""Wrapper around ScaNN vector database.
|
||||||
|
|
||||||
|
To use, you should have the ``scann`` python package installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
from langchain.vectorstores import ScaNN
|
||||||
|
|
||||||
|
db = ScaNN.from_texts(
|
||||||
|
['foo', 'bar', 'barz', 'qux'],
|
||||||
|
HuggingFaceEmbeddings())
|
||||||
|
db.similarity_search('foo?', k=1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding: Embeddings,
|
||||||
|
index: Any,
|
||||||
|
docstore: Docstore,
|
||||||
|
index_to_docstore_id: Dict[int, str],
|
||||||
|
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||||
|
normalize_L2: bool = False,
|
||||||
|
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||||
|
scann_config: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize with necessary components."""
|
||||||
|
self.embedding = embedding
|
||||||
|
self.index = index
|
||||||
|
self.docstore = docstore
|
||||||
|
self.index_to_docstore_id = index_to_docstore_id
|
||||||
|
self.distance_strategy = distance_strategy
|
||||||
|
self.override_relevance_score_fn = relevance_score_fn
|
||||||
|
self._normalize_L2 = normalize_L2
|
||||||
|
self._scann_config = scann_config
|
||||||
|
|
||||||
|
def __add(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
embeddings: Iterable[List[float]],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
if not isinstance(self.docstore, AddableMixin):
|
||||||
|
raise ValueError(
|
||||||
|
"If trying to add texts, the underlying docstore should support "
|
||||||
|
f"adding items, which {self.docstore} does not"
|
||||||
|
)
|
||||||
|
raise NotImplementedError("Updates are not available in ScaNN, yet.")
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Iterable of strings to add to the vectorstore.
|
||||||
|
metadatas: Optional list of metadatas associated with the texts.
|
||||||
|
ids: Optional list of unique IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ids from adding the texts into the vectorstore.
|
||||||
|
"""
|
||||||
|
# Embed and create the documents.
|
||||||
|
embeddings = self.embedding.embed_documents(list(texts))
|
||||||
|
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
|
||||||
|
|
||||||
|
def add_embeddings(
|
||||||
|
self,
|
||||||
|
text_embeddings: Iterable[Tuple[str, List[float]]],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_embeddings: Iterable pairs of string and embedding to
|
||||||
|
add to the vectorstore.
|
||||||
|
metadatas: Optional list of metadatas associated with the texts.
|
||||||
|
ids: Optional list of unique IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ids from adding the texts into the vectorstore.
|
||||||
|
"""
|
||||||
|
if not isinstance(self.docstore, AddableMixin):
|
||||||
|
raise ValueError(
|
||||||
|
"If trying to add texts, the underlying docstore should support "
|
||||||
|
f"adding items, which {self.docstore} does not"
|
||||||
|
)
|
||||||
|
# Embed and create the documents.
|
||||||
|
texts, embeddings = zip(*text_embeddings)
|
||||||
|
|
||||||
|
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
|
||||||
|
|
||||||
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||||
|
"""Delete by vector ID or other criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids to delete.
|
||||||
|
**kwargs: Other keyword arguments that subclasses might use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[bool]: True if deletion is successful,
|
||||||
|
False otherwise, None if not implemented.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError("Deletions are not available in ScaNN, yet.")
|
||||||
|
|
||||||
|
def similarity_search_with_score_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding vector to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
|
||||||
|
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||||
|
Defaults to 20.
|
||||||
|
**kwargs: kwargs to be passed to similarity search. Can include:
|
||||||
|
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||||
|
filter the resulting set of retrieved docs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of documents most similar to the query text and L2 distance
|
||||||
|
in float for each. Lower score represents more similarity.
|
||||||
|
"""
|
||||||
|
vector = np.array([embedding], dtype=np.float32)
|
||||||
|
if self._normalize_L2:
|
||||||
|
vector = normalize(vector)
|
||||||
|
indices, scores = self.index.search_batched(
|
||||||
|
vector, k if filter is None else fetch_k
|
||||||
|
)
|
||||||
|
docs = []
|
||||||
|
for j, i in enumerate(indices[0]):
|
||||||
|
if i == -1:
|
||||||
|
# This happens when not enough docs are returned.
|
||||||
|
continue
|
||||||
|
_id = self.index_to_docstore_id[i]
|
||||||
|
doc = self.docstore.search(_id)
|
||||||
|
if not isinstance(doc, Document):
|
||||||
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||||
|
if filter is not None:
|
||||||
|
filter = {
|
||||||
|
key: [value] if not isinstance(value, list) else value
|
||||||
|
for key, value in filter.items()
|
||||||
|
}
|
||||||
|
if all(doc.metadata.get(key) in value for key, value in filter.items()):
|
||||||
|
docs.append((doc, scores[0][j]))
|
||||||
|
else:
|
||||||
|
docs.append((doc, scores[0][j]))
|
||||||
|
|
||||||
|
score_threshold = kwargs.get("score_threshold")
|
||||||
|
if score_threshold is not None:
|
||||||
|
cmp = (
|
||||||
|
operator.ge
|
||||||
|
if self.distance_strategy
|
||||||
|
in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
|
||||||
|
else operator.le
|
||||||
|
)
|
||||||
|
docs = [
|
||||||
|
(doc, similarity)
|
||||||
|
for doc, similarity in docs
|
||||||
|
if cmp(similarity, score_threshold)
|
||||||
|
]
|
||||||
|
return docs[:k]
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||||
|
Defaults to 20.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of documents most similar to the query text with
|
||||||
|
L2 distance in float. Lower score represents more similarity.
|
||||||
|
"""
|
||||||
|
embedding = self.embedding.embed_query(query)
|
||||||
|
docs = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding,
|
||||||
|
k,
|
||||||
|
filter=filter,
|
||||||
|
fetch_k=fetch_k,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs most similar to embedding vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||||
|
Defaults to 20.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the embedding.
|
||||||
|
"""
|
||||||
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding,
|
||||||
|
k,
|
||||||
|
filter=filter,
|
||||||
|
fetch_k=fetch_k,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||||
|
Defaults to 20.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query.
|
||||||
|
"""
|
||||||
|
docs_and_scores = self.similarity_search_with_score(
|
||||||
|
query, k, filter=filter, fetch_k=fetch_k, **kwargs
|
||||||
|
)
|
||||||
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __from(
|
||||||
|
cls,
|
||||||
|
texts: List[str],
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
normalize_L2: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ScaNN:
|
||||||
|
scann = dependable_scann_import()
|
||||||
|
distance_strategy = kwargs.get(
|
||||||
|
"distance_strategy", DistanceStrategy.EUCLIDEAN_DISTANCE
|
||||||
|
)
|
||||||
|
scann_config = kwargs.get("scann_config", None)
|
||||||
|
|
||||||
|
vector = np.array(embeddings, dtype=np.float32)
|
||||||
|
if normalize_L2:
|
||||||
|
vector = normalize(vector)
|
||||||
|
if scann_config is not None:
|
||||||
|
index = scann.scann_ops_pybind.create_searcher(vector, scann_config)
|
||||||
|
else:
|
||||||
|
if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||||
|
index = (
|
||||||
|
scann.scann_ops_pybind.builder(vector, 1, "dot_product")
|
||||||
|
.score_brute_force()
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default to L2, currently other metric types not initialized.
|
||||||
|
index = (
|
||||||
|
scann.scann_ops_pybind.builder(vector, 1, "squared_l2")
|
||||||
|
.score_brute_force()
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
documents = []
|
||||||
|
if ids is None:
|
||||||
|
ids = [str(uuid.uuid4()) for _ in texts]
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
metadata = metadatas[i] if metadatas else {}
|
||||||
|
documents.append(Document(page_content=text, metadata=metadata))
|
||||||
|
index_to_id = dict(enumerate(ids))
|
||||||
|
|
||||||
|
if len(index_to_id) != len(documents):
|
||||||
|
raise Exception(
|
||||||
|
f"{len(index_to_id)} ids provided for {len(documents)} documents."
|
||||||
|
" Each document should have an id."
|
||||||
|
)
|
||||||
|
|
||||||
|
docstore = InMemoryDocstore(dict(zip(index_to_id.values(), documents)))
|
||||||
|
return cls(
|
||||||
|
embedding,
|
||||||
|
index,
|
||||||
|
docstore,
|
||||||
|
index_to_id,
|
||||||
|
normalize_L2=normalize_L2,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls,
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ScaNN:
|
||||||
|
"""Construct ScaNN wrapper from raw documents.
|
||||||
|
|
||||||
|
This is a user friendly interface that:
|
||||||
|
1. Embeds documents.
|
||||||
|
2. Creates an in memory docstore
|
||||||
|
3. Initializes the ScaNN database
|
||||||
|
|
||||||
|
This is intended to be a quick way to get started.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import ScaNN
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
scann = ScaNN.from_texts(texts, embeddings)
|
||||||
|
"""
|
||||||
|
embeddings = embedding.embed_documents(texts)
|
||||||
|
return cls.__from(
|
||||||
|
texts,
|
||||||
|
embeddings,
|
||||||
|
embedding,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_embeddings(
|
||||||
|
cls,
|
||||||
|
text_embeddings: List[Tuple[str, List[float]]],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ScaNN:
|
||||||
|
"""Construct ScaNN wrapper from raw documents.
|
||||||
|
|
||||||
|
This is a user friendly interface that:
|
||||||
|
1. Embeds documents.
|
||||||
|
2. Creates an in memory docstore
|
||||||
|
3. Initializes the ScaNN database
|
||||||
|
|
||||||
|
This is intended to be a quick way to get started.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import ScaNN
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
text_embeddings = embeddings.embed_documents(texts)
|
||||||
|
text_embedding_pairs = list(zip(texts, text_embeddings))
|
||||||
|
scann = ScaNN.from_embeddings(text_embedding_pairs, embeddings)
|
||||||
|
"""
|
||||||
|
texts = [t[0] for t in text_embeddings]
|
||||||
|
embeddings = [t[1] for t in text_embeddings]
|
||||||
|
return cls.__from(
|
||||||
|
texts,
|
||||||
|
embeddings,
|
||||||
|
embedding,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_local(self, folder_path: str, index_name: str = "index") -> None:
|
||||||
|
"""Save ScaNN index, docstore, and index_to_docstore_id to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_path: folder path to save index, docstore,
|
||||||
|
and index_to_docstore_id to.
|
||||||
|
"""
|
||||||
|
path = Path(folder_path)
|
||||||
|
scann_path = path / "{index_name}.scann".format(index_name=index_name)
|
||||||
|
scann_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# save index separately since it is not picklable
|
||||||
|
self.index.serialize(str(scann_path))
|
||||||
|
|
||||||
|
# save docstore and index_to_docstore_id
|
||||||
|
with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
|
||||||
|
pickle.dump((self.docstore, self.index_to_docstore_id), f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_local(
|
||||||
|
cls,
|
||||||
|
folder_path: str,
|
||||||
|
embedding: Embeddings,
|
||||||
|
index_name: str = "index",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ScaNN:
|
||||||
|
"""Load ScaNN index, docstore, and index_to_docstore_id from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_path: folder path to load index, docstore,
|
||||||
|
and index_to_docstore_id from.
|
||||||
|
embeddings: Embeddings to use when generating queries
|
||||||
|
index_name: for saving with a specific index file name
|
||||||
|
"""
|
||||||
|
path = Path(folder_path)
|
||||||
|
scann_path = path / "{index_name}.scann".format(index_name=index_name)
|
||||||
|
scann_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
# load index separately since it is not picklable
|
||||||
|
scann = dependable_scann_import()
|
||||||
|
index = scann.scann_ops_pybind.load_searcher(str(scann_path))
|
||||||
|
|
||||||
|
# load docstore and index_to_docstore_id
|
||||||
|
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
||||||
|
docstore, index_to_docstore_id = pickle.load(f)
|
||||||
|
return cls(embedding, index, docstore, index_to_docstore_id, **kwargs)
|
||||||
|
|
||||||
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
|
"""
|
||||||
|
The 'correct' relevance function
|
||||||
|
may differ depending on a few things, including:
|
||||||
|
- the distance / similarity metric used by the VectorStore
|
||||||
|
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||||
|
- embedding dimensionality
|
||||||
|
- etc.
|
||||||
|
"""
|
||||||
|
if self.override_relevance_score_fn is not None:
|
||||||
|
return self.override_relevance_score_fn
|
||||||
|
|
||||||
|
# Default strategy is to rely on distance strategy provided in
|
||||||
|
# vectorstore constructor
|
||||||
|
if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||||
|
return self._max_inner_product_relevance_score_fn
|
||||||
|
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||||
|
# Default behavior is to use euclidean distance relevancy
|
||||||
|
return self._euclidean_relevance_score_fn
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unknown distance strategy, must be cosine, max_inner_product,"
|
||||||
|
" or euclidean"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _similarity_search_with_relevance_scores(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs and their similarity scores on a scale from 0 to 1."""
|
||||||
|
# Pop score threshold so that only relevancy scores, not raw scores, are
|
||||||
|
# filtered.
|
||||||
|
score_threshold = kwargs.pop("score_threshold", None)
|
||||||
|
relevance_score_fn = self._select_relevance_score_fn()
|
||||||
|
if relevance_score_fn is None:
|
||||||
|
raise ValueError(
|
||||||
|
"normalize_score_fn must be provided to"
|
||||||
|
" ScaNN constructor to normalize scores"
|
||||||
|
)
|
||||||
|
docs_and_scores = self.similarity_search_with_score(
|
||||||
|
query,
|
||||||
|
k=k,
|
||||||
|
filter=filter,
|
||||||
|
fetch_k=fetch_k,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
docs_and_rel_scores = [
|
||||||
|
(doc, relevance_score_fn(score)) for doc, score in docs_and_scores
|
||||||
|
]
|
||||||
|
if score_threshold is not None:
|
||||||
|
docs_and_rel_scores = [
|
||||||
|
(doc, similarity)
|
||||||
|
for doc, similarity in docs_and_rel_scores
|
||||||
|
if similarity >= score_threshold
|
||||||
|
]
|
||||||
|
return docs_and_rel_scores
|
@ -0,0 +1,262 @@
|
|||||||
|
"""Test ScaNN functionality."""
|
||||||
|
import datetime
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
|
from langchain.vectorstores.scann import ScaNN, dependable_scann_import, normalize
|
||||||
|
from langchain.vectorstores.utils import DistanceStrategy
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
|
ConsistentFakeEmbeddings,
|
||||||
|
FakeEmbeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||||
|
index_to_id = docsearch.index_to_docstore_id
|
||||||
|
expected_docstore = InMemoryDocstore(
|
||||||
|
{
|
||||||
|
index_to_id[0]: Document(page_content="foo"),
|
||||||
|
index_to_id[1]: Document(page_content="bar"),
|
||||||
|
index_to_id[2]: Document(page_content="baz"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_vector_mips_l2() -> None:
|
||||||
|
"""Test vector similarity with MIPS and L2."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
euclidean_search = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||||
|
output = euclidean_search.similarity_search_with_score("foo", k=1)
|
||||||
|
expected_euclidean = [(Document(page_content="foo", metadata={}), 0.0)]
|
||||||
|
assert output == expected_euclidean
|
||||||
|
|
||||||
|
mips_search = ScaNN.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
|
||||||
|
normalize_L2=True,
|
||||||
|
)
|
||||||
|
output = mips_search.similarity_search_with_score("foo", k=1)
|
||||||
|
expected_mips = [(Document(page_content="foo", metadata={}), 1.0)]
|
||||||
|
assert output == expected_mips
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_with_config() -> None:
|
||||||
|
"""Test ScaNN with approximate search config."""
|
||||||
|
texts = [str(i) for i in range(10000)]
|
||||||
|
# Create a config with dimension = 10, k = 10.
|
||||||
|
# Tree: search 10 leaves in a search tree of 100 leaves.
|
||||||
|
# Quantization: uses 16-centroid quantizer every 2 dimension.
|
||||||
|
# Reordering: reorder top 100 results.
|
||||||
|
scann_config = (
|
||||||
|
dependable_scann_import()
|
||||||
|
.scann_ops_pybind.builder(np.zeros(shape=(0, 10)), 10, "squared_l2")
|
||||||
|
.tree(num_leaves=100, num_leaves_to_search=10)
|
||||||
|
.score_ah(2)
|
||||||
|
.reorder(100)
|
||||||
|
.create_config()
|
||||||
|
)
|
||||||
|
|
||||||
|
mips_search = ScaNN.from_texts(
|
||||||
|
texts,
|
||||||
|
ConsistentFakeEmbeddings(),
|
||||||
|
scann_config=scann_config,
|
||||||
|
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
|
||||||
|
normalize_L2=True,
|
||||||
|
)
|
||||||
|
output = mips_search.similarity_search_with_score("42", k=1)
|
||||||
|
expected = [(Document(page_content="42", metadata={}), 0.0)]
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_vector_sim() -> None:
|
||||||
|
"""Test vector similarity."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||||
|
index_to_id = docsearch.index_to_docstore_id
|
||||||
|
expected_docstore = InMemoryDocstore(
|
||||||
|
{
|
||||||
|
index_to_id[0]: Document(page_content="foo"),
|
||||||
|
index_to_id[1]: Document(page_content="bar"),
|
||||||
|
index_to_id[2]: Document(page_content="baz"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
|
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||||
|
output = docsearch.similarity_search_by_vector(query_vec, k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_vector_sim_with_score_threshold() -> None:
|
||||||
|
"""Test vector similarity."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||||
|
index_to_id = docsearch.index_to_docstore_id
|
||||||
|
expected_docstore = InMemoryDocstore(
|
||||||
|
{
|
||||||
|
index_to_id[0]: Document(page_content="foo"),
|
||||||
|
index_to_id[1]: Document(page_content="bar"),
|
||||||
|
index_to_id[2]: Document(page_content="baz"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
|
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||||
|
output = docsearch.similarity_search_by_vector(query_vec, k=2, score_threshold=0.2)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_similarity_search_with_score_by_vector() -> None:
|
||||||
|
"""Test vector similarity with score by vector."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||||
|
index_to_id = docsearch.index_to_docstore_id
|
||||||
|
expected_docstore = InMemoryDocstore(
|
||||||
|
{
|
||||||
|
index_to_id[0]: Document(page_content="foo"),
|
||||||
|
index_to_id[1]: Document(page_content="bar"),
|
||||||
|
index_to_id[2]: Document(page_content="baz"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
|
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||||
|
output = docsearch.similarity_search_with_score_by_vector(query_vec, k=1)
|
||||||
|
assert len(output) == 1
|
||||||
|
assert output[0][0] == Document(page_content="foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_similarity_search_with_score_by_vector_with_score_threshold() -> None:
|
||||||
|
"""Test vector similarity with score by vector."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||||
|
index_to_id = docsearch.index_to_docstore_id
|
||||||
|
expected_docstore = InMemoryDocstore(
|
||||||
|
{
|
||||||
|
index_to_id[0]: Document(page_content="foo"),
|
||||||
|
index_to_id[1]: Document(page_content="bar"),
|
||||||
|
index_to_id[2]: Document(page_content="baz"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
|
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||||
|
output = docsearch.similarity_search_with_score_by_vector(
|
||||||
|
query_vec,
|
||||||
|
k=2,
|
||||||
|
score_threshold=0.2,
|
||||||
|
)
|
||||||
|
assert len(output) == 1
|
||||||
|
assert output[0][0] == Document(page_content="foo")
|
||||||
|
assert output[0][1] < 0.2
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_with_metadatas() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||||
|
expected_docstore = InMemoryDocstore(
|
||||||
|
{
|
||||||
|
docsearch.index_to_docstore_id[0]: Document(
|
||||||
|
page_content="foo", metadata={"page": 0}
|
||||||
|
),
|
||||||
|
docsearch.index_to_docstore_id[1]: Document(
|
||||||
|
page_content="bar", metadata={"page": 1}
|
||||||
|
),
|
||||||
|
docsearch.index_to_docstore_id[2]: Document(
|
||||||
|
page_content="baz", metadata={"page": 2}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_with_metadatas_and_filter() -> None:
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||||
|
expected_docstore = InMemoryDocstore(
|
||||||
|
{
|
||||||
|
docsearch.index_to_docstore_id[0]: Document(
|
||||||
|
page_content="foo", metadata={"page": 0}
|
||||||
|
),
|
||||||
|
docsearch.index_to_docstore_id[1]: Document(
|
||||||
|
page_content="bar", metadata={"page": 1}
|
||||||
|
),
|
||||||
|
docsearch.index_to_docstore_id[2]: Document(
|
||||||
|
page_content="baz", metadata={"page": 2}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
|
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
|
||||||
|
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_with_metadatas_and_list_filter() -> None:
|
||||||
|
texts = ["foo", "bar", "baz", "foo", "qux"]
|
||||||
|
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
|
||||||
|
expected_docstore = InMemoryDocstore(
|
||||||
|
{
|
||||||
|
docsearch.index_to_docstore_id[0]: Document(
|
||||||
|
page_content="foo", metadata={"page": 0}
|
||||||
|
),
|
||||||
|
docsearch.index_to_docstore_id[1]: Document(
|
||||||
|
page_content="bar", metadata={"page": 1}
|
||||||
|
),
|
||||||
|
docsearch.index_to_docstore_id[2]: Document(
|
||||||
|
page_content="baz", metadata={"page": 2}
|
||||||
|
),
|
||||||
|
docsearch.index_to_docstore_id[3]: Document(
|
||||||
|
page_content="foo", metadata={"page": 3}
|
||||||
|
),
|
||||||
|
docsearch.index_to_docstore_id[4]: Document(
|
||||||
|
page_content="qux", metadata={"page": 3}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||||
|
output = docsearch.similarity_search("foor", k=1, filter={"page": [0, 1, 2]})
|
||||||
|
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_search_not_found() -> None:
|
||||||
|
"""Test what happens when document is not found."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||||
|
# Get rid of the docstore to purposefully induce errors.
|
||||||
|
docsearch.docstore = InMemoryDocstore({})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
docsearch.similarity_search("foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_local_save_load() -> None:
|
||||||
|
"""Test end to end serialization."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ScaNN.from_texts(texts, FakeEmbeddings())
|
||||||
|
temp_timestamp = datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||||
|
with tempfile.TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||||
|
docsearch.save_local(temp_folder)
|
||||||
|
new_docsearch = ScaNN.load_local(temp_folder, FakeEmbeddings())
|
||||||
|
assert new_docsearch.index is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_scann_normalize_l2() -> None:
|
||||||
|
"""Test normalize L2."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
emb = np.array(FakeEmbeddings().embed_documents(texts))
|
||||||
|
# Test norm is 1.
|
||||||
|
np.testing.assert_allclose(1, np.linalg.norm(normalize(emb), axis=-1))
|
||||||
|
# Test that there is no NaN after normalization.
|
||||||
|
np.testing.assert_array_equal(False, np.isnan(normalize(np.zeros(10))))
|
Loading…
Reference in New Issue
Block a user