diff --git a/langchain/retrievers/svm.py b/langchain/retrievers/svm.py index d69abc4b..2060d3b0 100644 --- a/langchain/retrievers/svm.py +++ b/langchain/retrievers/svm.py @@ -4,7 +4,7 @@ https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb""" from __future__ import annotations -from typing import Any, List +from typing import Any, List, Optional import numpy as np from pydantic import BaseModel @@ -22,6 +22,7 @@ class SVMRetriever(BaseRetriever, BaseModel): index: Any texts: List[str] k: int = 4 + relevancy_threshold: Optional[float] = None class Config: @@ -52,9 +53,25 @@ class SVMRetriever(BaseRetriever, BaseModel): similarities = clf.decision_function(x) sorted_ix = np.argsort(-similarities) + # svm.LinearSVC in scikit-learn is non-deterministic. + # if a text is the same as a query, there is no guarantee + # the query will be in the first index. + # this performs a simple swap, this works because anything + # left of the 0 should be equivalent. + zero_index = np.where(sorted_ix == 0)[0][0] + if zero_index != 0: + sorted_ix[0], sorted_ix[zero_index] = sorted_ix[zero_index], sorted_ix[0] + + denominator = np.max(similarities) - np.min(similarities) + 1e-6 + normalized_similarities = (similarities - np.min(similarities)) / denominator + top_k_results = [] for row in sorted_ix[1 : self.k + 1]: - top_k_results.append(Document(page_content=self.texts[row - 1])) + if ( + self.relevancy_threshold is None + or normalized_similarities[row] >= self.relevancy_threshold + ): + top_k_results.append(Document(page_content=self.texts[row - 1])) return top_k_results async def aget_relevant_documents(self, query: str) -> List[Document]: