add Hybrid retriever that not require any external service (#8108)

- Until now, hybrid search was limited to modules requiring external
services, such as Weaviate/Pinecone Hybrid Search. However, I have
developed a hybrid retriever that can merge a list of retrievers using
the [Reciprocal Rank
Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf)
algorithm. This new approach, similar to Weaviate hybrid search, does
not require the initialization of any external service.
  - Dependencies: No  - Twitter handle: dayuanjian21687

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Dayuan Jiang 2023-07-25 11:16:10 +09:00 committed by GitHub
parent 04e45f9cde
commit 125ae6d9de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 330 additions and 0 deletions

View File

@ -0,0 +1,102 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ensemble Retriever\n",
"\n",
"The `EnsembleRetriever` takes a list of retrievers as input and ensemble the results of their get_relevant_documents() methods and rerank the results based on the [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) algorithm.\n",
"\n",
"By leveraging the strengths of different algorithms, the `EnsembleRetriever` can achieve better performance than any single algorithm. \n",
"\n",
"The most common pattern is to combine a sparse retriever(like BM25) with a dense retriever(like Embedding similarity), because their strengths are complementary. It is also known as \"hybrid search\".The sparse retriever is good at finding relevant documents based on keywords, while the dense retriever is good at finding relevant documents based on semantic similarity."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import BM25Retriever, EnsembleRetriever\n",
"from langchain.vectorstores import FAISS"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"doc_list = [\n",
" \"I like apples\",\n",
" \"I like oranges\",\n",
" \"Apples and oranges are fruits\",\n",
"]\n",
"\n",
"# initialize the bm25 retriever and faiss retriever\n",
"bm25_retriever = BM25Retriever.from_texts(doc_list)\n",
"bm25_retriever.k = 2\n",
"\n",
"embedding = OpenAIEmbeddings()\n",
"faiss_vectorstore = FAISS.from_texts(doc_list, embedding)\n",
"faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={\"k\": 2})\n",
"\n",
"# initialize the ensemble retriever\n",
"ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='I like apples', metadata={}),\n",
" Document(page_content='Apples and oranges are fruits', metadata={})]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"docs = ensemble_retriever.get_relevant_documents(\"apples\")\n",
"docs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -6,6 +6,7 @@ from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.docarray import DocArrayRetriever from langchain.retrievers.docarray import DocArrayRetriever
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.retrievers.google_cloud_enterprise_search import ( from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever, GoogleCloudEnterpriseSearchRetriever,
) )
@ -64,4 +65,5 @@ __all__ = [
"ZepRetriever", "ZepRetriever",
"ZillizRetriever", "ZillizRetriever",
"DocArrayRetriever", "DocArrayRetriever",
"EnsembleRetriever",
] ]

View File

@ -0,0 +1,184 @@
"""
Ensemble retriever that ensemble the results of
multiple retrievers by using weighted Reciprocal Rank Fusion
"""
from typing import Any, Dict, List
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
class EnsembleRetriever(BaseRetriever):
"""
This class ensemble the results of multiple retrievers by using rank fusion.
Args:
retrievers: A list of retrievers to ensemble.
weights: A list of weights corresponding to the retrievers. Defaults to equal
weighting for all retrievers.
c: A constant added to the rank, controlling the balance between the importance
of high-ranked items and the consideration given to lower-ranked items.
Default is 60.
"""
retrievers: List[BaseRetriever]
weights: List[float]
c: int = 60
@root_validator(pre=True)
def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if not values.get("weights"):
n_retrievers = len(values["retrievers"])
values["weights"] = [1 / n_retrievers] * n_retrievers
return values
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of reranked documents.
"""
# Get fused result of the retrievers.
fused_documents = self.rank_fusion(query, run_manager)
return fused_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of reranked documents.
"""
# Get fused result of the retrievers.
fused_documents = await self.arank_fusion(query, run_manager)
return fused_documents
def rank_fusion(
self, query: str, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
Retrieve the results of the retrievers and use rank_fusion_func to get
the final result.
Args:
query: The query to search for.
Returns:
A list of reranked documents.
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.get_relevant_documents(
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
)
for i, retriever in enumerate(self.retrievers)
]
# apply rank fusion
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
return fused_documents
async def arank_fusion(
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""
Asynchronously retrieve the results of the retrievers
and use rank_fusion_func to get the final result.
Args:
query: The query to search for.
Returns:
A list of reranked documents.
"""
# Get the results of all retrievers.
retriever_docs = [
await retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
)
for i, retriever in enumerate(self.retrievers)
]
# apply rank fusion
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
return fused_documents
def weighted_reciprocal_rank(
self, doc_lists: List[List[Document]]
) -> List[Document]:
"""
Perform weighted Reciprocal Rank Fusion on multiple rank lists.
You can find more details about RRF here:
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
Args:
doc_lists: A list of rank lists, where each rank list contains unique items.
Returns:
list: The final aggregated list of items sorted by their weighted RRF
scores in descending order.
"""
if len(doc_lists) != len(self.weights):
raise ValueError(
"Number of rank lists must be equal to the number of weights."
)
# Create a union of all unique documents in the input doc_lists
all_documents = set()
for doc_list in doc_lists:
for doc in doc_list:
all_documents.add(doc.page_content)
# Initialize the RRF score dictionary for each document
rrf_score_dic = {doc: 0.0 for doc in all_documents}
# Calculate RRF scores for each document
for doc_list, weight in zip(doc_lists, self.weights):
for rank, doc in enumerate(doc_list, start=1):
rrf_score = weight * (1 / (rank + self.c))
rrf_score_dic[doc.page_content] += rrf_score
# Sort documents by their RRF scores in descending order
sorted_documents = sorted(
rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
)
# Map the sorted page_content back to the original document objects
page_content_to_doc_map = {
doc.page_content: doc for doc_list in doc_lists for doc in doc_list
}
sorted_docs = [
page_content_to_doc_map[page_content] for page_content in sorted_documents
]
return sorted_docs

View File

@ -0,0 +1,42 @@
import pytest
from langchain.retrievers.bm25 import BM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.schema import Document
@pytest.mark.requires("rank_bm25")
def test_ensemble_retriever_get_relevant_docs() -> None:
doc_list = [
"I like apples",
"I like oranges",
"Apples and oranges are fruits",
]
dummy_retriever = BM25Retriever.from_texts(doc_list)
dummy_retriever.k = 1
ensemble_retriever = EnsembleRetriever(
retrievers=[dummy_retriever, dummy_retriever]
)
docs = ensemble_retriever.get_relevant_documents("I like apples")
assert len(docs) == 1
@pytest.mark.requires("rank_bm25")
def test_weighted_reciprocal_rank() -> None:
doc1 = Document(page_content="1")
doc2 = Document(page_content="2")
dummy_retriever = BM25Retriever.from_texts(["1", "2"])
ensemble_retriever = EnsembleRetriever(
retrievers=[dummy_retriever, dummy_retriever], weights=[0.4, 0.5], c=0
)
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
assert result[0].page_content == "2"
assert result[1].page_content == "1"
ensemble_retriever.weights = [0.5, 0.4]
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
assert result[0].page_content == "1"
assert result[1].page_content == "2"