from typing import Any, Dict, List, Tuple from pydantic import BaseModel, ConfigDict, Field from langchain_community.cross_encoders.base import BaseCrossEncoder DEFAULT_MODEL_NAME = "BAAI/bge-reranker-base" class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder): """HuggingFace cross encoder models. Example: .. code-block:: python from langchain_community.cross_encoders import HuggingFaceCrossEncoder model_name = "BAAI/bge-reranker-base" model_kwargs = {'device': 'cpu'} hf = HuggingFaceCrossEncoder( model_name=model_name, model_kwargs=model_kwargs ) """ client: Any = None #: :meta private: model_name: str = DEFAULT_MODEL_NAME """Model name to use.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Keyword arguments to pass to the model.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" super().__init__(**kwargs) try: import sentence_transformers except ImportError as exc: raise ImportError( "Could not import sentence_transformers python package. " "Please install it with `pip install sentence-transformers`." ) from exc self.client = sentence_transformers.CrossEncoder( self.model_name, **self.model_kwargs ) model_config = ConfigDict(extra="forbid", protected_namespaces=()) def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: """Compute similarity scores using a HuggingFace transformer model. Args: text_pairs: The list of text text_pairs to score the similarity. Returns: List of scores, one for each pair. """ scores = self.client.predict(text_pairs) # Some models e.g bert-multilingual-passage-reranking-msmarco # gives two score not_relevant and relevant as compare with the query. if len(scores.shape) > 1: # we are going to get the relevant scores scores = map(lambda x: x[1], scores) return scores