From e04b063ff40d7f70eaa91f135729071de60b219d Mon Sep 17 00:00:00 2001 From: dham Date: Sun, 22 Jan 2023 01:08:14 +0100 Subject: [PATCH] add faiss local saving/loading (#676) - This uses the faiss built-in `write_index` and `load_index` to save and load faiss indexes locally - Also fixes #674 - The save/load functions also use the faiss library, so I refactored the dependency into a function --- langchain/vectorstores/faiss.py | 40 +++++++++++++++---- .../vectorstores/test_faiss.py | 25 ++++++++++-- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 20608c15..714d4e7b 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -14,6 +14,19 @@ from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance +def dependable_faiss_import() -> Any: + """Import faiss if available, otherwise raise error.""" + try: + import faiss + except ImportError: + raise ValueError( + "Could not import faiss python package. " + "Please it install it with `pip install faiss` " + "or `pip install faiss-cpu` (depending on Python version)." + ) + return faiss + + class FAISS(VectorStore): """Wrapper around FAISS vector database. @@ -174,14 +187,7 @@ class FAISS(VectorStore): embeddings = OpenAIEmbeddings() faiss = FAISS.from_texts(texts, embeddings) """ - try: - import faiss - except ImportError: - raise ValueError( - "Could not import faiss python package. " - "Please it install it with `pip install faiss` " - "or `pip install faiss-cpu` (depending on Python version)." - ) + faiss = dependable_faiss_import() embeddings = embedding.embed_documents(texts) index = faiss.IndexFlatL2(len(embeddings[0])) index.add(np.array(embeddings, dtype=np.float32)) @@ -194,3 +200,21 @@ class FAISS(VectorStore): {index_to_id[i]: doc for i, doc in enumerate(documents)} ) return cls(embedding.embed_query, index, docstore, index_to_id) + + def save_local(self, path: str) -> None: + """Save FAISS index to disk. + + Args: + path: Path to save FAISS index to. + """ + faiss = dependable_faiss_import() + faiss.write_index(self.index, path) + + def load_local(self, path: str) -> None: + """Load FAISS index from disk. + + Args: + path: Path to load FAISS index from. + """ + faiss = dependable_faiss_import() + self.index = faiss.read_index(path) diff --git a/tests/integration_tests/vectorstores/test_faiss.py b/tests/integration_tests/vectorstores/test_faiss.py index c3d2ba57..07d60782 100644 --- a/tests/integration_tests/vectorstores/test_faiss.py +++ b/tests/integration_tests/vectorstores/test_faiss.py @@ -1,4 +1,5 @@ """Test FAISS functionality.""" +import tempfile from typing import List import pytest @@ -46,9 +47,15 @@ def test_faiss_with_metadatas() -> None: docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas) expected_docstore = InMemoryDocstore( { - "0": Document(page_content="foo", metadata={"page": 0}), - "1": Document(page_content="bar", metadata={"page": 1}), - "2": Document(page_content="baz", metadata={"page": 2}), + docsearch.index_to_docstore_id[0]: Document( + page_content="foo", metadata={"page": 0} + ), + docsearch.index_to_docstore_id[1]: Document( + page_content="bar", metadata={"page": 1} + ), + docsearch.index_to_docstore_id[2]: Document( + page_content="baz", metadata={"page": 2} + ), } ) assert docsearch.docstore.__dict__ == expected_docstore.__dict__ @@ -82,3 +89,15 @@ def test_faiss_add_texts_not_supported() -> None: docsearch = FAISS(FakeEmbeddings().embed_query, None, Wikipedia(), {}) with pytest.raises(ValueError): docsearch.add_texts(["foo"]) + + +def test_faiss_local_save_load() -> None: + """Test end to end serialization.""" + texts = ["foo", "bar", "baz"] + docsearch = FAISS.from_texts(texts, FakeEmbeddings()) + + with tempfile.NamedTemporaryFile() as temp_file: + docsearch.save_local(temp_file.name) + docsearch.index = None + docsearch.load_local(temp_file.name) + assert docsearch.index is not None