You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/voyageai/tests/unit_tests/test_rerank.py

84 lines
2.9 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from collections import namedtuple
from typing import Any
import pytest # type: ignore
from langchain_core.documents import Document
from voyageai.api_resources import VoyageResponse # type: ignore
from voyageai.object import RerankingObject # type: ignore
from langchain_voyageai.rerank import VoyageAIRerank
doc_list = [
"The Mediterranean diet emphasizes fish, olive oil, and vegetables"
", believed to reduce chronic diseases.",
"Photosynthesis in plants converts light energy into glucose and "
"produces essential oxygen.",
"20th-century innovations, from radios to smartphones, centered "
"on electronic advancements.",
"Rivers provide water, irrigation, and habitat for aquatic species, "
"vital for ecosystems.",
"Apples conference call to discuss fourth fiscal quarter results and "
"business updates is scheduled for Thursday, November 2, 2023 at 2:00 "
"p.m. PT / 5:00 p.m. ET.",
"Shakespeare's works, like 'Hamlet' and 'A Midsummer Night's Dream,' "
"endure in literature.",
]
documents = [Document(page_content=x) for x in doc_list]
@pytest.mark.requires("voyageai")
def test_init() -> None:
VoyageAIRerank(
voyage_api_key="foo",
model="rerank-lite-1",
)
def get_mock_rerank_result() -> RerankingObject:
VoyageResultItem = namedtuple("VoyageResultItem", ["index", "relevance_score"])
Usage = namedtuple("Usage", ["total_tokens"])
voyage_response = VoyageResponse()
voyage_response.data = [
VoyageResultItem(index=1, relevance_score=0.9),
VoyageResultItem(index=0, relevance_score=0.8),
]
voyage_response.usage = Usage(total_tokens=255)
return RerankingObject(response=voyage_response, documents=doc_list)
@pytest.mark.requires("voyageai")
def test_rerank_unit_test(mocker: Any) -> None:
mocker.patch("voyageai.Client.rerank").return_value = get_mock_rerank_result()
expected_result = [
Document(
page_content="Photosynthesis in plants converts light energy into "
"glucose and produces essential oxygen.",
metadata={"relevance_score": 0.9},
),
Document(
page_content="The Mediterranean diet emphasizes fish, olive oil, and "
"vegetables, believed to reduce chronic diseases.",
metadata={"relevance_score": 0.8},
),
]
rerank = VoyageAIRerank(
voyage_api_key="foo",
model="rerank-lite-1",
)
result = rerank.compress_documents(
documents=documents, query="When is the Apple's conference call scheduled?"
)
assert expected_result == result
def test_rerank_empty_input() -> None:
rerank = VoyageAIRerank(
voyage_api_key="foo",
model="rerank-lite-1",
)
result = rerank.compress_documents(
documents=[], query="When is the Apple's conference call scheduled?"
)
assert len(result) == 0