@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import operator
import os
import pickle
@ -15,6 +16,7 @@ from typing import (
Optional ,
Sized ,
Tuple ,
Union ,
)
import numpy as np
@ -26,6 +28,8 @@ from langchain.schema.embeddings import Embeddings
from langchain . schema . vectorstore import VectorStore
from langchain . vectorstores . utils import DistanceStrategy , maximal_marginal_relevance
logger = logging . getLogger ( __name__ )
def dependable_faiss_import ( no_avx2 : Optional [ bool ] = None ) - > Any :
"""
@ -82,7 +86,7 @@ class FAISS(VectorStore):
def __init__ (
self ,
embedding_function : Callable,
embedding_function : Union[ Callable, Embeddings ] ,
index : Any ,
docstore : Docstore ,
index_to_docstore_id : Dict [ int , str ] ,
@ -91,6 +95,11 @@ class FAISS(VectorStore):
distance_strategy : DistanceStrategy = DistanceStrategy . EUCLIDEAN_DISTANCE ,
) :
""" Initialize with necessary components. """
if not isinstance ( embedding_function , Embeddings ) :
logger . warning (
" `embedding_function` is expected to be an Embeddings object, support "
" for passing in a function will soon be removed. "
)
self . embedding_function = embedding_function
self . index = index
self . docstore = docstore
@ -108,6 +117,26 @@ class FAISS(VectorStore):
)
)
@property
def embeddings ( self ) - > Optional [ Embeddings ] :
return (
self . embedding_function
if isinstance ( self . embedding_function , Embeddings )
else None
)
def _embed_documents ( self , texts : List [ str ] ) - > List [ List [ float ] ] :
if isinstance ( self . embedding_function , Embeddings ) :
return self . embedding_function . embed_documents ( texts )
else :
return [ self . embedding_function ( text ) for text in texts ]
def _embed_query ( self , text : str ) - > List [ float ] :
if isinstance ( self . embedding_function , Embeddings ) :
return self . embedding_function . embed_query ( text )
else :
return self . embedding_function ( text )
def __add (
self ,
texts : Iterable [ str ] ,
@ -163,7 +192,8 @@ class FAISS(VectorStore):
Returns :
List of ids from adding the texts into the vectorstore .
"""
embeddings = [ self . embedding_function ( text ) for text in texts ]
texts = list ( texts )
embeddings = self . _embed_documents ( texts )
return self . __add ( texts , embeddings , metadatas = metadatas , ids = ids )
def add_embeddings (
@ -272,7 +302,7 @@ class FAISS(VectorStore):
List of documents most similar to the query text with
L2 distance in float . Lower score represents more similarity .
"""
embedding = self . embedding_function ( query )
embedding = self . _embed_query ( query )
docs = self . similarity_search_with_score_by_vector (
embedding ,
k ,
@ -465,7 +495,7 @@ class FAISS(VectorStore):
Returns :
List of Documents selected by maximal marginal relevance .
"""
embedding = self . embedding_function ( query )
embedding = self . _embed_query ( query )
docs = self . max_marginal_relevance_search_by_vector (
embedding ,
k = k ,
@ -561,7 +591,7 @@ class FAISS(VectorStore):
# Default to L2, currently other metric types not initialized.
index = faiss . IndexFlatL2 ( len ( embeddings [ 0 ] ) )
vecstore = cls (
embedding . embed_query ,
embedding ,
index ,
InMemoryDocstore ( ) ,
{ } ,
@ -696,9 +726,7 @@ class FAISS(VectorStore):
# load docstore and index_to_docstore_id
with open ( path / " {index_name} .pkl " . format ( index_name = index_name ) , " rb " ) as f :
docstore , index_to_docstore_id = pickle . load ( f )
return cls (
embeddings . embed_query , index , docstore , index_to_docstore_id , * * kwargs
)
return cls ( embeddings , index , docstore , index_to_docstore_id , * * kwargs )
def serialize_to_bytes ( self ) - > bytes :
""" Serialize FAISS index, docstore, and index_to_docstore_id to bytes. """
@ -713,9 +741,7 @@ class FAISS(VectorStore):
) - > FAISS :
""" Deserialize FAISS index, docstore, and index_to_docstore_id from bytes. """
index , docstore , index_to_docstore_id = pickle . loads ( serialized )
return cls (
embeddings . embed_query , index , docstore , index_to_docstore_id , * * kwargs
)
return cls ( embeddings , index , docstore , index_to_docstore_id , * * kwargs )
def _select_relevance_score_fn ( self ) - > Callable [ [ float ] , float ] :
"""