Harrison/tfidf retriever (#2440)

This commit is contained in:
Harrison Chase 2023-04-05 07:36:49 -07:00 committed by GitHub
parent a63cfad558
commit 00bc8df640
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 205 additions and 31 deletions

View File

@ -14,7 +14,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "393ac030", "id": "393ac030",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -32,13 +32,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 2,
"id": "bcb3c8c2", "id": "bcb3c8c2",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"elasticsearch_url=\"http://localhost:9200\"\n", "elasticsearch_url=\"http://localhost:9200\"\n",
"retriever = ElasticSearchBM25Retriever.create(elasticsearch_url, \"langchain-index-3\")" "retriever = ElasticSearchBM25Retriever.create(elasticsearch_url, \"langchain-index-4\")"
] ]
}, },
{ {
@ -66,21 +66,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 3,
"id": "98b1c017", "id": "98b1c017",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"['386c76c9-4355-4c12-aaeb-7b80054caf93',\n", "['cbd4cb47-8d9f-4f34-b80e-ea871bc49856',\n",
" 'fffd279c-a0c9-4158-a904-6e242c517c99',\n", " 'f3bd2e24-76d1-4f9b-826b-ec4c0e8c7365',\n",
" '7f5528a3-18d0-43b0-894d-f6770a002219',\n", " '8631bfc8-7c12-48ee-ab56-8ad5f373676e',\n",
" 'e2ef5e32-d5bd-44e2-b045-cfc5a8e0a0a1',\n", " '8be8374c-3253-4d87-928d-d73550a2ecf0',\n",
" 'cce8ba48-e473-4235-bca2-2c8d65e73ccf']" " 'd79f457b-2842-4eab-ae10-77aa420b53d7']"
] ]
}, },
"execution_count": 14, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -101,7 +101,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 4,
"id": "c0455218", "id": "c0455218",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -111,7 +111,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 5,
"id": "7dfa5c29", "id": "7dfa5c29",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -122,7 +122,7 @@
" Document(page_content='foo bar', metadata={})]" " Document(page_content='foo bar', metadata={})]"
] ]
}, },
"execution_count": 16, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }

View File

@ -0,0 +1,127 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ab66dd43",
"metadata": {},
"source": [
"# TF-IDF Retriever\n",
"\n",
"This notebook goes over how to use a retriever that under the hood uses TF-IDF using scikit-learn.\n",
"\n",
"For more information on the details of TF-IDF see [this blog post](https://medium.com/data-science-bootcamp/tf-idf-basics-of-information-retrieval-48de122b2a4c)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "393ac030",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import TFIDFRetriever"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a801b57c",
"metadata": {},
"outputs": [],
"source": [
"# !pip install scikit-learn"
]
},
{
"cell_type": "markdown",
"id": "aaf80e7f",
"metadata": {},
"source": [
"## Create New Retriever with Texts"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "98b1c017",
"metadata": {},
"outputs": [],
"source": [
"retriever = TFIDFRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"])"
]
},
{
"cell_type": "markdown",
"id": "08437fa2",
"metadata": {},
"source": [
"## Use Retriever\n",
"\n",
"We can now use the retriever!"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c0455218",
"metadata": {},
"outputs": [],
"source": [
"result = retriever.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7dfa5c29",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='foo', metadata={}),\n",
" Document(page_content='foo bar', metadata={}),\n",
" Document(page_content='hello', metadata={}),\n",
" Document(page_content='world', metadata={})]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74bd9256",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -3,6 +3,7 @@ from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
from langchain.retrievers.tfidf import TFIDFRetriever
__all__ = [ __all__ = [
"ChatGPTPluginRetriever", "ChatGPTPluginRetriever",
@ -10,4 +11,5 @@ __all__ = [
"PineconeHybridSearchRetriever", "PineconeHybridSearchRetriever",
"MetalRetriever", "MetalRetriever",
"ElasticSearchBM25Retriever", "ElasticSearchBM25Retriever",
"TFIDFRetriever",
] ]

View File

@ -49,29 +49,27 @@ class ElasticSearchBM25Retriever(BaseRetriever):
es = Elasticsearch(elasticsearch_url) es = Elasticsearch(elasticsearch_url)
# Define the index settings and mappings # Define the index settings and mappings
index_settings = { settings = {
"settings": { "analysis": {"analyzer": {"default": {"type": "standard"}}},
"analysis": {"analyzer": {"default": {"type": "standard"}}}, "similarity": {
"similarity": { "custom_bm25": {
"custom_bm25": { "type": "BM25",
"type": "BM25", "k1": k1,
"k1": k1, "b": b,
"b": b,
}
},
},
"mappings": {
"properties": {
"content": {
"type": "text",
"similarity": "custom_bm25", # Use the custom BM25 similarity
}
} }
}, },
} }
mappings = {
"properties": {
"content": {
"type": "text",
"similarity": "custom_bm25", # Use the custom BM25 similarity
}
}
}
# Create the index with the specified settings and mappings # Create the index with the specified settings and mappings
es.indices.create(index=index_name, body=index_settings) es.indices.create(index=index_name, mappings=mappings, settings=settings)
return cls(es, index_name) return cls(es, index_name)
def add_texts( def add_texts(

View File

@ -0,0 +1,47 @@
"""TF-IDF Retriever.
Largely based on
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
from typing import Any, List
from pydantic import BaseModel
from langchain.schema import BaseRetriever, Document
class TFIDFRetriever(BaseRetriever, BaseModel):
vectorizer: Any
docs: List[Document]
tfidf_array: Any
k: int = 4
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@classmethod
def from_texts(cls, texts: List[str], **kwargs: Any) -> "TFIDFRetriever":
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer()
tfidf_array = vectorizer.fit_transform(texts)
docs = [Document(page_content=t) for t in texts]
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)
def get_relevant_documents(self, query: str) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
query_vec = self.vectorizer.transform(
[query]
) # Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
results = cosine_similarity(self.tfidf_array, query_vec).reshape(
(-1,)
) # Op -- (n_docs,1) -- Cosine Sim with each doc
return_docs = []
for i in results.argsort()[-self.k :][::-1]:
return_docs.append(self.docs[i])
return return_docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError