@ -38,6 +38,8 @@ class SVMRetriever(BaseRetriever):
""" Index of embeddings. """
texts : List [ str ]
""" List of texts to index. """
metadatas : Optional [ List [ dict ] ] = None
""" List of metadatas corresponding with each text. """
k : int = 4
""" Number of results to return. """
relevancy_threshold : Optional [ float ] = None
@ -51,10 +53,20 @@ class SVMRetriever(BaseRetriever):
@classmethod
def from_texts (
cls , texts : List [ str ] , embeddings : Embeddings , * * kwargs : Any
cls ,
texts : List [ str ] ,
embeddings : Embeddings ,
metadatas : Optional [ List [ dict ] ] = None ,
* * kwargs : Any ,
) - > SVMRetriever :
index = create_index ( texts , embeddings )
return cls ( embeddings = embeddings , index = index , texts = texts , * * kwargs )
return cls (
embeddings = embeddings ,
index = index ,
texts = texts ,
metadatas = metadatas ,
* * kwargs ,
)
@classmethod
def from_documents (
@ -64,7 +76,9 @@ class SVMRetriever(BaseRetriever):
* * kwargs : Any ,
) - > SVMRetriever :
texts , metadatas = zip ( * ( ( d . page_content , d . metadata ) for d in documents ) )
return cls . from_texts ( texts = texts , embeddings = embeddings , * * kwargs )
return cls . from_texts (
texts = texts , embeddings = embeddings , metadatas = metadatas , * * kwargs
)
def _get_relevant_documents (
self , query : str , * , run_manager : CallbackManagerForRetrieverRun
@ -108,5 +122,7 @@ class SVMRetriever(BaseRetriever):
self . relevancy_threshold is None
or normalized_similarities [ row ] > = self . relevancy_threshold
) :
top_k_results . append ( Document ( page_content = self . texts [ row - 1 ] ) )
metadata = self . metadatas [ row - 1 ] if self . metadatas else { }
doc = Document ( page_content = self . texts [ row - 1 ] , metadata = metadata )
top_k_results . append ( doc )
return top_k_results