Harrison/index name (#2869)

Co-authored-by: Mesum Raza Hemani <mes.javacca@gmail.com>
This commit is contained in:
Harrison Chase 2023-04-13 22:01:32 -07:00 committed by GitHub
parent dcb17503f2
commit 8a98e5b50b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -373,39 +373,47 @@ class FAISS(VectorStore):
embeddings = [t[1] for t in text_embeddings]
return cls.__from(texts, embeddings, embedding, metadatas, **kwargs)
def save_local(self, folder_path: str) -> None:
def save_local(self, folder_path: str, index_name: str = "index") -> None:
"""Save FAISS index, docstore, and index_to_docstore_id to disk.
Args:
folder_path: folder path to save index, docstore,
and index_to_docstore_id to.
index_name: for saving with a specific index file name
"""
path = Path(folder_path)
path.mkdir(exist_ok=True, parents=True)
# save index separately since it is not picklable
faiss = dependable_faiss_import()
faiss.write_index(self.index, str(path / "index.faiss"))
faiss.write_index(
self.index, str(path / "{index_name}.faiss".format(index_name=index_name))
)
# save docstore and index_to_docstore_id
with open(path / "index.pkl", "wb") as f:
with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
pickle.dump((self.docstore, self.index_to_docstore_id), f)
@classmethod
def load_local(cls, folder_path: str, embeddings: Embeddings) -> FAISS:
def load_local(
cls, folder_path: str, embeddings: Embeddings, index_name: str = "index"
) -> FAISS:
"""Load FAISS index, docstore, and index_to_docstore_id to disk.
Args:
folder_path: folder path to load index, docstore,
and index_to_docstore_id from.
embeddings: Embeddings to use when generating queries
index_name: for saving with a specific index file name
"""
path = Path(folder_path)
# load index separately since it is not picklable
faiss = dependable_faiss_import()
index = faiss.read_index(str(path / "index.faiss"))
index = faiss.read_index(
str(path / "{index_name}.faiss".format(index_name=index_name))
)
# load docstore and index_to_docstore_id
with open(path / "index.pkl", "rb") as f:
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
docstore, index_to_docstore_id = pickle.load(f)
return cls(embeddings.embed_query, index, docstore, index_to_docstore_id)