Update Cohere Reranker (#4180)

The forward ref annotations don't get updated if we only iimport with
type checking

---------

Co-authored-by: Abhinav Verma <abhinav_win12@yahoo.co.in>
parallel_dir_loader
Zander Chase 1 year ago committed by GitHub
parent d84bb02881
commit 84cfa76e00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Dict, Sequence
from pydantic import root_validator
from pydantic import Extra, root_validator
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.schema import Document
@ -10,6 +10,13 @@ from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
from cohere import Client
else:
# We do to avoid pydantic annotation issues when actually instantiating
# while keeping this import optional
try:
from cohere import Client
except ImportError:
pass
class CohereRerank(BaseDocumentCompressor):
@ -17,7 +24,13 @@ class CohereRerank(BaseDocumentCompressor):
top_n: int = 3
model: str = "rerank-english-v2.0"
@root_validator()
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
cohere_api_key = get_from_dict_or_env(

@ -0,0 +1,8 @@
"""Test the cohere reranker."""
from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank
def test_cohere_reranker_init() -> None:
"""Test the cohere reranker initializes correctly."""
CohereRerank()
Loading…
Cancel
Save