Harrison/tfidf parameters (#3481)

Co-authored-by: pao <go5kuramubon@gmail.com>
Co-authored-by: KyoHattori <kyo.hattori@abejainc.com>
This commit is contained in:
Harrison Chase 2023-04-24 22:19:58 -07:00 committed by GitHub
parent eda69b13f3
commit 7257f9e015
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 3 deletions

View File

@ -2,7 +2,7 @@
Largely based on
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
from typing import Any, List
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
@ -21,10 +21,16 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
arbitrary_types_allowed = True
@classmethod
def from_texts(cls, texts: List[str], **kwargs: Any) -> "TFIDFRetriever":
def from_texts(
cls,
texts: List[str],
tfidf_params: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> "TFIDFRetriever":
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer()
tfidf_params = tfidf_params or {}
vectorizer = TfidfVectorizer(**tfidf_params)
tfidf_array = vectorizer.fit_transform(texts)
docs = [Document(page_content=t) for t in texts]
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)

View File

@ -0,0 +1,17 @@
from langchain.retrievers.tfidf import TFIDFRetriever
def test_from_texts() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts)
assert len(tfidf_retriever.docs) == 3
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
def test_from_texts_with_tfidf_params() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
tfidf_retriever = TFIDFRetriever.from_texts(
texts=input_texts, tfidf_params={"min_df": 2}
)
# should count only multiple words (have, pan)
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 2)