diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 9b139807..4157fa9d 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -373,39 +373,47 @@ class FAISS(VectorStore): embeddings = [t[1] for t in text_embeddings] return cls.__from(texts, embeddings, embedding, metadatas, **kwargs) - def save_local(self, folder_path: str) -> None: + def save_local(self, folder_path: str, index_name: str = "index") -> None: """Save FAISS index, docstore, and index_to_docstore_id to disk. Args: folder_path: folder path to save index, docstore, and index_to_docstore_id to. + index_name: for saving with a specific index file name """ path = Path(folder_path) path.mkdir(exist_ok=True, parents=True) # save index separately since it is not picklable faiss = dependable_faiss_import() - faiss.write_index(self.index, str(path / "index.faiss")) + faiss.write_index( + self.index, str(path / "{index_name}.faiss".format(index_name=index_name)) + ) # save docstore and index_to_docstore_id - with open(path / "index.pkl", "wb") as f: + with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f: pickle.dump((self.docstore, self.index_to_docstore_id), f) @classmethod - def load_local(cls, folder_path: str, embeddings: Embeddings) -> FAISS: + def load_local( + cls, folder_path: str, embeddings: Embeddings, index_name: str = "index" + ) -> FAISS: """Load FAISS index, docstore, and index_to_docstore_id to disk. Args: folder_path: folder path to load index, docstore, and index_to_docstore_id from. embeddings: Embeddings to use when generating queries + index_name: for saving with a specific index file name """ path = Path(folder_path) # load index separately since it is not picklable faiss = dependable_faiss_import() - index = faiss.read_index(str(path / "index.faiss")) + index = faiss.read_index( + str(path / "{index_name}.faiss".format(index_name=index_name)) + ) # load docstore and index_to_docstore_id - with open(path / "index.pkl", "rb") as f: + with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f: docstore, index_to_docstore_id = pickle.load(f) return cls(embeddings.embed_query, index, docstore, index_to_docstore_id)