infra: add test for ensemble retriever to ensure multiple retrievers (#8401)

Add tests to ensemble retriever to ensure it works with combination of
multiple retrievers

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/17507/head
shibuiwilliam 4 months ago committed by GitHub
parent 5738143d4b
commit c502736841
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,6 +1,8 @@
import pytest
from langchain_core.documents import Document
from langchain.embeddings import FakeEmbeddings
from langchain.retrievers import KNNRetriever, TFIDFRetriever
from langchain.retrievers.bm25 import BM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
@ -40,3 +42,38 @@ def test_weighted_reciprocal_rank() -> None:
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
assert result[0].page_content == "1"
assert result[1].page_content == "2"
@pytest.mark.requires("rank_bm25", "sklearn")
def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None:
doc_list_a = [
"I like apples",
"I like oranges",
"Apples and oranges are fruits",
]
doc_list_b = [
"I like melons",
"I like pineapples",
"Melons and pineapples are fruits",
]
doc_list_c = [
"I like avocados",
"I like strawberries",
"Avocados and strawberries are fruits",
]
dummy_retriever = BM25Retriever.from_texts(doc_list_a)
dummy_retriever.k = 1
tfidf_retriever = TFIDFRetriever.from_texts(texts=doc_list_b)
tfidf_retriever.k = 1
knn_retriever = KNNRetriever.from_texts(
texts=doc_list_c, embeddings=FakeEmbeddings(size=100)
)
knn_retriever.k = 1
ensemble_retriever = EnsembleRetriever(
retrievers=[dummy_retriever, tfidf_retriever, knn_retriever],
weights=[0.6, 0.3, 0.1],
)
docs = ensemble_retriever.get_relevant_documents("I like apples")
assert len(docs) == 3

Loading…
Cancel
Save