diff --git a/langchain/retrievers/tfidf.py b/langchain/retrievers/tfidf.py index 2eef3cdd..2fa8a58c 100644 --- a/langchain/retrievers/tfidf.py +++ b/langchain/retrievers/tfidf.py @@ -2,7 +2,7 @@ Largely based on https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb""" -from typing import Any, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel @@ -21,10 +21,16 @@ class TFIDFRetriever(BaseRetriever, BaseModel): arbitrary_types_allowed = True @classmethod - def from_texts(cls, texts: List[str], **kwargs: Any) -> "TFIDFRetriever": + def from_texts( + cls, + texts: List[str], + tfidf_params: Optional[Dict[str, Any]] = None, + **kwargs: Any + ) -> "TFIDFRetriever": from sklearn.feature_extraction.text import TfidfVectorizer - vectorizer = TfidfVectorizer() + tfidf_params = tfidf_params or {} + vectorizer = TfidfVectorizer(**tfidf_params) tfidf_array = vectorizer.fit_transform(texts) docs = [Document(page_content=t) for t in texts] return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs) diff --git a/tests/integration_tests/retrievers/test_tfidf.py b/tests/integration_tests/retrievers/test_tfidf.py new file mode 100644 index 00000000..54dae33a --- /dev/null +++ b/tests/integration_tests/retrievers/test_tfidf.py @@ -0,0 +1,17 @@ +from langchain.retrievers.tfidf import TFIDFRetriever + + +def test_from_texts() -> None: + input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] + tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts) + assert len(tfidf_retriever.docs) == 3 + assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5) + + +def test_from_texts_with_tfidf_params() -> None: + input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] + tfidf_retriever = TFIDFRetriever.from_texts( + texts=input_texts, tfidf_params={"min_df": 2} + ) + # should count only multiple words (have, pan) + assert tfidf_retriever.tfidf_array.toarray().shape == (3, 2)