From 85f57ab4cd034c0e9ad959a323bf4f95cd16616f Mon Sep 17 00:00:00 2001 From: billytrend-cohere <144115527+billytrend-cohere@users.noreply.github.com> Date: Wed, 27 Mar 2024 10:41:53 -0500 Subject: [PATCH] cohere[patch]: Fix cohere rerank (#19624) Fix cohere rerank inspired by https://github.com/langchain-ai/langchain/pull/19486 --- libs/partners/cohere/langchain_cohere/rerank.py | 8 ++++++-- .../tests/integration_tests/test_rerank.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 libs/partners/cohere/tests/integration_tests/test_rerank.py diff --git a/libs/partners/cohere/langchain_cohere/rerank.py b/libs/partners/cohere/langchain_cohere/rerank.py index 5c8c2bcfc8..f946b3ea36 100644 --- a/libs/partners/cohere/langchain_cohere/rerank.py +++ b/libs/partners/cohere/langchain_cohere/rerank.py @@ -69,10 +69,14 @@ class CohereRerank(BaseDocumentCompressor): model = model or self.model top_n = top_n if (top_n is None or top_n > 0) else self.top_n results = self.client.rerank( - query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc + query=query, + documents=docs, + model=model, + top_n=top_n, + max_chunks_per_doc=max_chunks_per_doc, ) result_dicts = [] - for res in results: + for res in results.results: result_dicts.append( {"index": res.index, "relevance_score": res.relevance_score} ) diff --git a/libs/partners/cohere/tests/integration_tests/test_rerank.py b/libs/partners/cohere/tests/integration_tests/test_rerank.py new file mode 100644 index 0000000000..f9a2ebd0ae --- /dev/null +++ b/libs/partners/cohere/tests/integration_tests/test_rerank.py @@ -0,0 +1,16 @@ +"""Test Cohere reranks.""" +from langchain_core.documents import Document + +from langchain_cohere import CohereRerank + + +def test_langchain_cohere_rerank_documents() -> None: + """Test cohere rerank.""" + rerank = CohereRerank() + test_documents = [ + Document(page_content="This is a test document."), + Document(page_content="Another test document."), + ] + test_query = "Test query" + results = rerank.rerank(test_documents, test_query) + assert len(results) == 2