from __future__ import annotations from typing import Any, Callable, Dict, Iterable, List, Optional from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever def default_preprocessing_func(text: str) -> List[str]: return text.split() class BM25Retriever(BaseRetriever): """`BM25` retriever without Elasticsearch.""" vectorizer: Any """ BM25 vectorizer.""" docs: List[Document] = Field(repr=False) """ List of documents.""" k: int = 4 """ Number of documents to return.""" preprocess_func: Callable[[str], List[str]] = default_preprocessing_func """ Preprocessing function to use on the text before BM25 vectorization.""" class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True @classmethod def from_texts( cls, texts: Iterable[str], metadatas: Optional[Iterable[dict]] = None, bm25_params: Optional[Dict[str, Any]] = None, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, **kwargs: Any, ) -> BM25Retriever: """ Create a BM25Retriever from a list of texts. Args: texts: A list of texts to vectorize. metadatas: A list of metadata dicts to associate with each text. bm25_params: Parameters to pass to the BM25 vectorizer. preprocess_func: A function to preprocess each text before vectorization. **kwargs: Any other arguments to pass to the retriever. Returns: A BM25Retriever instance. """ try: from rank_bm25 import BM25Okapi except ImportError: raise ImportError( "Could not import rank_bm25, please install with `pip install " "rank_bm25`." ) texts_processed = [preprocess_func(t) for t in texts] bm25_params = bm25_params or {} vectorizer = BM25Okapi(texts_processed, **bm25_params) metadatas = metadatas or ({} for _ in texts) docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] return cls( vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs ) @classmethod def from_documents( cls, documents: Iterable[Document], *, bm25_params: Optional[Dict[str, Any]] = None, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, **kwargs: Any, ) -> BM25Retriever: """ Create a BM25Retriever from a list of Documents. Args: documents: A list of Documents to vectorize. bm25_params: Parameters to pass to the BM25 vectorizer. preprocess_func: A function to preprocess each text before vectorization. **kwargs: Any other arguments to pass to the retriever. Returns: A BM25Retriever instance. """ texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) return cls.from_texts( texts=texts, bm25_params=bm25_params, metadatas=metadatas, preprocess_func=preprocess_func, **kwargs, ) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: processed_query = self.preprocess_func(query) return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k) return return_docs