mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add SVM retriever class, based on https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb. Testing still WIP, but the logic is correct (I have a local implementation outside of Langchain working). --------- Co-authored-by: Lance Martin <122662504+PineappleExpress808@users.noreply.github.com> Co-authored-by: rlm <31treehaus@31s-MacBook-Pro.local>
This commit is contained in:
parent
baf350e32b
commit
274b25c010
128
docs/modules/indexes/retrievers/examples/svm_retriever.ipynb
Normal file
128
docs/modules/indexes/retrievers/examples/svm_retriever.ipynb
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ab66dd43",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# SVM Retriever\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use a retriever that under the hood uses an SVM using scikit-learn.\n",
|
||||||
|
"\n",
|
||||||
|
"Largely based on https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "393ac030",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.retrievers import SVMRetriever\n",
|
||||||
|
"from langchain.embeddings import OpenAIEmbeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "a801b57c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# !pip install scikit-learn"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "aaf80e7f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create New Retriever with Texts"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "98b1c017",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = SVMRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"], OpenAIEmbeddings())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "08437fa2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Use Retriever\n",
|
||||||
|
"\n",
|
||||||
|
"We can now use the retriever!"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "c0455218",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"result = retriever.get_relevant_documents(\"foo\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "7dfa5c29",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[Document(page_content='foo', metadata={}),\n",
|
||||||
|
" Document(page_content='foo bar', metadata={}),\n",
|
||||||
|
" Document(page_content='hello', metadata={}),\n",
|
||||||
|
" Document(page_content='world', metadata={})]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"result"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "74bd9256",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -4,6 +4,7 @@ from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
|
|||||||
from langchain.retrievers.metal import MetalRetriever
|
from langchain.retrievers.metal import MetalRetriever
|
||||||
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
|
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
|
||||||
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
|
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
|
||||||
|
from langchain.retrievers.svm import SVMRetriever
|
||||||
from langchain.retrievers.tfidf import TFIDFRetriever
|
from langchain.retrievers.tfidf import TFIDFRetriever
|
||||||
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever
|
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever
|
||||||
|
|
||||||
@ -16,4 +17,5 @@ __all__ = [
|
|||||||
"TFIDFRetriever",
|
"TFIDFRetriever",
|
||||||
"WeaviateHybridSearchRetriever",
|
"WeaviateHybridSearchRetriever",
|
||||||
"DataberryRetriever",
|
"DataberryRetriever",
|
||||||
|
"SVMRetriever",
|
||||||
]
|
]
|
||||||
|
61
langchain/retrievers/svm.py
Normal file
61
langchain/retrievers/svm.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""SMV Retriever.
|
||||||
|
Largely based on
|
||||||
|
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
|
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||||
|
return np.array([embeddings.embed_query(split) for split in contexts])
|
||||||
|
|
||||||
|
|
||||||
|
class SVMRetriever(BaseRetriever, BaseModel):
|
||||||
|
embeddings: Embeddings
|
||||||
|
index: Any
|
||||||
|
texts: List[str]
|
||||||
|
k: int = 4
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
|
||||||
|
) -> SVMRetriever:
|
||||||
|
index = create_index(texts, embeddings)
|
||||||
|
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
from sklearn import svm
|
||||||
|
|
||||||
|
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||||
|
x = np.concatenate([query_embeds[None, ...], self.index])
|
||||||
|
y = np.zeros(x.shape[0])
|
||||||
|
y[0] = 1
|
||||||
|
|
||||||
|
clf = svm.LinearSVC(
|
||||||
|
class_weight="balanced", verbose=False, max_iter=10000, tol=1e-6, C=0.1
|
||||||
|
)
|
||||||
|
clf.fit(x, y)
|
||||||
|
|
||||||
|
similarities = clf.decision_function(x)
|
||||||
|
sorted_ix = np.argsort(-similarities)
|
||||||
|
|
||||||
|
top_k_results = []
|
||||||
|
for row in sorted_ix[1 : self.k + 1]:
|
||||||
|
top_k_results.append(Document(page_content=self.texts[row - 1]))
|
||||||
|
return top_k_results
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
raise NotImplementedError
|
Loading…
Reference in New Issue
Block a user