diff --git a/langchain/retrievers/document_compressors/cohere_rerank.py b/langchain/retrievers/document_compressors/cohere_rerank.py index 43b084d7..41513c65 100644 --- a/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/langchain/retrievers/document_compressors/cohere_rerank.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Dict, Sequence -from pydantic import root_validator +from pydantic import Extra, root_validator from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.schema import Document @@ -10,6 +10,13 @@ from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: from cohere import Client +else: + # We do to avoid pydantic annotation issues when actually instantiating + # while keeping this import optional + try: + from cohere import Client + except ImportError: + pass class CohereRerank(BaseDocumentCompressor): @@ -17,7 +24,13 @@ class CohereRerank(BaseDocumentCompressor): top_n: int = 3 model: str = "rerank-english-v2.0" - @root_validator() + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" cohere_api_key = get_from_dict_or_env( diff --git a/tests/integration_tests/retrievers/document_compressors/test_cohere_reranker.py b/tests/integration_tests/retrievers/document_compressors/test_cohere_reranker.py new file mode 100644 index 00000000..66745204 --- /dev/null +++ b/tests/integration_tests/retrievers/document_compressors/test_cohere_reranker.py @@ -0,0 +1,8 @@ +"""Test the cohere reranker.""" + +from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank + + +def test_cohere_reranker_init() -> None: + """Test the cohere reranker initializes correctly.""" + CohereRerank()