2023-07-25 02:16:10 +00:00
|
|
|
import pytest
|
2024-05-16 21:24:27 +00:00
|
|
|
from langchain.retrievers.ensemble import EnsembleRetriever
|
2023-11-21 16:35:29 +00:00
|
|
|
from langchain_core.documents import Document
|
2024-05-08 20:46:52 +00:00
|
|
|
from langchain_core.embeddings import FakeEmbeddings
|
2023-07-25 02:16:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("rank_bm25")
|
|
|
|
def test_ensemble_retriever_get_relevant_docs() -> None:
|
|
|
|
doc_list = [
|
|
|
|
"I like apples",
|
|
|
|
"I like oranges",
|
|
|
|
"Apples and oranges are fruits",
|
|
|
|
]
|
|
|
|
|
2024-05-08 20:46:52 +00:00
|
|
|
from langchain_community.retrievers import BM25Retriever
|
|
|
|
|
2023-07-25 02:16:10 +00:00
|
|
|
dummy_retriever = BM25Retriever.from_texts(doc_list)
|
|
|
|
dummy_retriever.k = 1
|
|
|
|
|
2024-03-15 20:37:09 +00:00
|
|
|
ensemble_retriever = EnsembleRetriever( # type: ignore[call-arg]
|
2023-07-25 02:16:10 +00:00
|
|
|
retrievers=[dummy_retriever, dummy_retriever]
|
|
|
|
)
|
2024-04-22 15:14:53 +00:00
|
|
|
docs = ensemble_retriever.invoke("I like apples")
|
2023-07-25 02:16:10 +00:00
|
|
|
assert len(docs) == 1
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("rank_bm25")
|
|
|
|
def test_weighted_reciprocal_rank() -> None:
|
|
|
|
doc1 = Document(page_content="1")
|
|
|
|
doc2 = Document(page_content="2")
|
|
|
|
|
2024-05-08 20:46:52 +00:00
|
|
|
from langchain_community.retrievers import BM25Retriever
|
|
|
|
|
2023-07-25 02:16:10 +00:00
|
|
|
dummy_retriever = BM25Retriever.from_texts(["1", "2"])
|
|
|
|
ensemble_retriever = EnsembleRetriever(
|
|
|
|
retrievers=[dummy_retriever, dummy_retriever], weights=[0.4, 0.5], c=0
|
|
|
|
)
|
|
|
|
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
|
|
|
|
assert result[0].page_content == "2"
|
|
|
|
assert result[1].page_content == "1"
|
|
|
|
|
|
|
|
ensemble_retriever.weights = [0.5, 0.4]
|
|
|
|
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
|
|
|
|
assert result[0].page_content == "1"
|
|
|
|
assert result[1].page_content == "2"
|
2024-02-14 05:22:03 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("rank_bm25", "sklearn")
|
|
|
|
def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None:
|
|
|
|
doc_list_a = [
|
|
|
|
"I like apples",
|
|
|
|
"I like oranges",
|
|
|
|
"Apples and oranges are fruits",
|
|
|
|
]
|
|
|
|
doc_list_b = [
|
|
|
|
"I like melons",
|
|
|
|
"I like pineapples",
|
|
|
|
"Melons and pineapples are fruits",
|
|
|
|
]
|
|
|
|
doc_list_c = [
|
|
|
|
"I like avocados",
|
|
|
|
"I like strawberries",
|
|
|
|
"Avocados and strawberries are fruits",
|
|
|
|
]
|
|
|
|
|
2024-05-08 20:46:52 +00:00
|
|
|
from langchain_community.retrievers import (
|
|
|
|
BM25Retriever,
|
|
|
|
KNNRetriever,
|
|
|
|
TFIDFRetriever,
|
|
|
|
)
|
|
|
|
|
2024-02-14 05:22:03 +00:00
|
|
|
dummy_retriever = BM25Retriever.from_texts(doc_list_a)
|
|
|
|
dummy_retriever.k = 1
|
|
|
|
tfidf_retriever = TFIDFRetriever.from_texts(texts=doc_list_b)
|
|
|
|
tfidf_retriever.k = 1
|
|
|
|
knn_retriever = KNNRetriever.from_texts(
|
|
|
|
texts=doc_list_c, embeddings=FakeEmbeddings(size=100)
|
|
|
|
)
|
|
|
|
knn_retriever.k = 1
|
|
|
|
|
|
|
|
ensemble_retriever = EnsembleRetriever(
|
|
|
|
retrievers=[dummy_retriever, tfidf_retriever, knn_retriever],
|
|
|
|
weights=[0.6, 0.3, 0.1],
|
|
|
|
)
|
2024-04-22 15:14:53 +00:00
|
|
|
docs = ensemble_retriever.invoke("I like apples")
|
2024-02-14 05:22:03 +00:00
|
|
|
assert len(docs) == 3
|