tfidf retriever (#5114)

Co-authored-by: vempaliakhil96 <vempaliakhil96@gmail.com>
This commit is contained in:
Davis Chase 2023-05-24 10:02:09 -07:00 committed by GitHub
parent b00c77dc62
commit 2b2176a3c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 26 deletions

View File

@ -16,17 +16,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "a801b57c",
"metadata": {},
"outputs": [],
"source": [
"# !pip install scikit-learn"
"# !pip install scikit-learn\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "393ac030",
"metadata": {
"tags": []
@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "98b1c017",
"metadata": {
"tags": []
@ -56,6 +56,27 @@
"retriever = TFIDFRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"])"
]
},
{
"cell_type": "markdown",
"id": "c016b266",
"metadata": {},
"source": [
"## Create a New Retriever with Documents\n",
"\n",
"You can now create a new retriever with the documents you created."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "53af4f00",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema import Document\n",
"retriever = TFIDFRetriever.from_documents([Document(page_content=\"foo\"), Document(page_content=\"bar\"), Document(page_content=\"world\"), Document(page_content=\"hello\"), Document(page_content=\"foo bar\")])"
]
},
{
"cell_type": "markdown",
"id": "08437fa2",
@ -68,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "c0455218",
"metadata": {
"tags": []
@ -80,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "7dfa5c29",
"metadata": {
"tags": []
@ -95,7 +116,7 @@
" Document(page_content='world', metadata={})]"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -103,14 +124,6 @@
"source": [
"result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74bd9256",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -129,7 +142,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.3"
}
},
"nbformat": 4,

View File

@ -2,7 +2,9 @@
Largely based on
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
from typing import Any, Dict, List, Optional
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel
@ -23,18 +25,39 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
@classmethod
def from_texts(
cls,
texts: List[str],
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
tfidf_params: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> "TFIDFRetriever":
from sklearn.feature_extraction.text import TfidfVectorizer
**kwargs: Any,
) -> TFIDFRetriever:
try:
from sklearn.feature_extraction.text import TfidfVectorizer
except ImportError:
raise ImportError(
"Could not import scikit-learn, please install with `pip install "
"scikit-learn`."
)
tfidf_params = tfidf_params or {}
vectorizer = TfidfVectorizer(**tfidf_params)
tfidf_array = vectorizer.fit_transform(texts)
docs = [Document(page_content=t) for t in texts]
metadatas = metadatas or ({} for _ in texts)
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
tfidf_params: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> TFIDFRetriever:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs
)
def get_relevant_documents(self, query: str) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity

5
poetry.lock generated
View File

@ -6874,7 +6874,6 @@ files = [
{file = "pylance-0.4.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2b86fb8dccc03094c0db37bef0d91bda60e8eb0d1eddf245c6971450c8d8a53f"},
{file = "pylance-0.4.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bc82914b13204187d673b5f3d45f93219c38a0e9d0542ba251074f639669789"},
{file = "pylance-0.4.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a4bcce77f99ecd4cbebbadb01e58d5d8138d40eb56bdcdbc3b20b0475e7a472"},
{file = "pylance-0.4.12-cp38-abi3-win_amd64.whl", hash = "sha256:9616931c5300030adb9626d22515710a127d1e46a46737a7a0f980b52f13627c"},
]
[package.dependencies]
@ -10857,7 +10856,7 @@ azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "telethon", "tqdm", "zep-python"]
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
openai = ["openai", "tiktoken"]
qdrant = ["qdrant-client"]
@ -10866,4 +10865,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "196588e10bb33939f5bae294a194ad01e803f40ed1087fe6a7a4b87e8d80712b"
content-hash = "640f7e8102328d7ec3f56778d7cdb76b4846fc407c99606e0aec31833bc3933e"

View File

@ -93,6 +93,7 @@ langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true}
chardet = {version="^5.1.0", optional=true}
requests-toolbelt = {version = "^1.0.0", optional = true}
openlm = {version = "^0.0.5", optional = true}
scikit-learn = {version = "^1.2.2", optional = true}
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
azure-ai-vision = {version = "^0.11.1b1", optional = true}
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
@ -274,7 +275,8 @@ extended_testing = [
"zep-python",
"gql",
"requests_toolbelt",
"html2text"
"html2text",
"scikit-learn",
]
[tool.ruff]

View File

@ -1,6 +1,10 @@
import pytest
from langchain.retrievers.tfidf import TFIDFRetriever
from langchain.schema import Document
@pytest.mark.requires("sklearn")
def test_from_texts() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts)
@ -8,6 +12,7 @@ def test_from_texts() -> None:
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
@pytest.mark.requires("sklearn")
def test_from_texts_with_tfidf_params() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
tfidf_retriever = TFIDFRetriever.from_texts(
@ -15,3 +20,15 @@ def test_from_texts_with_tfidf_params() -> None:
)
# should count only multiple words (have, pan)
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 2)
@pytest.mark.requires("sklearn")
def test_from_documents() -> None:
input_docs = [
Document(page_content="I have a pen."),
Document(page_content="Do you have a pen?"),
Document(page_content="I have a bag."),
]
tfidf_retriever = TFIDFRetriever.from_documents(documents=input_docs)
assert len(tfidf_retriever.docs) == 3
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)