From 77fc2f76442b922764436cc267220a5df8b7a022 Mon Sep 17 00:00:00 2001 From: Sian Cao Date: Thu, 19 Oct 2023 14:51:28 +0800 Subject: [PATCH] fix: impl missing embeddings method (#10823) FAISS does not implement embeddings method and use embed_query to embedding texts which is wrong for some embedding models. --------- Co-authored-by: Bagatur --- .../langchain/langchain/vectorstores/faiss.py | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/faiss.py b/libs/langchain/langchain/vectorstores/faiss.py index 2fc3bb68eb..0b867d2709 100644 --- a/libs/langchain/langchain/vectorstores/faiss.py +++ b/libs/langchain/langchain/vectorstores/faiss.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import operator import os import pickle @@ -15,6 +16,7 @@ from typing import ( Optional, Sized, Tuple, + Union, ) import numpy as np @@ -26,6 +28,8 @@ from langchain.schema.embeddings import Embeddings from langchain.schema.vectorstore import VectorStore from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance +logger = logging.getLogger(__name__) + def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: """ @@ -82,7 +86,7 @@ class FAISS(VectorStore): def __init__( self, - embedding_function: Callable, + embedding_function: Union[Callable, Embeddings], index: Any, docstore: Docstore, index_to_docstore_id: Dict[int, str], @@ -91,6 +95,11 @@ class FAISS(VectorStore): distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE, ): """Initialize with necessary components.""" + if not isinstance(embedding_function, Embeddings): + logger.warning( + "`embedding_function` is expected to be an Embeddings object, support " + "for passing in a function will soon be removed." + ) self.embedding_function = embedding_function self.index = index self.docstore = docstore @@ -108,6 +117,26 @@ class FAISS(VectorStore): ) ) + @property + def embeddings(self) -> Optional[Embeddings]: + return ( + self.embedding_function + if isinstance(self.embedding_function, Embeddings) + else None + ) + + def _embed_documents(self, texts: List[str]) -> List[List[float]]: + if isinstance(self.embedding_function, Embeddings): + return self.embedding_function.embed_documents(texts) + else: + return [self.embedding_function(text) for text in texts] + + def _embed_query(self, text: str) -> List[float]: + if isinstance(self.embedding_function, Embeddings): + return self.embedding_function.embed_query(text) + else: + return self.embedding_function(text) + def __add( self, texts: Iterable[str], @@ -163,7 +192,8 @@ class FAISS(VectorStore): Returns: List of ids from adding the texts into the vectorstore. """ - embeddings = [self.embedding_function(text) for text in texts] + texts = list(texts) + embeddings = self._embed_documents(texts) return self.__add(texts, embeddings, metadatas=metadatas, ids=ids) def add_embeddings( @@ -272,7 +302,7 @@ class FAISS(VectorStore): List of documents most similar to the query text with L2 distance in float. Lower score represents more similarity. """ - embedding = self.embedding_function(query) + embedding = self._embed_query(query) docs = self.similarity_search_with_score_by_vector( embedding, k, @@ -465,7 +495,7 @@ class FAISS(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - embedding = self.embedding_function(query) + embedding = self._embed_query(query) docs = self.max_marginal_relevance_search_by_vector( embedding, k=k, @@ -561,7 +591,7 @@ class FAISS(VectorStore): # Default to L2, currently other metric types not initialized. index = faiss.IndexFlatL2(len(embeddings[0])) vecstore = cls( - embedding.embed_query, + embedding, index, InMemoryDocstore(), {}, @@ -696,9 +726,7 @@ class FAISS(VectorStore): # load docstore and index_to_docstore_id with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f: docstore, index_to_docstore_id = pickle.load(f) - return cls( - embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs - ) + return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs) def serialize_to_bytes(self) -> bytes: """Serialize FAISS index, docstore, and index_to_docstore_id to bytes.""" @@ -713,9 +741,7 @@ class FAISS(VectorStore): ) -> FAISS: """Deserialize FAISS index, docstore, and index_to_docstore_id from bytes.""" index, docstore, index_to_docstore_id = pickle.loads(serialized) - return cls( - embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs - ) + return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs) def _select_relevance_score_fn(self) -> Callable[[float], float]: """