From 4c02f4bc30eb8640b9c28976be8fe9a86529d31c Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 16 Apr 2023 12:57:18 -0700 Subject: [PATCH] Fix bug in svm.LinearSVC, add support for a relevancy_threshold (#2959) (#2981) - Modify SVMRetriever class to add an optional relevancy_threshold - Modify SVMRetriever.get_relevant_documents method to filter out documents with similarity scores below the relevancy threshold - Normalized the similarities to be between 0 and 1 so the relevancy_threshold makes more sense - The number of results are limited to the top k documents or the maximum number of relevant documents above the threshold, whichever is smaller This code will now return the top self.k results (or less, if there are not enough results that meet the self.relevancy_threshold criteria). The svm.LinearSVC implementation in scikit-learn is non-deterministic, which means SVMRetriever.from_texts(["bar", "world", "foo", "hello", "foo bar"]) could return [3 0 5 4 2 1] instead of [0 3 5 4 2 1] with a query of "foo". If you pass in multiple "foo" texts, the order could be different each time. Here, we only care if the 0 is the first element, otherwise it will offset the text and similarities. Example: ```python retriever = SVMRetriever.from_texts( ["foo", "bar", "world", "hello", "foo bar"], OpenAIEmbeddings(), k=4, relevancy_threshold=.25 ) result = retriever.get_relevant_documents("foo") ``` yields ```python [Document(page_content='foo', metadata={}), Document(page_content='foo bar', metadata={})] ``` --------- Co-authored-by: Brandon Sandoval <52767641+account00001@users.noreply.github.com> --- langchain/retrievers/svm.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/langchain/retrievers/svm.py b/langchain/retrievers/svm.py index d69abc4b..2060d3b0 100644 --- a/langchain/retrievers/svm.py +++ b/langchain/retrievers/svm.py @@ -4,7 +4,7 @@ https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb""" from __future__ import annotations -from typing import Any, List +from typing import Any, List, Optional import numpy as np from pydantic import BaseModel @@ -22,6 +22,7 @@ class SVMRetriever(BaseRetriever, BaseModel): index: Any texts: List[str] k: int = 4 + relevancy_threshold: Optional[float] = None class Config: @@ -52,9 +53,25 @@ class SVMRetriever(BaseRetriever, BaseModel): similarities = clf.decision_function(x) sorted_ix = np.argsort(-similarities) + # svm.LinearSVC in scikit-learn is non-deterministic. + # if a text is the same as a query, there is no guarantee + # the query will be in the first index. + # this performs a simple swap, this works because anything + # left of the 0 should be equivalent. + zero_index = np.where(sorted_ix == 0)[0][0] + if zero_index != 0: + sorted_ix[0], sorted_ix[zero_index] = sorted_ix[zero_index], sorted_ix[0] + + denominator = np.max(similarities) - np.min(similarities) + 1e-6 + normalized_similarities = (similarities - np.min(similarities)) / denominator + top_k_results = [] for row in sorted_ix[1 : self.k + 1]: - top_k_results.append(Document(page_content=self.texts[row - 1])) + if ( + self.relevancy_threshold is None + or normalized_similarities[row] >= self.relevancy_threshold + ): + top_k_results.append(Document(page_content=self.texts[row - 1])) return top_k_results async def aget_relevant_documents(self, query: str) -> List[Document]: