SVM retriever (#2947) (#2949)

Add SVM retriever class, based on
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb.

Testing still WIP, but the logic is correct (I have a local
implementation outside of Langchain working).

---------

Co-authored-by: Lance Martin <122662504+PineappleExpress808@users.noreply.github.com>
Co-authored-by: rlm <31treehaus@31s-MacBook-Pro.local>
This commit is contained in:
Harrison Chase 2023-04-15 12:49:59 -07:00 committed by GitHub
parent baf350e32b
commit 274b25c010
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 191 additions and 0 deletions

View File

@ -0,0 +1,128 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ab66dd43",
"metadata": {},
"source": [
"# SVM Retriever\n",
"\n",
"This notebook goes over how to use a retriever that under the hood uses an SVM using scikit-learn.\n",
"\n",
"Largely based on https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "393ac030",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import SVMRetriever\n",
"from langchain.embeddings import OpenAIEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a801b57c",
"metadata": {},
"outputs": [],
"source": [
"# !pip install scikit-learn"
]
},
{
"cell_type": "markdown",
"id": "aaf80e7f",
"metadata": {},
"source": [
"## Create New Retriever with Texts"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "98b1c017",
"metadata": {},
"outputs": [],
"source": [
"retriever = SVMRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"], OpenAIEmbeddings())"
]
},
{
"cell_type": "markdown",
"id": "08437fa2",
"metadata": {},
"source": [
"## Use Retriever\n",
"\n",
"We can now use the retriever!"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c0455218",
"metadata": {},
"outputs": [],
"source": [
"result = retriever.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7dfa5c29",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='foo', metadata={}),\n",
" Document(page_content='foo bar', metadata={}),\n",
" Document(page_content='hello', metadata={}),\n",
" Document(page_content='world', metadata={})]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74bd9256",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -4,6 +4,7 @@ from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
from langchain.retrievers.svm import SVMRetriever
from langchain.retrievers.tfidf import TFIDFRetriever from langchain.retrievers.tfidf import TFIDFRetriever
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever
@ -16,4 +17,5 @@ __all__ = [
"TFIDFRetriever", "TFIDFRetriever",
"WeaviateHybridSearchRetriever", "WeaviateHybridSearchRetriever",
"DataberryRetriever", "DataberryRetriever",
"SVMRetriever",
] ]

View File

@ -0,0 +1,61 @@
"""SMV Retriever.
Largely based on
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
from __future__ import annotations
from typing import Any, List
import numpy as np
from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
return np.array([embeddings.embed_query(split) for split in contexts])
class SVMRetriever(BaseRetriever, BaseModel):
embeddings: Embeddings
index: Any
texts: List[str]
k: int = 4
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@classmethod
def from_texts(
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
) -> SVMRetriever:
index = create_index(texts, embeddings)
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def get_relevant_documents(self, query: str) -> List[Document]:
from sklearn import svm
query_embeds = np.array(self.embeddings.embed_query(query))
x = np.concatenate([query_embeds[None, ...], self.index])
y = np.zeros(x.shape[0])
y[0] = 1
clf = svm.LinearSVC(
class_weight="balanced", verbose=False, max_iter=10000, tol=1e-6, C=0.1
)
clf.fit(x, y)
similarities = clf.decision_function(x)
sorted_ix = np.argsort(-similarities)
top_k_results = []
for row in sorted_ix[1 : self.k + 1]:
top_k_results.append(Document(page_content=self.texts[row - 1]))
return top_k_results
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError