@ -5,7 +5,7 @@ https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
from __future__ import annotations
import concurrent . futures
from typing import Any , List, Optional
from typing import Any , Iterable, List, Optional
import numpy as np
from langchain_core . callbacks import CallbackManagerForRetrieverRun
@ -38,6 +38,8 @@ class KNNRetriever(BaseRetriever):
""" Index of embeddings. """
texts : List [ str ]
""" List of texts to index. """
metadatas : Optional [ List [ dict ] ] = None
""" List of metadatas corresponding with each text. """
k : int = 4
""" Number of results to return. """
relevancy_threshold : Optional [ float ] = None
@ -51,10 +53,32 @@ class KNNRetriever(BaseRetriever):
@classmethod
def from_texts (
cls , texts : List [ str ] , embeddings : Embeddings , * * kwargs : Any
cls ,
texts : List [ str ] ,
embeddings : Embeddings ,
metadatas : Optional [ List [ dict ] ] = None ,
* * kwargs : Any ,
) - > KNNRetriever :
index = create_index ( texts , embeddings )
return cls ( embeddings = embeddings , index = index , texts = texts , * * kwargs )
return cls (
embeddings = embeddings ,
index = index ,
texts = texts ,
metadatas = metadatas ,
* * kwargs ,
)
@classmethod
def from_documents (
cls ,
documents : Iterable [ Document ] ,
embeddings : Embeddings ,
* * kwargs : Any ,
) - > KNNRetriever :
texts , metadatas = zip ( * ( ( d . page_content , d . metadata ) for d in documents ) )
return cls . from_texts (
texts = texts , embeddings = embeddings , metadatas = metadatas , * * kwargs
)
def _get_relevant_documents (
self , query : str , * , run_manager : CallbackManagerForRetrieverRun
@ -71,7 +95,10 @@ class KNNRetriever(BaseRetriever):
normalized_similarities = ( similarities - np . min ( similarities ) ) / denominator
top_k_results = [
Document ( page_content = self . texts [ row ] )
Document (
page_content = self . texts [ row ] ,
metadata = self . metadatas [ row ] if self . metadatas else { } ,
)
for row in sorted_ix [ 0 : self . k ]
if (
self . relevancy_threshold is None