forked from Archives/langchain
- Modify SVMRetriever class to add an optional relevancy_threshold - Modify SVMRetriever.get_relevant_documents method to filter out documents with similarity scores below the relevancy threshold - Normalized the similarities to be between 0 and 1 so the relevancy_threshold makes more sense - The number of results are limited to the top k documents or the maximum number of relevant documents above the threshold, whichever is smaller This code will now return the top self.k results (or less, if there are not enough results that meet the self.relevancy_threshold criteria). The svm.LinearSVC implementation in scikit-learn is non-deterministic, which means SVMRetriever.from_texts(["bar", "world", "foo", "hello", "foo bar"]) could return [3 0 5 4 2 1] instead of [0 3 5 4 2 1] with a query of "foo". If you pass in multiple "foo" texts, the order could be different each time. Here, we only care if the 0 is the first element, otherwise it will offset the text and similarities. Example: ```python retriever = SVMRetriever.from_texts( ["foo", "bar", "world", "hello", "foo bar"], OpenAIEmbeddings(), k=4, relevancy_threshold=.25 ) result = retriever.get_relevant_documents("foo") ``` yields ```python [Document(page_content='foo', metadata={}), Document(page_content='foo bar', metadata={})] ``` --------- Co-authored-by: Brandon Sandoval <52767641+account00001@users.noreply.github.com>
This commit is contained in:
parent
7302787a7b
commit
4c02f4bc30
@ -4,7 +4,7 @@ https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, List
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -22,6 +22,7 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
|||||||
index: Any
|
index: Any
|
||||||
texts: List[str]
|
texts: List[str]
|
||||||
k: int = 4
|
k: int = 4
|
||||||
|
relevancy_threshold: Optional[float] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
||||||
@ -52,9 +53,25 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
|||||||
similarities = clf.decision_function(x)
|
similarities = clf.decision_function(x)
|
||||||
sorted_ix = np.argsort(-similarities)
|
sorted_ix = np.argsort(-similarities)
|
||||||
|
|
||||||
|
# svm.LinearSVC in scikit-learn is non-deterministic.
|
||||||
|
# if a text is the same as a query, there is no guarantee
|
||||||
|
# the query will be in the first index.
|
||||||
|
# this performs a simple swap, this works because anything
|
||||||
|
# left of the 0 should be equivalent.
|
||||||
|
zero_index = np.where(sorted_ix == 0)[0][0]
|
||||||
|
if zero_index != 0:
|
||||||
|
sorted_ix[0], sorted_ix[zero_index] = sorted_ix[zero_index], sorted_ix[0]
|
||||||
|
|
||||||
|
denominator = np.max(similarities) - np.min(similarities) + 1e-6
|
||||||
|
normalized_similarities = (similarities - np.min(similarities)) / denominator
|
||||||
|
|
||||||
top_k_results = []
|
top_k_results = []
|
||||||
for row in sorted_ix[1 : self.k + 1]:
|
for row in sorted_ix[1 : self.k + 1]:
|
||||||
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
if (
|
||||||
|
self.relevancy_threshold is None
|
||||||
|
or normalized_similarities[row] >= self.relevancy_threshold
|
||||||
|
):
|
||||||
|
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
||||||
return top_k_results
|
return top_k_results
|
||||||
|
|
||||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
Loading…
Reference in New Issue
Block a user