You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/tests/unit_tests/retrievers/test_ensemble.py

87 lines
2.6 KiB
Python

import pytest
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain_core.documents import Document
from langchain_core.embeddings import FakeEmbeddings
@pytest.mark.requires("rank_bm25")
def test_ensemble_retriever_get_relevant_docs() -> None:
doc_list = [
"I like apples",
"I like oranges",
"Apples and oranges are fruits",
]
from langchain_community.retrievers import BM25Retriever
dummy_retriever = BM25Retriever.from_texts(doc_list)
dummy_retriever.k = 1
ensemble_retriever = EnsembleRetriever( # type: ignore[call-arg]
retrievers=[dummy_retriever, dummy_retriever]
)
docs = ensemble_retriever.invoke("I like apples")
assert len(docs) == 1
@pytest.mark.requires("rank_bm25")
def test_weighted_reciprocal_rank() -> None:
doc1 = Document(page_content="1")
doc2 = Document(page_content="2")
from langchain_community.retrievers import BM25Retriever
dummy_retriever = BM25Retriever.from_texts(["1", "2"])
ensemble_retriever = EnsembleRetriever(
retrievers=[dummy_retriever, dummy_retriever], weights=[0.4, 0.5], c=0
)
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
assert result[0].page_content == "2"
assert result[1].page_content == "1"
ensemble_retriever.weights = [0.5, 0.4]
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
assert result[0].page_content == "1"
assert result[1].page_content == "2"
@pytest.mark.requires("rank_bm25", "sklearn")
def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None:
doc_list_a = [
"I like apples",
"I like oranges",
"Apples and oranges are fruits",
]
doc_list_b = [
"I like melons",
"I like pineapples",
"Melons and pineapples are fruits",
]
doc_list_c = [
"I like avocados",
"I like strawberries",
"Avocados and strawberries are fruits",
]
from langchain_community.retrievers import (
BM25Retriever,
KNNRetriever,
TFIDFRetriever,
)
dummy_retriever = BM25Retriever.from_texts(doc_list_a)
dummy_retriever.k = 1
tfidf_retriever = TFIDFRetriever.from_texts(texts=doc_list_b)
tfidf_retriever.k = 1
knn_retriever = KNNRetriever.from_texts(
texts=doc_list_c, embeddings=FakeEmbeddings(size=100)
)
knn_retriever.k = 1
ensemble_retriever = EnsembleRetriever(
retrievers=[dummy_retriever, tfidf_retriever, knn_retriever],
weights=[0.6, 0.3, 0.1],
)
docs = ensemble_retriever.invoke("I like apples")
assert len(docs) == 3