Bagatur/refac faiss (#9076)

Code cleanup and bug fix in deletion
pull/9163/head^2
Bagatur 1 year ago committed by GitHub
parent 3eccd72382
commit 358562769a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,16 @@ import pickle
import uuid import uuid
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sized,
Tuple,
)
import numpy as np import numpy as np
@ -46,16 +55,29 @@ def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
return faiss return faiss
def _len_check_if_sized(x: Any, y: Any, x_name: str, y_name: str) -> None:
if isinstance(x, Sized) and isinstance(y, Sized) and len(x) != len(y):
raise ValueError(
f"{x_name} and {y_name} expected to be equal length but "
f"len({x_name})={len(x)} and len({y_name})={len(y)}"
)
return
class FAISS(VectorStore): class FAISS(VectorStore):
"""Wrapper around FAISS vector database. """Wrapper around FAISS vector database.
To use, you should have the ``faiss`` python package installed. To use, you must have the ``faiss`` python package installed.
Example: Example:
.. code-block:: python .. code-block:: python
from langchain import FAISS from langchain.embeddings.openai import OpenAIEmbeddings
faiss = FAISS(embedding_function, index, docstore, index_to_docstore_id) from langchain.vectorstores import FAISS
embeddings = OpenAIEmbeddings()
texts = ["FAISS is an important library", "LangChain supports FAISS"]
faiss = FAISS.from_texts(texts, embeddings)
""" """
@ -87,44 +109,43 @@ class FAISS(VectorStore):
) )
) )
@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Accept embeddings object directly
return None
def __add( def __add(
self, self,
texts: Iterable[str], texts: Iterable[str],
embeddings: Iterable[List[float]], embeddings: Iterable[List[float]],
metadatas: Optional[List[dict]] = None, metadatas: Optional[Iterable[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]: ) -> List[str]:
faiss = dependable_faiss_import()
if not isinstance(self.docstore, AddableMixin): if not isinstance(self.docstore, AddableMixin):
raise ValueError( raise ValueError(
"If trying to add texts, the underlying docstore should support " "If trying to add texts, the underlying docstore should support "
f"adding items, which {self.docstore} does not" f"adding items, which {self.docstore} does not"
) )
documents = []
for i, text in enumerate(texts): _len_check_if_sized(texts, metadatas, "texts", "metadatas")
metadata = metadatas[i] if metadatas else {} _metadatas = metadatas or ({} for _ in texts)
documents.append(Document(page_content=text, metadata=metadata)) documents = [
if ids is None: Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas)
ids = [str(uuid.uuid4()) for _ in texts] ]
# Add to the index, the index_to_id mapping, and the docstore.
starting_len = len(self.index_to_docstore_id) _len_check_if_sized(documents, embeddings, "documents", "embeddings")
faiss = dependable_faiss_import() _len_check_if_sized(documents, ids, "documents", "ids")
# Add to the index.
vector = np.array(embeddings, dtype=np.float32) vector = np.array(embeddings, dtype=np.float32)
if self._normalize_L2: if self._normalize_L2:
faiss.normalize_L2(vector) faiss.normalize_L2(vector)
self.index.add(vector) self.index.add(vector)
# Get list of index, id, and docs.
full_info = [(starting_len + i, ids[i], doc) for i, doc in enumerate(documents)]
# Add information to docstore and index. # Add information to docstore and index.
self.docstore.add({_id: doc for _, _id, doc in full_info}) ids = ids or [str(uuid.uuid4()) for _ in texts]
index_to_id = {index: _id for index, _id, _ in full_info} self.docstore.add({id_: doc for id_, doc in zip(ids, documents)})
starting_len = len(self.index_to_docstore_id)
index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)}
self.index_to_docstore_id.update(index_to_id) self.index_to_docstore_id.update(index_to_id)
return [_id for _, _id, _ in full_info] return ids
def add_texts( def add_texts(
self, self,
@ -143,14 +164,8 @@ class FAISS(VectorStore):
Returns: Returns:
List of ids from adding the texts into the vectorstore. List of ids from adding the texts into the vectorstore.
""" """
if not isinstance(self.docstore, AddableMixin):
raise ValueError(
"If trying to add texts, the underlying docstore should support "
f"adding items, which {self.docstore} does not"
)
# Embed and create the documents.
embeddings = [self.embedding_function(text) for text in texts] embeddings = [self.embedding_function(text) for text in texts]
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs) return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
def add_embeddings( def add_embeddings(
self, self,
@ -170,15 +185,9 @@ class FAISS(VectorStore):
Returns: Returns:
List of ids from adding the texts into the vectorstore. List of ids from adding the texts into the vectorstore.
""" """
if not isinstance(self.docstore, AddableMixin):
raise ValueError(
"If trying to add texts, the underlying docstore should support "
f"adding items, which {self.docstore} does not"
)
# Embed and create the documents. # Embed and create the documents.
texts, embeddings = zip(*text_embeddings) texts, embeddings = zip(*text_embeddings)
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, self,
@ -480,22 +489,26 @@ class FAISS(VectorStore):
""" """
if ids is None: if ids is None:
raise ValueError("No ids provided to delete.") raise ValueError("No ids provided to delete.")
missing_ids = set(ids).difference(self.index_to_docstore_id.values())
if missing_ids:
raise ValueError(
f"Some specified ids do not exist in the current store. Ids not found: "
f"{missing_ids}"
)
overlapping = set(ids).intersection(self.index_to_docstore_id.values()) reversed_index = {id_: idx for idx, id_ in self.index_to_docstore_id.items()}
if not overlapping: index_to_delete = [reversed_index[id_] for id_ in ids]
raise ValueError("ids do not exist in the current object")
_reversed_index = {v: k for k, v in self.index_to_docstore_id.items()}
index_to_delete = [_reversed_index[i] for i in ids]
# Removing ids from index.
self.index.remove_ids(np.array(index_to_delete, dtype=np.int64)) self.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
for _id in index_to_delete:
del self.index_to_docstore_id[_id]
# Remove items from docstore.
self.docstore.delete(ids) self.docstore.delete(ids)
remaining_ids = [
id_
for i, id_ in sorted(self.index_to_docstore_id.items())
if i not in index_to_delete
]
self.index_to_docstore_id = {i: id_ for i, id_ in enumerate(remaining_ids)}
return True return True
def merge_from(self, target: FAISS) -> None: def merge_from(self, target: FAISS) -> None:
@ -533,50 +546,32 @@ class FAISS(VectorStore):
@classmethod @classmethod
def __from( def __from(
cls, cls,
texts: List[str], texts: Iterable[str],
embeddings: List[List[float]], embeddings: List[List[float]],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[Iterable[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
normalize_L2: bool = False, normalize_L2: bool = False,
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
**kwargs: Any, **kwargs: Any,
) -> FAISS: ) -> FAISS:
faiss = dependable_faiss_import() faiss = dependable_faiss_import()
distance_strategy = kwargs.get(
"distance_strategy", DistanceStrategy.EUCLIDEAN_DISTANCE
)
if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
index = faiss.IndexFlatIP(len(embeddings[0])) index = faiss.IndexFlatIP(len(embeddings[0]))
else: else:
# Default to L2, currently other metric types not initialized. # Default to L2, currently other metric types not initialized.
index = faiss.IndexFlatL2(len(embeddings[0])) index = faiss.IndexFlatL2(len(embeddings[0]))
vector = np.array(embeddings, dtype=np.float32) vecstore = cls(
if normalize_L2 and distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
faiss.normalize_L2(vector)
index.add(vector)
documents = []
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
index_to_id = dict(enumerate(ids))
if len(index_to_id) != len(documents):
raise Exception(
f"{len(index_to_id)} ids provided for {len(documents)} documents."
" Each document should have an id."
)
docstore = InMemoryDocstore(dict(zip(index_to_id.values(), documents)))
return cls(
embedding.embed_query, embedding.embed_query,
index, index,
docstore, InMemoryDocstore(),
index_to_id, {},
normalize_L2=normalize_L2, normalize_L2=normalize_L2,
distance_strategy=distance_strategy,
**kwargs, **kwargs,
) )
vecstore.__add(texts, embeddings, metadatas=metadatas, ids=ids)
return vecstore
@classmethod @classmethod
def from_texts( def from_texts(
@ -601,6 +596,7 @@ class FAISS(VectorStore):
from langchain import FAISS from langchain import FAISS
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
faiss = FAISS.from_texts(texts, embeddings) faiss = FAISS.from_texts(texts, embeddings)
""" """
@ -617,9 +613,9 @@ class FAISS(VectorStore):
@classmethod @classmethod
def from_embeddings( def from_embeddings(
cls, cls,
text_embeddings: List[Tuple[str, List[float]]], text_embeddings: Iterable[Tuple[str, List[float]]],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[Iterable[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> FAISS: ) -> FAISS:
@ -637,9 +633,10 @@ class FAISS(VectorStore):
from langchain import FAISS from langchain import FAISS
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
text_embeddings = embeddings.embed_documents(texts) text_embeddings = embeddings.embed_documents(texts)
text_embedding_pairs = list(zip(texts, text_embeddings)) text_embedding_pairs = zip(texts, text_embeddings)
faiss = FAISS.from_embeddings(text_embedding_pairs, embeddings) faiss = FAISS.from_embeddings(text_embedding_pairs, embeddings)
""" """
texts = [t[0] for t in text_embeddings] texts = [t[0] for t in text_embeddings]

@ -10477,7 +10477,7 @@ clarifai = ["clarifai"]
cohere = ["cohere"] cohere = ["cohere"]
docarray = ["docarray"] docarray = ["docarray"]
embeddings = ["sentence-transformers"] embeddings = ["sentence-transformers"]
extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xmltodict"] extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xmltodict"]
javascript = ["esprima"] javascript = ["esprima"]
llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
openai = ["openai", "tiktoken"] openai = ["openai", "tiktoken"]
@ -10487,4 +10487,4 @@ text-helpers = ["chardet"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "fb53fa05a5258de15427c0f69f2070265842bd530f139ed4e0ed71cd3b70ad36" content-hash = "6e85bdaca0b4a62bace541dd914266b49a4d7f90c7be2030fab639bf7efc23c6"

@ -334,6 +334,7 @@ extended_testing = [
"feedparser", "feedparser",
"xata", "xata",
"xmltodict", "xmltodict",
"faiss-cpu",
] ]
[tool.ruff] [tool.ruff]

@ -7,11 +7,12 @@ import pytest
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.docstore.in_memory import InMemoryDocstore from langchain.docstore.in_memory import InMemoryDocstore
from langchain.docstore.wikipedia import Wikipedia
from langchain.vectorstores.faiss import FAISS from langchain.vectorstores.faiss import FAISS
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.agents.test_react import FakeDocstore
@pytest.mark.requires("faiss")
def test_faiss() -> None: def test_faiss() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -29,6 +30,7 @@ def test_faiss() -> None:
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
@pytest.mark.requires("faiss")
def test_faiss_vector_sim() -> None: def test_faiss_vector_sim() -> None:
"""Test vector similarity.""" """Test vector similarity."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -47,6 +49,7 @@ def test_faiss_vector_sim() -> None:
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
@pytest.mark.requires("faiss")
def test_faiss_vector_sim_with_score_threshold() -> None: def test_faiss_vector_sim_with_score_threshold() -> None:
"""Test vector similarity.""" """Test vector similarity."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -65,6 +68,7 @@ def test_faiss_vector_sim_with_score_threshold() -> None:
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
@pytest.mark.requires("faiss")
def test_similarity_search_with_score_by_vector() -> None: def test_similarity_search_with_score_by_vector() -> None:
"""Test vector similarity with score by vector.""" """Test vector similarity with score by vector."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -84,6 +88,7 @@ def test_similarity_search_with_score_by_vector() -> None:
assert output[0][0] == Document(page_content="foo") assert output[0][0] == Document(page_content="foo")
@pytest.mark.requires("faiss")
def test_similarity_search_with_score_by_vector_with_score_threshold() -> None: def test_similarity_search_with_score_by_vector_with_score_threshold() -> None:
"""Test vector similarity with score by vector.""" """Test vector similarity with score by vector."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -108,6 +113,7 @@ def test_similarity_search_with_score_by_vector_with_score_threshold() -> None:
assert output[0][1] < 0.2 assert output[0][1] < 0.2
@pytest.mark.requires("faiss")
def test_faiss_mmr() -> None: def test_faiss_mmr() -> None:
texts = ["foo", "foo", "fou", "foy"] texts = ["foo", "foo", "fou", "foy"]
docsearch = FAISS.from_texts(texts, FakeEmbeddings()) docsearch = FAISS.from_texts(texts, FakeEmbeddings())
@ -122,6 +128,7 @@ def test_faiss_mmr() -> None:
assert output[1][0] != Document(page_content="foo") assert output[1][0] != Document(page_content="foo")
@pytest.mark.requires("faiss")
def test_faiss_mmr_with_metadatas() -> None: def test_faiss_mmr_with_metadatas() -> None:
texts = ["foo", "foo", "fou", "foy"] texts = ["foo", "foo", "fou", "foy"]
metadatas = [{"page": i} for i in range(len(texts))] metadatas = [{"page": i} for i in range(len(texts))]
@ -136,6 +143,7 @@ def test_faiss_mmr_with_metadatas() -> None:
assert output[1][0] != Document(page_content="foo", metadata={"page": 0}) assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
@pytest.mark.requires("faiss")
def test_faiss_mmr_with_metadatas_and_filter() -> None: def test_faiss_mmr_with_metadatas_and_filter() -> None:
texts = ["foo", "foo", "fou", "foy"] texts = ["foo", "foo", "fou", "foy"]
metadatas = [{"page": i} for i in range(len(texts))] metadatas = [{"page": i} for i in range(len(texts))]
@ -149,6 +157,7 @@ def test_faiss_mmr_with_metadatas_and_filter() -> None:
assert output[0][1] == 0.0 assert output[0][1] == 0.0
@pytest.mark.requires("faiss")
def test_faiss_mmr_with_metadatas_and_list_filter() -> None: def test_faiss_mmr_with_metadatas_and_list_filter() -> None:
texts = ["foo", "foo", "fou", "foy"] texts = ["foo", "foo", "fou", "foy"]
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))] metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
@ -163,6 +172,7 @@ def test_faiss_mmr_with_metadatas_and_list_filter() -> None:
assert output[1][0] != Document(page_content="foo", metadata={"page": 0}) assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
@pytest.mark.requires("faiss")
def test_faiss_with_metadatas() -> None: def test_faiss_with_metadatas() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -186,6 +196,7 @@ def test_faiss_with_metadatas() -> None:
assert output == [Document(page_content="foo", metadata={"page": 0})] assert output == [Document(page_content="foo", metadata={"page": 0})]
@pytest.mark.requires("faiss")
def test_faiss_with_metadatas_and_filter() -> None: def test_faiss_with_metadatas_and_filter() -> None:
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))] metadatas = [{"page": i} for i in range(len(texts))]
@ -208,6 +219,7 @@ def test_faiss_with_metadatas_and_filter() -> None:
assert output == [Document(page_content="bar", metadata={"page": 1})] assert output == [Document(page_content="bar", metadata={"page": 1})]
@pytest.mark.requires("faiss")
def test_faiss_with_metadatas_and_list_filter() -> None: def test_faiss_with_metadatas_and_list_filter() -> None:
texts = ["foo", "bar", "baz", "foo", "qux"] texts = ["foo", "bar", "baz", "foo", "qux"]
metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))] metadatas = [{"page": i} if i <= 3 else {"page": 3} for i in range(len(texts))]
@ -236,6 +248,7 @@ def test_faiss_with_metadatas_and_list_filter() -> None:
assert output == [Document(page_content="foo", metadata={"page": 0})] assert output == [Document(page_content="foo", metadata={"page": 0})]
@pytest.mark.requires("faiss")
def test_faiss_search_not_found() -> None: def test_faiss_search_not_found() -> None:
"""Test what happens when document is not found.""" """Test what happens when document is not found."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -246,6 +259,7 @@ def test_faiss_search_not_found() -> None:
docsearch.similarity_search("foo") docsearch.similarity_search("foo")
@pytest.mark.requires("faiss")
def test_faiss_add_texts() -> None: def test_faiss_add_texts() -> None:
"""Test end to end adding of texts.""" """Test end to end adding of texts."""
# Create initial doc store. # Create initial doc store.
@ -257,13 +271,15 @@ def test_faiss_add_texts() -> None:
assert output == [Document(page_content="foo"), Document(page_content="foo")] assert output == [Document(page_content="foo"), Document(page_content="foo")]
@pytest.mark.requires("faiss")
def test_faiss_add_texts_not_supported() -> None: def test_faiss_add_texts_not_supported() -> None:
"""Test adding of texts to a docstore that doesn't support it.""" """Test adding of texts to a docstore that doesn't support it."""
docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {}) docsearch = FAISS(FakeEmbeddings().embed_query, None, FakeDocstore(), {})
with pytest.raises(ValueError): with pytest.raises(ValueError):
docsearch.add_texts(["foo"]) docsearch.add_texts(["foo"])
@pytest.mark.requires("faiss")
def test_faiss_local_save_load() -> None: def test_faiss_local_save_load() -> None:
"""Test end to end serialization.""" """Test end to end serialization."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -275,6 +291,7 @@ def test_faiss_local_save_load() -> None:
assert new_docsearch.index is not None assert new_docsearch.index is not None
@pytest.mark.requires("faiss")
def test_faiss_similarity_search_with_relevance_scores() -> None: def test_faiss_similarity_search_with_relevance_scores() -> None:
"""Test the similarity search with normalized similarities.""" """Test the similarity search with normalized similarities."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -289,6 +306,7 @@ def test_faiss_similarity_search_with_relevance_scores() -> None:
assert score == 1.0 assert score == 1.0
@pytest.mark.requires("faiss")
def test_faiss_similarity_search_with_relevance_scores_with_threshold() -> None: def test_faiss_similarity_search_with_relevance_scores_with_threshold() -> None:
"""Test the similarity search with normalized similarities with score threshold.""" """Test the similarity search with normalized similarities with score threshold."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -306,6 +324,7 @@ def test_faiss_similarity_search_with_relevance_scores_with_threshold() -> None:
assert score == 1.0 assert score == 1.0
@pytest.mark.requires("faiss")
def test_faiss_invalid_normalize_fn() -> None: def test_faiss_invalid_normalize_fn() -> None:
"""Test the similarity search with normalized similarities.""" """Test the similarity search with normalized similarities."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -316,12 +335,22 @@ def test_faiss_invalid_normalize_fn() -> None:
docsearch.similarity_search_with_relevance_scores("foo", k=1) docsearch.similarity_search_with_relevance_scores("foo", k=1)
@pytest.mark.requires("faiss")
def test_missing_normalize_score_fn() -> None: def test_missing_normalize_score_fn() -> None:
"""Test doesn't perform similarity search without a valid distance strategy.""" """Test doesn't perform similarity search without a valid distance strategy."""
texts = ["foo", "bar", "baz"]
faiss_instance = FAISS.from_texts(texts, FakeEmbeddings(), distance_strategy="fake")
with pytest.raises(ValueError): with pytest.raises(ValueError):
texts = ["foo", "bar", "baz"]
faiss_instance = FAISS.from_texts(
texts, FakeEmbeddings(), distance_strategy="fake"
)
faiss_instance.similarity_search_with_relevance_scores("foo", k=2) faiss_instance.similarity_search_with_relevance_scores("foo", k=2)
@pytest.mark.requires("faiss")
def test_delete() -> None:
"""Test the similarity search with normalized similarities."""
ids = ["a", "b", "c"]
docsearch = FAISS.from_texts(["foo", "bar", "baz"], FakeEmbeddings(), ids=ids)
docsearch.delete(ids[1:2])
result = docsearch.similarity_search("bar", k=2)
assert sorted([d.page_content for d in result]) == ["baz", "foo"]
assert docsearch.index_to_docstore_id == {0: ids[0], 1: ids[2]}
Loading…
Cancel
Save