diff --git a/docs/extras/integrations/retrievers/tf_idf.ipynb b/docs/extras/integrations/retrievers/tf_idf.ipynb index 45558c0e59..e94091e6df 100644 --- a/docs/extras/integrations/retrievers/tf_idf.ipynb +++ b/docs/extras/integrations/retrievers/tf_idf.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "a801b57c", "metadata": {}, "outputs": [], @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "393ac030", "metadata": { "tags": [] @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "98b1c017", "metadata": { "tags": [] @@ -133,6 +133,68 @@ "source": [ "result" ] + }, + { + "cell_type": "markdown", + "id": "363f3c04", + "metadata": {}, + "source": [ + "## Save and load\n", + "\n", + "You can easily save and load this retriever, making it handy for local development!" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "10c90d03", + "metadata": {}, + "outputs": [], + "source": [ + "retriever.save_local(\"testing.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fb3b153c", + "metadata": {}, + "outputs": [], + "source": [ + "retriever_copy = TFIDFRetriever.load_local(\"testing.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c03ff3c7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='foo', metadata={}),\n", + " Document(page_content='foo bar', metadata={}),\n", + " Document(page_content='hello', metadata={}),\n", + " Document(page_content='world', metadata={})]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "retriever_copy.get_relevant_documents(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d7c5728", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -151,7 +213,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.10.1" } }, "nbformat": 4, diff --git a/libs/langchain/langchain/retrievers/tfidf.py b/libs/langchain/langchain/retrievers/tfidf.py index 1d910f18ec..d5758f3942 100644 --- a/libs/langchain/langchain/retrievers/tfidf.py +++ b/libs/langchain/langchain/retrievers/tfidf.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pickle +from pathlib import Path from typing import Any, Dict, Iterable, List, Optional from langchain.callbacks.manager import CallbackManagerForRetrieverRun @@ -76,3 +78,49 @@ class TFIDFRetriever(BaseRetriever): ) # Op -- (n_docs,1) -- Cosine Sim with each doc return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]] return return_docs + + def save_local( + self, + folder_path: str, + file_name: str = "tfidf_vectorizer", + ) -> None: + try: + import joblib + except ImportError: + raise ImportError( + "Could not import joblib, please install with `pip install joblib`." + ) + + path = Path(folder_path) + path.mkdir(exist_ok=True, parents=True) + + # Save vectorizer with joblib dump. + joblib.dump(self.vectorizer, path / f"{file_name}.joblib") + + # Save docs and tfidf array as pickle. + with open(path / f"{file_name}.pkl", "wb") as f: + pickle.dump((self.docs, self.tfidf_array), f) + + @classmethod + def load_local( + cls, + folder_path: str, + file_name: str = "tfidf_vectorizer", + ) -> TFIDFRetriever: + try: + import joblib + except ImportError: + raise ImportError( + "Could not import joblib, please install with `pip install joblib`." + ) + + path = Path(folder_path) + + # Load vectorizer with joblib load. + vectorizer = joblib.load(path / f"{file_name}.joblib") + + # Load docs and tfidf array as pickle. + with open(path / f"{file_name}.pkl", "rb") as f: + docs, tfidf_array = pickle.load(f) + + return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_tfidf.py b/libs/langchain/tests/unit_tests/retrievers/test_tfidf.py index 197eedd751..484fa6f0c6 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_tfidf.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_tfidf.py @@ -1,3 +1,7 @@ +import os +from datetime import datetime +from tempfile import TemporaryDirectory + import pytest from langchain.retrievers.tfidf import TFIDFRetriever @@ -32,3 +36,26 @@ def test_from_documents() -> None: tfidf_retriever = TFIDFRetriever.from_documents(documents=input_docs) assert len(tfidf_retriever.docs) == 3 assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5) + + +@pytest.mark.requires("sklearn") +def test_save_local_load_local() -> None: + input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."] + tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts) + + file_name = "tfidf_vectorizer" + temp_timestamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S") + with TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder: + tfidf_retriever.save_local( + folder_path=temp_folder, + file_name=file_name, + ) + assert os.path.exists(os.path.join(temp_folder, f"{file_name}.joblib")) + assert os.path.exists(os.path.join(temp_folder, f"{file_name}.pkl")) + + loaded_tfidf_retriever = TFIDFRetriever.load_local( + folder_path=temp_folder, + file_name=file_name, + ) + assert len(loaded_tfidf_retriever.docs) == 3 + assert loaded_tfidf_retriever.tfidf_array.toarray().shape == (3, 5)