diff --git a/langchain/retrievers/knn.py b/langchain/retrievers/knn.py index d6204723..362a0ec2 100644 --- a/langchain/retrievers/knn.py +++ b/langchain/retrievers/knn.py @@ -51,13 +51,14 @@ class KNNRetriever(BaseRetriever, BaseModel): denominator = np.max(similarities) - np.min(similarities) + 1e-6 normalized_similarities = (similarities - np.min(similarities)) / denominator - top_k_results = [] - for row in sorted_ix[0 : self.k]: + top_k_results = [ + Document(page_content=self.texts[row]) + for row in sorted_ix[0 : self.k] if ( self.relevancy_threshold is None or normalized_similarities[row] >= self.relevancy_threshold - ): - top_k_results.append(Document(page_content=self.texts[row])) + ) + ] return top_k_results async def aget_relevant_documents(self, query: str) -> List[Document]: