mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
langchain[patch]: fix-cohere-reranker-rerank-method with cohere v5 (#19486)
#### Description Fixed the following error with `rerank` method from `CohereRerank`: ``` ---> [79](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:79) results = self.client.rerank( [80](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:80) query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc [81](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:81) ) [82](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:82) result_dicts = [] [83](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/jjmov99/legal-colombia/~/legal-colombia/.venv/lib/python3.11/site-packages/langchain/retrievers/document_compressors/cohere_rerank.py:83) for res in results.results: TypeError: BaseCohere.rerank() takes 1 positional argument but 4 positional arguments (and 2 keyword-only arguments) were given ``` This was easily fixed going from this: ``` def rerank( self, documents: Sequence[Union[str, Document, dict]], query: str, *, model: Optional[str] = None, top_n: Optional[int] = -1, max_chunks_per_doc: Optional[int] = None, ) -> List[Dict[str, Any]]: ... if len(documents) == 0: # to avoid empty api call return [] docs = [ doc.page_content if isinstance(doc, Document) else doc for doc in documents ] model = model or self.model top_n = top_n if (top_n is None or top_n > 0) else self.top_n results = self.client.rerank( query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc ) result_dicts = [] for res in results: result_dicts.append( {"index": res.index, "relevance_score": res.relevance_score} ) return result_dicts ``` to this: ``` def rerank( self, documents: Sequence[Union[str, Document, dict]], query: str, *, model: Optional[str] = None, top_n: Optional[int] = -1, max_chunks_per_doc: Optional[int] = None, ) -> List[Dict[str, Any]]: ... if len(documents) == 0: # to avoid empty api call return [] docs = [ doc.page_content if isinstance(doc, Document) else doc for doc in documents ] model = model or self.model top_n = top_n if (top_n is None or top_n > 0) else self.top_n results = self.client.rerank( query=query, documents=docs, model=model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc <------------- ) result_dicts = [] for res in results.results: <------------- result_dicts.append( {"index": res.index, "relevance_score": res.relevance_score} ) return result_dicts ``` #### Unit & Integration tests I added a unit test to check the behaviour of `rerank`. Also fixed the original integration test which was failing. #### Format & Linting Everything worked properly with `make lint_diff`, `make format_diff` and `make format`. However I noticed an error coming from other part of the library when doing `make lint`: ``` (langchain-py3.9) ➜ langchain git:(master) make format [ "." = "" ] || poetry run ruff format . 1636 files left unchanged [ "." = "" ] || poetry run ruff --select I --fix . (langchain-py3.9) ➜ langchain git:(master) make lint ./scripts/check_pydantic.sh . ./scripts/lint_imports.sh poetry run ruff . [ "." = "" ] || poetry run ruff format . --diff 1636 files already formatted [ "." = "" ] || poetry run ruff --select I . [ "." = "" ] || mkdir -p .mypy_cache && poetry run mypy . --cache-dir .mypy_cache langchain/agents/openai_assistant/base.py:252: error: Argument "file_ids" to "create" of "Assistants" has incompatible type "Optional[Any]"; expected "Union[list[str], NotGiven]" [arg-type] langchain/agents/openai_assistant/base.py:374: error: Argument "file_ids" to "create" of "AsyncAssistants" has incompatible type "Optional[Any]"; expected "Union[list[str], NotGiven]" [arg-type] Found 2 errors in 1 file (checked 1634 source files) make: *** [Makefile:65: lint] Error 1 ``` --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
332996b4b2
commit
51baa1b5cf
@ -81,8 +81,14 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
model = model or self.model
|
||||
top_n = top_n if (top_n is None or top_n > 0) else self.top_n
|
||||
results = self.client.rerank(
|
||||
query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc
|
||||
query=query,
|
||||
documents=docs,
|
||||
model=model,
|
||||
top_n=top_n,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
)
|
||||
if hasattr(results, "results"):
|
||||
results = getattr(results, "results")
|
||||
result_dicts = []
|
||||
for res in results:
|
||||
result_dicts.append(
|
||||
|
6
libs/langchain/poetry.lock
generated
6
libs/langchain/poetry.lock
generated
@ -3489,7 +3489,7 @@ tenacity = "^8.1.0"
|
||||
|
||||
[package.extras]
|
||||
cli = ["typer (>=0.9.0,<0.10.0)"]
|
||||
extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"]
|
||||
extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "vdms (>=0.0.20,<0.0.21)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
@ -3497,7 +3497,7 @@ url = "../community"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.34"
|
||||
version = "0.1.35"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -9411,4 +9411,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "89cdf7e6e23edb6910401dc9a84bb8fccb75fb300e9d87c445a917990bfa5e91"
|
||||
content-hash = "d005d6f8e57e6831bd8801f8556fab20af7853444021c5009abe3aa99057f78a"
|
||||
|
@ -36,7 +36,7 @@ jinja2 = {version = "^3", optional = true}
|
||||
tiktoken = {version = ">=0.3.2,<0.6.0", optional = true, python=">=3.9"}
|
||||
qdrant-client = {version = "^1.3.1", optional = true, python = ">=3.8.1,<3.12"}
|
||||
dataclasses-json = ">= 0.5.7, < 0.7"
|
||||
cohere = {version = "^4", optional = true}
|
||||
cohere = {version = ">=4,<6", optional = true}
|
||||
openai = {version = "<2", optional = true}
|
||||
nlpcloud = {version = "^1", optional = true}
|
||||
huggingface_hub = {version = "^0", optional = true}
|
||||
|
@ -1,8 +1,10 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain.retrievers.document_compressors import CohereRerank
|
||||
from langchain.schema import Document
|
||||
|
||||
os.environ["COHERE_API_KEY"] = "foo"
|
||||
|
||||
@ -14,3 +16,37 @@ def test_init() -> None:
|
||||
CohereRerank(
|
||||
top_n=5, model="rerank-english_v2.0", cohere_api_key="foo", user_agent="bar"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("cohere")
|
||||
def test_rerank(mocker: MockerFixture) -> None:
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_result = mocker.MagicMock()
|
||||
mock_result.results = [
|
||||
mocker.MagicMock(index=0, relevance_score=0.8),
|
||||
mocker.MagicMock(index=1, relevance_score=0.6),
|
||||
]
|
||||
mock_client.rerank.return_value = mock_result
|
||||
|
||||
test_documents = [
|
||||
Document(page_content="This is a test document."),
|
||||
Document(page_content="Another test document."),
|
||||
]
|
||||
test_query = "Test query"
|
||||
|
||||
mocker.patch("cohere.Client", return_value=mock_client)
|
||||
|
||||
reranker = CohereRerank(cohere_api_key="foo")
|
||||
results = reranker.rerank(test_documents, test_query)
|
||||
|
||||
mock_client.rerank.assert_called_once_with(
|
||||
query=test_query,
|
||||
documents=[doc.page_content for doc in test_documents],
|
||||
model="rerank-english-v2.0",
|
||||
top_n=3,
|
||||
max_chunks_per_doc=None,
|
||||
)
|
||||
assert results == [
|
||||
{"index": 0, "relevance_score": 0.8},
|
||||
{"index": 1, "relevance_score": 0.6},
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user