mirror of https://github.com/hwchase17/langchain
Harrison/knn retriever (#4083)
Co-authored-by: Yuichi Tateno (secon) <hotchpotch@users.noreply.github.com>pull/4084/head
parent
65c3b146c9
commit
5f30cc8713
@ -0,0 +1,111 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "ab66dd43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# kNN Retriever\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use a retriever that under the hood uses an kNN.\n",
|
||||
"\n",
|
||||
"Largely based on https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "393ac030",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.retrievers import KNNRetriever\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aaf80e7f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create New Retriever with Texts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "98b1c017",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = KNNRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"], OpenAIEmbeddings())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "08437fa2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use Retriever\n",
|
||||
"\n",
|
||||
"We can now use the retriever!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c0455218",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = retriever.get_relevant_documents(\"foo\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "7dfa5c29",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='foo', metadata={}),\n",
|
||||
" Document(page_content='foo bar', metadata={}),\n",
|
||||
" Document(page_content='hello', metadata={}),\n",
|
||||
" Document(page_content='bar', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
"""KNN Retriever.
|
||||
Largely based on
|
||||
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||
|
||||
|
||||
class KNNRetriever(BaseRetriever, BaseModel):
|
||||
embeddings: Embeddings
|
||||
index: Any
|
||||
texts: List[str]
|
||||
k: int = 4
|
||||
relevancy_threshold: Optional[float] = None
|
||||
|
||||
class Config:
|
||||
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
|
||||
) -> KNNRetriever:
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||
# calc L2 norm
|
||||
index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
|
||||
query_embeds = query_embeds / np.sqrt((query_embeds**2).sum())
|
||||
|
||||
similarities = index_embeds.dot(query_embeds)
|
||||
sorted_ix = np.argsort(-similarities)
|
||||
|
||||
denominator = np.max(similarities) - np.min(similarities) + 1e-6
|
||||
normalized_similarities = (similarities - np.min(similarities)) / denominator
|
||||
|
||||
top_k_results = []
|
||||
for row in sorted_ix[0 : self.k]:
|
||||
if (
|
||||
self.relevancy_threshold is None
|
||||
or normalized_similarities[row] >= self.relevancy_threshold
|
||||
):
|
||||
top_k_results.append(Document(page_content=self.texts[row]))
|
||||
return top_k_results
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue