mirror of https://github.com/hwchase17/langchain
Harrison/knn retriever (#4083)
Co-authored-by: Yuichi Tateno (secon) <hotchpotch@users.noreply.github.com>pull/4084/head
parent
65c3b146c9
commit
5f30cc8713
@ -0,0 +1,111 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ab66dd43",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# kNN Retriever\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use a retriever that under the hood uses an kNN.\n",
|
||||||
|
"\n",
|
||||||
|
"Largely based on https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "393ac030",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.retrievers import KNNRetriever\n",
|
||||||
|
"from langchain.embeddings import OpenAIEmbeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "aaf80e7f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create New Retriever with Texts"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "98b1c017",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = KNNRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"], OpenAIEmbeddings())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "08437fa2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Use Retriever\n",
|
||||||
|
"\n",
|
||||||
|
"We can now use the retriever!"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "c0455218",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"result = retriever.get_relevant_documents(\"foo\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "7dfa5c29",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[Document(page_content='foo', metadata={}),\n",
|
||||||
|
" Document(page_content='foo bar', metadata={}),\n",
|
||||||
|
" Document(page_content='hello', metadata={}),\n",
|
||||||
|
" Document(page_content='bar', metadata={})]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"result"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,64 @@
|
|||||||
|
"""KNN Retriever.
|
||||||
|
Largely based on
|
||||||
|
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
|
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||||
|
|
||||||
|
|
||||||
|
class KNNRetriever(BaseRetriever, BaseModel):
|
||||||
|
embeddings: Embeddings
|
||||||
|
index: Any
|
||||||
|
texts: List[str]
|
||||||
|
k: int = 4
|
||||||
|
relevancy_threshold: Optional[float] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
|
||||||
|
) -> KNNRetriever:
|
||||||
|
index = create_index(texts, embeddings)
|
||||||
|
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||||
|
# calc L2 norm
|
||||||
|
index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
|
||||||
|
query_embeds = query_embeds / np.sqrt((query_embeds**2).sum())
|
||||||
|
|
||||||
|
similarities = index_embeds.dot(query_embeds)
|
||||||
|
sorted_ix = np.argsort(-similarities)
|
||||||
|
|
||||||
|
denominator = np.max(similarities) - np.min(similarities) + 1e-6
|
||||||
|
normalized_similarities = (similarities - np.min(similarities)) / denominator
|
||||||
|
|
||||||
|
top_k_results = []
|
||||||
|
for row in sorted_ix[0 : self.k]:
|
||||||
|
if (
|
||||||
|
self.relevancy_threshold is None
|
||||||
|
or normalized_similarities[row] >= self.relevancy_threshold
|
||||||
|
):
|
||||||
|
top_k_results.append(Document(page_content=self.texts[row]))
|
||||||
|
return top_k_results
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
raise NotImplementedError
|
Loading…
Reference in New Issue