From ba023d53ca14becd1b4008b052efb0579e904391 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 17 May 2023 21:40:49 -0700 Subject: [PATCH] Harrison/faiss norm (#4903) Co-authored-by: Jiaxin Shan --- langchain/vectorstores/faiss.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 2085f95a..14397916 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -81,6 +81,7 @@ class FAISS(VectorStore): relevance_score_fn: Optional[ Callable[[float], float] ] = _default_relevance_score_fn, + normalize_L2: bool = False, ): """Initialize with necessary components.""" self.embedding_function = embedding_function @@ -88,6 +89,7 @@ class FAISS(VectorStore): self.docstore = docstore self.index_to_docstore_id = index_to_docstore_id self.relevance_score_fn = relevance_score_fn + self._normalize_L2 = normalize_L2 def __add( self, @@ -107,7 +109,11 @@ class FAISS(VectorStore): documents.append(Document(page_content=text, metadata=metadata)) # Add to the index, the index_to_id mapping, and the docstore. starting_len = len(self.index_to_docstore_id) - self.index.add(np.array(embeddings, dtype=np.float32)) + faiss = dependable_faiss_import() + vector = np.array(embeddings, dtype=np.float32) + if self._normalize_L2: + faiss.normalize_L2(vector) + self.index.add(vector) # Get list of index, id, and docs. full_info = [ (starting_len + i, str(uuid.uuid4()), doc) @@ -182,7 +188,11 @@ class FAISS(VectorStore): Returns: List of Documents most similar to the query and score for each """ - scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) + faiss = dependable_faiss_import() + vector = np.array([embedding], dtype=np.float32) + if self._normalize_L2: + faiss.normalize_L2(vector) + scores, indices = self.index.search(vector, k) docs = [] for j, i in enumerate(indices[0]): if i == -1: @@ -356,11 +366,15 @@ class FAISS(VectorStore): embeddings: List[List[float]], embedding: Embeddings, metadatas: Optional[List[dict]] = None, + normalize_L2: bool = False, **kwargs: Any, ) -> FAISS: faiss = dependable_faiss_import() index = faiss.IndexFlatL2(len(embeddings[0])) - index.add(np.array(embeddings, dtype=np.float32)) + vector = np.array(embeddings, dtype=np.float32) + if normalize_L2: + faiss.normalize_L2(vector) + index.add(vector) documents = [] for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} @@ -369,7 +383,14 @@ class FAISS(VectorStore): docstore = InMemoryDocstore( {index_to_id[i]: doc for i, doc in enumerate(documents)} ) - return cls(embedding.embed_query, index, docstore, index_to_id, **kwargs) + return cls( + embedding.embed_query, + index, + docstore, + index_to_id, + normalize_L2=normalize_L2, + **kwargs, + ) @classmethod def from_texts(