From b33d2346db6082bf00243f60af5eb66e6681274c Mon Sep 17 00:00:00 2001 From: wenngong <76683249+wenngong@users.noreply.github.com> Date: Tue, 25 Jun 2024 05:50:08 +0800 Subject: [PATCH] community: FlashrankRerank support loading customer client (#23350) Description: FlashrankRerank Document compressor support loading customer client Issue: #23338 Co-authored-by: gongwn1 --- .../document_compressors/flashrank_rerank.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/libs/community/langchain_community/document_compressors/flashrank_rerank.py b/libs/community/langchain_community/document_compressors/flashrank_rerank.py index dd3307b43e..fd66bee659 100644 --- a/libs/community/langchain_community/document_compressors/flashrank_rerank.py +++ b/libs/community/langchain_community/document_compressors/flashrank_rerank.py @@ -38,17 +38,20 @@ class FlashrankRerank(BaseDocumentCompressor): @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - try: - from flashrank import Ranker - except ImportError: - raise ImportError( - "Could not import flashrank python package. " - "Please install it with `pip install flashrank`." - ) + if "client" in values: + return values + else: + try: + from flashrank import Ranker + except ImportError: + raise ImportError( + "Could not import flashrank python package. " + "Please install it with `pip install flashrank`." + ) - values["model"] = values.get("model", DEFAULT_MODEL_NAME) - values["client"] = Ranker(model_name=values["model"]) - return values + values["model"] = values.get("model", DEFAULT_MODEL_NAME) + values["client"] = Ranker(model_name=values["model"]) + return values def compress_documents( self,