forked from Archives/langchain
tfidf retriever (#5114)
Co-authored-by: vempaliakhil96 <vempaliakhil96@gmail.com>
This commit is contained in:
parent
b00c77dc62
commit
2b2176a3c1
@ -16,17 +16,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "a801b57c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip install scikit-learn"
|
||||
"# !pip install scikit-learn\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "393ac030",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -46,7 +46,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "98b1c017",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -56,6 +56,27 @@
|
||||
"retriever = TFIDFRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c016b266",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create a New Retriever with Documents\n",
|
||||
"\n",
|
||||
"You can now create a new retriever with the documents you created."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "53af4f00",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import Document\n",
|
||||
"retriever = TFIDFRetriever.from_documents([Document(page_content=\"foo\"), Document(page_content=\"bar\"), Document(page_content=\"world\"), Document(page_content=\"hello\"), Document(page_content=\"foo bar\")])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "08437fa2",
|
||||
@ -68,7 +89,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"id": "c0455218",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -80,7 +101,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"id": "7dfa5c29",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -95,7 +116,7 @@
|
||||
" Document(page_content='world', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -103,14 +124,6 @@
|
||||
"source": [
|
||||
"result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "74bd9256",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -129,7 +142,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
Largely based on
|
||||
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -23,18 +25,39 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
tfidf_params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> "TFIDFRetriever":
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
**kwargs: Any,
|
||||
) -> TFIDFRetriever:
|
||||
try:
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import scikit-learn, please install with `pip install "
|
||||
"scikit-learn`."
|
||||
)
|
||||
|
||||
tfidf_params = tfidf_params or {}
|
||||
vectorizer = TfidfVectorizer(**tfidf_params)
|
||||
tfidf_array = vectorizer.fit_transform(texts)
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
metadatas = metadatas or ({} for _ in texts)
|
||||
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
|
||||
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: Iterable[Document],
|
||||
*,
|
||||
tfidf_params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> TFIDFRetriever:
|
||||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||
return cls.from_texts(
|
||||
texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs
|
||||
)
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
|
5
poetry.lock
generated
5
poetry.lock
generated
@ -6874,7 +6874,6 @@ files = [
|
||||
{file = "pylance-0.4.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2b86fb8dccc03094c0db37bef0d91bda60e8eb0d1eddf245c6971450c8d8a53f"},
|
||||
{file = "pylance-0.4.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bc82914b13204187d673b5f3d45f93219c38a0e9d0542ba251074f639669789"},
|
||||
{file = "pylance-0.4.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a4bcce77f99ecd4cbebbadb01e58d5d8138d40eb56bdcdbc3b20b0475e7a472"},
|
||||
{file = "pylance-0.4.12-cp38-abi3-win_amd64.whl", hash = "sha256:9616931c5300030adb9626d22515710a127d1e46a46737a7a0f980b52f13627c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -10857,7 +10856,7 @@ azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "telethon", "tqdm", "zep-python"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
|
||||
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
qdrant = ["qdrant-client"]
|
||||
@ -10866,4 +10865,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "196588e10bb33939f5bae294a194ad01e803f40ed1087fe6a7a4b87e8d80712b"
|
||||
content-hash = "640f7e8102328d7ec3f56778d7cdb76b4846fc407c99606e0aec31833bc3933e"
|
||||
|
@ -93,6 +93,7 @@ langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true}
|
||||
chardet = {version="^5.1.0", optional=true}
|
||||
requests-toolbelt = {version = "^1.0.0", optional = true}
|
||||
openlm = {version = "^0.0.5", optional = true}
|
||||
scikit-learn = {version = "^1.2.2", optional = true}
|
||||
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
|
||||
azure-ai-vision = {version = "^0.11.1b1", optional = true}
|
||||
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
|
||||
@ -274,7 +275,8 @@ extended_testing = [
|
||||
"zep-python",
|
||||
"gql",
|
||||
"requests_toolbelt",
|
||||
"html2text"
|
||||
"html2text",
|
||||
"scikit-learn",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -1,6 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.tfidf import TFIDFRetriever
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
@pytest.mark.requires("sklearn")
|
||||
def test_from_texts() -> None:
|
||||
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
|
||||
tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts)
|
||||
@ -8,6 +12,7 @@ def test_from_texts() -> None:
|
||||
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
|
||||
|
||||
|
||||
@pytest.mark.requires("sklearn")
|
||||
def test_from_texts_with_tfidf_params() -> None:
|
||||
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
|
||||
tfidf_retriever = TFIDFRetriever.from_texts(
|
||||
@ -15,3 +20,15 @@ def test_from_texts_with_tfidf_params() -> None:
|
||||
)
|
||||
# should count only multiple words (have, pan)
|
||||
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 2)
|
||||
|
||||
|
||||
@pytest.mark.requires("sklearn")
|
||||
def test_from_documents() -> None:
|
||||
input_docs = [
|
||||
Document(page_content="I have a pen."),
|
||||
Document(page_content="Do you have a pen?"),
|
||||
Document(page_content="I have a bag."),
|
||||
]
|
||||
tfidf_retriever = TFIDFRetriever.from_documents(documents=input_docs)
|
||||
assert len(tfidf_retriever.docs) == 3
|
||||
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
|
Loading…
Reference in New Issue
Block a user