diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index 44a48131e4..53c2ee423b 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -81,8 +81,14 @@ class CohereRerank(BaseDocumentCompressor): model = model or self.model top_n = top_n if (top_n is None or top_n > 0) else self.top_n results = self.client.rerank( - query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc + query=query, + documents=docs, + model=model, + top_n=top_n, + max_chunks_per_doc=max_chunks_per_doc, ) + if hasattr(results, "results"): + results = getattr(results, "results") result_dicts = [] for res in results: result_dicts.append( diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 8b2d15f551..66499273b0 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -3489,7 +3489,7 @@ tenacity = "^8.1.0" [package.extras] cli = ["typer (>=0.9.0,<0.10.0)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "vdms (>=0.0.20,<0.0.21)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"] [package.source] type = "directory" @@ -3497,7 +3497,7 @@ url = "../community" [[package]] name = "langchain-core" -version = "0.1.34" +version = "0.1.35" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -9411,4 +9411,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "89cdf7e6e23edb6910401dc9a84bb8fccb75fb300e9d87c445a917990bfa5e91" +content-hash = "d005d6f8e57e6831bd8801f8556fab20af7853444021c5009abe3aa99057f78a" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 53334cf615..8384232cc8 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -36,7 +36,7 @@ jinja2 = {version = "^3", optional = true} tiktoken = {version = ">=0.3.2,<0.6.0", optional = true, python=">=3.9"} qdrant-client = {version = "^1.3.1", optional = true, python = ">=3.8.1,<3.12"} dataclasses-json = ">= 0.5.7, < 0.7" -cohere = {version = "^4", optional = true} +cohere = {version = ">=4,<6", optional = true} openai = {version = "<2", optional = true} nlpcloud = {version = "^1", optional = true} huggingface_hub = {version = "^0", optional = true} diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_cohere_rerank.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_cohere_rerank.py index 51eb0f626f..87609e4d59 100644 --- a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_cohere_rerank.py +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_cohere_rerank.py @@ -1,8 +1,10 @@ import os import pytest +from pytest_mock import MockerFixture from langchain.retrievers.document_compressors import CohereRerank +from langchain.schema import Document os.environ["COHERE_API_KEY"] = "foo" @@ -14,3 +16,37 @@ def test_init() -> None: CohereRerank( top_n=5, model="rerank-english_v2.0", cohere_api_key="foo", user_agent="bar" ) + + +@pytest.mark.requires("cohere") +def test_rerank(mocker: MockerFixture) -> None: + mock_client = mocker.MagicMock() + mock_result = mocker.MagicMock() + mock_result.results = [ + mocker.MagicMock(index=0, relevance_score=0.8), + mocker.MagicMock(index=1, relevance_score=0.6), + ] + mock_client.rerank.return_value = mock_result + + test_documents = [ + Document(page_content="This is a test document."), + Document(page_content="Another test document."), + ] + test_query = "Test query" + + mocker.patch("cohere.Client", return_value=mock_client) + + reranker = CohereRerank(cohere_api_key="foo") + results = reranker.rerank(test_documents, test_query) + + mock_client.rerank.assert_called_once_with( + query=test_query, + documents=[doc.page_content for doc in test_documents], + model="rerank-english-v2.0", + top_n=3, + max_chunks_per_doc=None, + ) + assert results == [ + {"index": 0, "relevance_score": 0.8}, + {"index": 1, "relevance_score": 0.6}, + ]