From 2b2176a3c1bc3d1870a35f9594d13d817c011b3f Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Wed, 24 May 2023 10:02:09 -0700 Subject: [PATCH] tfidf retriever (#5114) Co-authored-by: vempaliakhil96 --- .../indexes/retrievers/examples/tf_idf.ipynb | 45 ++++++++++++------- langchain/retrievers/tfidf.py | 35 ++++++++++++--- poetry.lock | 5 +-- pyproject.toml | 4 +- .../retrievers/test_tfidf.py | 17 +++++++ 5 files changed, 80 insertions(+), 26 deletions(-) rename tests/{integration_tests => unit_tests}/retrievers/test_tfidf.py (55%) diff --git a/docs/modules/indexes/retrievers/examples/tf_idf.ipynb b/docs/modules/indexes/retrievers/examples/tf_idf.ipynb index b594f7a8..fed3df6c 100644 --- a/docs/modules/indexes/retrievers/examples/tf_idf.ipynb +++ b/docs/modules/indexes/retrievers/examples/tf_idf.ipynb @@ -16,17 +16,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "a801b57c", "metadata": {}, "outputs": [], "source": [ - "# !pip install scikit-learn" + "# !pip install scikit-learn\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "393ac030", "metadata": { "tags": [] @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "98b1c017", "metadata": { "tags": [] @@ -56,6 +56,27 @@ "retriever = TFIDFRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"])" ] }, + { + "cell_type": "markdown", + "id": "c016b266", + "metadata": {}, + "source": [ + "## Create a New Retriever with Documents\n", + "\n", + "You can now create a new retriever with the documents you created." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "53af4f00", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema import Document\n", + "retriever = TFIDFRetriever.from_documents([Document(page_content=\"foo\"), Document(page_content=\"bar\"), Document(page_content=\"world\"), Document(page_content=\"hello\"), Document(page_content=\"foo bar\")])" + ] + }, { "cell_type": "markdown", "id": "08437fa2", @@ -68,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "c0455218", "metadata": { "tags": [] @@ -80,7 +101,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "7dfa5c29", "metadata": { "tags": [] @@ -95,7 +116,7 @@ " Document(page_content='world', metadata={})]" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -103,14 +124,6 @@ "source": [ "result" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "74bd9256", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -129,7 +142,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.11.3" } }, "nbformat": 4, diff --git a/langchain/retrievers/tfidf.py b/langchain/retrievers/tfidf.py index 2fa8a58c..3ccece7d 100644 --- a/langchain/retrievers/tfidf.py +++ b/langchain/retrievers/tfidf.py @@ -2,7 +2,9 @@ Largely based on https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb""" -from typing import Any, Dict, List, Optional +from __future__ import annotations + +from typing import Any, Dict, Iterable, List, Optional from pydantic import BaseModel @@ -23,18 +25,39 @@ class TFIDFRetriever(BaseRetriever, BaseModel): @classmethod def from_texts( cls, - texts: List[str], + texts: Iterable[str], + metadatas: Optional[Iterable[dict]] = None, tfidf_params: Optional[Dict[str, Any]] = None, - **kwargs: Any - ) -> "TFIDFRetriever": - from sklearn.feature_extraction.text import TfidfVectorizer + **kwargs: Any, + ) -> TFIDFRetriever: + try: + from sklearn.feature_extraction.text import TfidfVectorizer + except ImportError: + raise ImportError( + "Could not import scikit-learn, please install with `pip install " + "scikit-learn`." + ) tfidf_params = tfidf_params or {} vectorizer = TfidfVectorizer(**tfidf_params) tfidf_array = vectorizer.fit_transform(texts) - docs = [Document(page_content=t) for t in texts] + metadatas = metadatas or ({} for _ in texts) + docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs) + @classmethod + def from_documents( + cls, + documents: Iterable[Document], + *, + tfidf_params: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> TFIDFRetriever: + texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) + return cls.from_texts( + texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs + ) + def get_relevant_documents(self, query: str) -> List[Document]: from sklearn.metrics.pairwise import cosine_similarity diff --git a/poetry.lock b/poetry.lock index b561b228..3092f4e2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6874,7 +6874,6 @@ files = [ {file = "pylance-0.4.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2b86fb8dccc03094c0db37bef0d91bda60e8eb0d1eddf245c6971450c8d8a53f"}, {file = "pylance-0.4.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bc82914b13204187d673b5f3d45f93219c38a0e9d0542ba251074f639669789"}, {file = "pylance-0.4.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a4bcce77f99ecd4cbebbadb01e58d5d8138d40eb56bdcdbc3b20b0475e7a472"}, - {file = "pylance-0.4.12-cp38-abi3-win_amd64.whl", hash = "sha256:9616931c5300030adb9626d22515710a127d1e46a46737a7a0f980b52f13627c"}, ] [package.dependencies] @@ -10857,7 +10856,7 @@ azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices- cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "telethon", "tqdm", "zep-python"] +extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"] llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] qdrant = ["qdrant-client"] @@ -10866,4 +10865,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "196588e10bb33939f5bae294a194ad01e803f40ed1087fe6a7a4b87e8d80712b" +content-hash = "640f7e8102328d7ec3f56778d7cdb76b4846fc407c99606e0aec31833bc3933e" diff --git a/pyproject.toml b/pyproject.toml index 8d3b8648..7f243bcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true} chardet = {version="^5.1.0", optional=true} requests-toolbelt = {version = "^1.0.0", optional = true} openlm = {version = "^0.0.5", optional = true} +scikit-learn = {version = "^1.2.2", optional = true} azure-ai-formrecognizer = {version = "^3.2.1", optional = true} azure-ai-vision = {version = "^0.11.1b1", optional = true} azure-cognitiveservices-speech = {version = "^1.28.0", optional = true} @@ -274,7 +275,8 @@ extended_testing = [ "zep-python", "gql", "requests_toolbelt", - "html2text" + "html2text", + "scikit-learn", ] [tool.ruff] diff --git a/tests/integration_tests/retrievers/test_tfidf.py b/tests/unit_tests/retrievers/test_tfidf.py similarity index 55% rename from tests/integration_tests/retrievers/test_tfidf.py rename to tests/unit_tests/retrievers/test_tfidf.py index 54dae33a..197eedd7 100644 --- a/tests/integration_tests/retrievers/test_tfidf.py +++ b/tests/unit_tests/retrievers/test_tfidf.py @@ -1,6 +1,10 @@ +import pytest + from langchain.retrievers.tfidf import TFIDFRetriever +from langchain.schema import Document +@pytest.mark.requires("sklearn") def test_from_texts() -> None: input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts) @@ -8,6 +12,7 @@ def test_from_texts() -> None: assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5) +@pytest.mark.requires("sklearn") def test_from_texts_with_tfidf_params() -> None: input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] tfidf_retriever = TFIDFRetriever.from_texts( @@ -15,3 +20,15 @@ def test_from_texts_with_tfidf_params() -> None: ) # should count only multiple words (have, pan) assert tfidf_retriever.tfidf_array.toarray().shape == (3, 2) + + +@pytest.mark.requires("sklearn") +def test_from_documents() -> None: + input_docs = [ + Document(page_content="I have a pen."), + Document(page_content="Do you have a pen?"), + Document(page_content="I have a bag."), + ] + tfidf_retriever = TFIDFRetriever.from_documents(documents=input_docs) + assert len(tfidf_retriever.docs) == 3 + assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)