diff --git a/docs/modules/indexes/retrievers/examples/svm_retriever.ipynb b/docs/modules/indexes/retrievers/examples/svm_retriever.ipynb new file mode 100644 index 00000000..ad14b33d --- /dev/null +++ b/docs/modules/indexes/retrievers/examples/svm_retriever.ipynb @@ -0,0 +1,128 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ab66dd43", + "metadata": {}, + "source": [ + "# SVM Retriever\n", + "\n", + "This notebook goes over how to use a retriever that under the hood uses an SVM using scikit-learn.\n", + "\n", + "Largely based on https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "393ac030", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import SVMRetriever\n", + "from langchain.embeddings import OpenAIEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a801b57c", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install scikit-learn" + ] + }, + { + "cell_type": "markdown", + "id": "aaf80e7f", + "metadata": {}, + "source": [ + "## Create New Retriever with Texts" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "98b1c017", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = SVMRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"], OpenAIEmbeddings())" + ] + }, + { + "cell_type": "markdown", + "id": "08437fa2", + "metadata": {}, + "source": [ + "## Use Retriever\n", + "\n", + "We can now use the retriever!" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c0455218", + "metadata": {}, + "outputs": [], + "source": [ + "result = retriever.get_relevant_documents(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7dfa5c29", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='foo', metadata={}),\n", + " Document(page_content='foo bar', metadata={}),\n", + " Document(page_content='hello', metadata={}),\n", + " Document(page_content='world', metadata={})]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74bd9256", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index cb1ed4ce..9ca44072 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -4,6 +4,7 @@ from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever +from langchain.retrievers.svm import SVMRetriever from langchain.retrievers.tfidf import TFIDFRetriever from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever @@ -16,4 +17,5 @@ __all__ = [ "TFIDFRetriever", "WeaviateHybridSearchRetriever", "DataberryRetriever", + "SVMRetriever", ] diff --git a/langchain/retrievers/svm.py b/langchain/retrievers/svm.py new file mode 100644 index 00000000..d69abc4b --- /dev/null +++ b/langchain/retrievers/svm.py @@ -0,0 +1,61 @@ +"""SMV Retriever. +Largely based on +https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb""" + +from __future__ import annotations + +from typing import Any, List + +import numpy as np +from pydantic import BaseModel + +from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever, Document + + +def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: + return np.array([embeddings.embed_query(split) for split in contexts]) + + +class SVMRetriever(BaseRetriever, BaseModel): + embeddings: Embeddings + index: Any + texts: List[str] + k: int = 4 + + class Config: + + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @classmethod + def from_texts( + cls, texts: List[str], embeddings: Embeddings, **kwargs: Any + ) -> SVMRetriever: + index = create_index(texts, embeddings) + return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) + + def get_relevant_documents(self, query: str) -> List[Document]: + from sklearn import svm + + query_embeds = np.array(self.embeddings.embed_query(query)) + x = np.concatenate([query_embeds[None, ...], self.index]) + y = np.zeros(x.shape[0]) + y[0] = 1 + + clf = svm.LinearSVC( + class_weight="balanced", verbose=False, max_iter=10000, tol=1e-6, C=0.1 + ) + clf.fit(x, y) + + similarities = clf.decision_function(x) + sorted_ix = np.argsort(-similarities) + + top_k_results = [] + for row in sorted_ix[1 : self.k + 1]: + top_k_results.append(Document(page_content=self.texts[row - 1])) + return top_k_results + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError