add faiss local saving/loading (#676)

- This uses the faiss built-in `write_index` and `load_index` to save
and load faiss indexes locally
- Also fixes #674
- The save/load functions also use the faiss library, so I refactored
the dependency into a function
harrison/document-split
dham 1 year ago committed by GitHub
parent e45f7e40e8
commit e04b063ff4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,19 @@ from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
def dependable_faiss_import() -> Any:
"""Import faiss if available, otherwise raise error."""
try:
import faiss
except ImportError:
raise ValueError(
"Could not import faiss python package. "
"Please it install it with `pip install faiss` "
"or `pip install faiss-cpu` (depending on Python version)."
)
return faiss
class FAISS(VectorStore):
"""Wrapper around FAISS vector database.
@ -174,14 +187,7 @@ class FAISS(VectorStore):
embeddings = OpenAIEmbeddings()
faiss = FAISS.from_texts(texts, embeddings)
"""
try:
import faiss
except ImportError:
raise ValueError(
"Could not import faiss python package. "
"Please it install it with `pip install faiss` "
"or `pip install faiss-cpu` (depending on Python version)."
)
faiss = dependable_faiss_import()
embeddings = embedding.embed_documents(texts)
index = faiss.IndexFlatL2(len(embeddings[0]))
index.add(np.array(embeddings, dtype=np.float32))
@ -194,3 +200,21 @@ class FAISS(VectorStore):
{index_to_id[i]: doc for i, doc in enumerate(documents)}
)
return cls(embedding.embed_query, index, docstore, index_to_id)
def save_local(self, path: str) -> None:
"""Save FAISS index to disk.
Args:
path: Path to save FAISS index to.
"""
faiss = dependable_faiss_import()
faiss.write_index(self.index, path)
def load_local(self, path: str) -> None:
"""Load FAISS index from disk.
Args:
path: Path to load FAISS index from.
"""
faiss = dependable_faiss_import()
self.index = faiss.read_index(path)

@ -1,4 +1,5 @@
"""Test FAISS functionality."""
import tempfile
from typing import List
import pytest
@ -46,9 +47,15 @@ def test_faiss_with_metadatas() -> None:
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
"0": Document(page_content="foo", metadata={"page": 0}),
"1": Document(page_content="bar", metadata={"page": 1}),
"2": Document(page_content="baz", metadata={"page": 2}),
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
@ -82,3 +89,15 @@ def test_faiss_add_texts_not_supported() -> None:
docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {})
with pytest.raises(ValueError):
docsearch.add_texts(["foo"])
def test_faiss_local_save_load() -> None:
"""Test end to end serialization."""
texts = ["foo", "bar", "baz"]
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
with tempfile.NamedTemporaryFile() as temp_file:
docsearch.save_local(temp_file.name)
docsearch.index = None
docsearch.load_local(temp_file.name)
assert docsearch.index is not None

Loading…
Cancel
Save