Harrison/knn retriever (#4083)

Co-authored-by: Yuichi Tateno (secon) <hotchpotch@users.noreply.github.com>
fix_agent_callbacks
Harrison Chase 1 year ago committed by GitHub
parent 65c3b146c9
commit 5f30cc8713
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,111 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "ab66dd43",
"metadata": {},
"source": [
"# kNN Retriever\n",
"\n",
"This notebook goes over how to use a retriever that under the hood uses an kNN.\n",
"\n",
"Largely based on https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "393ac030",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import KNNRetriever\n",
"from langchain.embeddings import OpenAIEmbeddings"
]
},
{
"cell_type": "markdown",
"id": "aaf80e7f",
"metadata": {},
"source": [
"## Create New Retriever with Texts"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "98b1c017",
"metadata": {},
"outputs": [],
"source": [
"retriever = KNNRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"], OpenAIEmbeddings())"
]
},
{
"cell_type": "markdown",
"id": "08437fa2",
"metadata": {},
"source": [
"## Use Retriever\n",
"\n",
"We can now use the retriever!"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c0455218",
"metadata": {},
"outputs": [],
"source": [
"result = retriever.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7dfa5c29",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='foo', metadata={}),\n",
" Document(page_content='foo bar', metadata={}),\n",
" Document(page_content='hello', metadata={}),\n",
" Document(page_content='bar', metadata={})]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -56,7 +56,7 @@
{
"data": {
"text/plain": [
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={})"
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 3,

@ -162,17 +162,25 @@ class ChatOpenAI(BaseChatModel):
"OPENAI_ORGANIZATION",
default="",
)
openai_api_base = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
try:
import openai
openai.api_key = openai_api_key
if openai_organization:
openai.organization = openai_organization
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
openai.api_key = openai_api_key
if openai_organization:
openai.organization = openai_organization
if openai_api_base:
openai.api_base = openai_api_base
try:
values["client"] = openai.ChatCompletion
except AttributeError:

@ -2,6 +2,7 @@ from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.databerry import DataberryRetriever
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.knn import KNNRetriever
from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
@ -25,5 +26,6 @@ __all__ = [
"DataberryRetriever",
"TimeWeightedVectorStoreRetriever",
"SVMRetriever",
"KNNRetriever",
"VespaRetriever",
]

@ -0,0 +1,64 @@
"""KNN Retriever.
Largely based on
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
from __future__ import annotations
import concurrent.futures
from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
with concurrent.futures.ThreadPoolExecutor() as executor:
return np.array(list(executor.map(embeddings.embed_query, contexts)))
class KNNRetriever(BaseRetriever, BaseModel):
embeddings: Embeddings
index: Any
texts: List[str]
k: int = 4
relevancy_threshold: Optional[float] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@classmethod
def from_texts(
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
) -> KNNRetriever:
index = create_index(texts, embeddings)
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def get_relevant_documents(self, query: str) -> List[Document]:
query_embeds = np.array(self.embeddings.embed_query(query))
# calc L2 norm
index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
query_embeds = query_embeds / np.sqrt((query_embeds**2).sum())
similarities = index_embeds.dot(query_embeds)
sorted_ix = np.argsort(-similarities)
denominator = np.max(similarities) - np.min(similarities) + 1e-6
normalized_similarities = (similarities - np.min(similarities)) / denominator
top_k_results = []
for row in sorted_ix[0 : self.k]:
if (
self.relevancy_threshold is None
or normalized_similarities[row] >= self.relevancy_threshold
):
top_k_results.append(Document(page_content=self.texts[row]))
return top_k_results
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
Loading…
Cancel
Save