Fix bug in svm.LinearSVC, add support for a relevancy_threshold (#2959) (#2981)

- Modify SVMRetriever class to add an optional relevancy_threshold
- Modify SVMRetriever.get_relevant_documents method to filter out
documents with similarity scores below the relevancy threshold
- Normalized the similarities to be between 0 and 1 so the
relevancy_threshold makes more sense
- The number of results are limited to the top k documents or the
maximum number of relevant documents above the threshold, whichever is
smaller

This code will now return the top self.k results (or less, if there are
not enough results that meet the self.relevancy_threshold criteria).

The svm.LinearSVC implementation in scikit-learn is non-deterministic,
which means
SVMRetriever.from_texts(["bar", "world", "foo", "hello", "foo bar"])
could return [3 0 5 4 2 1] instead of [0 3 5 4 2 1] with a query of
"foo".
If you pass in multiple "foo" texts, the order could be different each
time. Here, we only care if the 0 is the first element, otherwise it
will offset the text and similarities.


Example:
```python
retriever = SVMRetriever.from_texts(
  ["foo", "bar", "world", "hello", "foo bar"],
  OpenAIEmbeddings(),
  k=4,
  relevancy_threshold=.25
)

result = retriever.get_relevant_documents("foo")
```
yields
```python
[Document(page_content='foo', metadata={}), Document(page_content='foo bar', metadata={})]
```

---------

Co-authored-by: Brandon Sandoval <52767641+account00001@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-04-16 12:57:18 -07:00 committed by GitHub
parent 7302787a7b
commit 4c02f4bc30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
from __future__ import annotations
from typing import Any, List
from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel
@ -22,6 +22,7 @@ class SVMRetriever(BaseRetriever, BaseModel):
index: Any
texts: List[str]
k: int = 4
relevancy_threshold: Optional[float] = None
class Config:
@ -52,9 +53,25 @@ class SVMRetriever(BaseRetriever, BaseModel):
similarities = clf.decision_function(x)
sorted_ix = np.argsort(-similarities)
# svm.LinearSVC in scikit-learn is non-deterministic.
# if a text is the same as a query, there is no guarantee
# the query will be in the first index.
# this performs a simple swap, this works because anything
# left of the 0 should be equivalent.
zero_index = np.where(sorted_ix == 0)[0][0]
if zero_index != 0:
sorted_ix[0], sorted_ix[zero_index] = sorted_ix[zero_index], sorted_ix[0]
denominator = np.max(similarities) - np.min(similarities) + 1e-6
normalized_similarities = (similarities - np.min(similarities)) / denominator
top_k_results = []
for row in sorted_ix[1 : self.k + 1]:
top_k_results.append(Document(page_content=self.texts[row - 1]))
if (
self.relevancy_threshold is None
or normalized_similarities[row] >= self.relevancy_threshold
):
top_k_results.append(Document(page_content=self.texts[row - 1]))
return top_k_results
async def aget_relevant_documents(self, query: str) -> List[Document]: