mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/tfidf parameters (#3481)
Co-authored-by: pao <go5kuramubon@gmail.com> Co-authored-by: KyoHattori <kyo.hattori@abejainc.com>
This commit is contained in:
parent
eda69b13f3
commit
7257f9e015
@ -2,7 +2,7 @@
|
||||
|
||||
Largely based on
|
||||
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
|
||||
from typing import Any, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -21,10 +21,16 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def from_texts(cls, texts: List[str], **kwargs: Any) -> "TFIDFRetriever":
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
tfidf_params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> "TFIDFRetriever":
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
vectorizer = TfidfVectorizer()
|
||||
tfidf_params = tfidf_params or {}
|
||||
vectorizer = TfidfVectorizer(**tfidf_params)
|
||||
tfidf_array = vectorizer.fit_transform(texts)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)
|
||||
|
17
tests/integration_tests/retrievers/test_tfidf.py
Normal file
17
tests/integration_tests/retrievers/test_tfidf.py
Normal file
@ -0,0 +1,17 @@
|
||||
from langchain.retrievers.tfidf import TFIDFRetriever
|
||||
|
||||
|
||||
def test_from_texts() -> None:
|
||||
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
|
||||
tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts)
|
||||
assert len(tfidf_retriever.docs) == 3
|
||||
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
|
||||
|
||||
|
||||
def test_from_texts_with_tfidf_params() -> None:
|
||||
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
|
||||
tfidf_retriever = TFIDFRetriever.from_texts(
|
||||
texts=input_texts, tfidf_params={"min_df": 2}
|
||||
)
|
||||
# should count only multiple words (have, pan)
|
||||
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 2)
|
Loading…
Reference in New Issue
Block a user