|
|
@ -38,17 +38,20 @@ class FlashrankRerank(BaseDocumentCompressor):
|
|
|
|
@root_validator(pre=True)
|
|
|
|
@root_validator(pre=True)
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
"""Validate that api key and python package exists in environment."""
|
|
|
|
"""Validate that api key and python package exists in environment."""
|
|
|
|
try:
|
|
|
|
if "client" in values:
|
|
|
|
from flashrank import Ranker
|
|
|
|
return values
|
|
|
|
except ImportError:
|
|
|
|
else:
|
|
|
|
raise ImportError(
|
|
|
|
try:
|
|
|
|
"Could not import flashrank python package. "
|
|
|
|
from flashrank import Ranker
|
|
|
|
"Please install it with `pip install flashrank`."
|
|
|
|
except ImportError:
|
|
|
|
)
|
|
|
|
raise ImportError(
|
|
|
|
|
|
|
|
"Could not import flashrank python package. "
|
|
|
|
|
|
|
|
"Please install it with `pip install flashrank`."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
values["model"] = values.get("model", DEFAULT_MODEL_NAME)
|
|
|
|
values["model"] = values.get("model", DEFAULT_MODEL_NAME)
|
|
|
|
values["client"] = Ranker(model_name=values["model"])
|
|
|
|
values["client"] = Ranker(model_name=values["model"])
|
|
|
|
return values
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
|
|
def compress_documents(
|
|
|
|
def compress_documents(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|