add save and load tfidf vectorizer and docs for TFIDFRetriever (#8112)

This is to add save_local and load_local to tfidf_vectorizer and docs in
tfidf_retriever to make the vectorizer reusable.

<!-- Thank you for contributing to LangChain!

Replace this comment with:
- Description: add save_local and load_local to tfidf_vectorizer and
docs in tfidf_retriever
  - Issue: None
  - Dependencies: None
  - Tag maintainer: @rlancemartin, @eyurtsev
  - Twitter handle: @MlopsJ

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
shibuiwilliam 2023-08-04 15:06:27 +09:00 committed by GitHub
parent 0f68054401
commit 2759e2d857
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 141 additions and 4 deletions

View File

@ -16,7 +16,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"id": "a801b57c", "id": "a801b57c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -26,7 +26,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"id": "393ac030", "id": "393ac030",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -46,7 +46,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"id": "98b1c017", "id": "98b1c017",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -133,6 +133,68 @@
"source": [ "source": [
"result" "result"
] ]
},
{
"cell_type": "markdown",
"id": "363f3c04",
"metadata": {},
"source": [
"## Save and load\n",
"\n",
"You can easily save and load this retriever, making it handy for local development!"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "10c90d03",
"metadata": {},
"outputs": [],
"source": [
"retriever.save_local(\"testing.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fb3b153c",
"metadata": {},
"outputs": [],
"source": [
"retriever_copy = TFIDFRetriever.load_local(\"testing.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c03ff3c7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='foo', metadata={}),\n",
" Document(page_content='foo bar', metadata={}),\n",
" Document(page_content='hello', metadata={}),\n",
" Document(page_content='world', metadata={})]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever_copy.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d7c5728",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
@ -151,7 +213,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.3" "version": "3.10.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import pickle
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional
from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.callbacks.manager import CallbackManagerForRetrieverRun
@ -76,3 +78,49 @@ class TFIDFRetriever(BaseRetriever):
) # Op -- (n_docs,1) -- Cosine Sim with each doc ) # Op -- (n_docs,1) -- Cosine Sim with each doc
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]] return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
return return_docs return return_docs
def save_local(
self,
folder_path: str,
file_name: str = "tfidf_vectorizer",
) -> None:
try:
import joblib
except ImportError:
raise ImportError(
"Could not import joblib, please install with `pip install joblib`."
)
path = Path(folder_path)
path.mkdir(exist_ok=True, parents=True)
# Save vectorizer with joblib dump.
joblib.dump(self.vectorizer, path / f"{file_name}.joblib")
# Save docs and tfidf array as pickle.
with open(path / f"{file_name}.pkl", "wb") as f:
pickle.dump((self.docs, self.tfidf_array), f)
@classmethod
def load_local(
cls,
folder_path: str,
file_name: str = "tfidf_vectorizer",
) -> TFIDFRetriever:
try:
import joblib
except ImportError:
raise ImportError(
"Could not import joblib, please install with `pip install joblib`."
)
path = Path(folder_path)
# Load vectorizer with joblib load.
vectorizer = joblib.load(path / f"{file_name}.joblib")
# Load docs and tfidf array as pickle.
with open(path / f"{file_name}.pkl", "rb") as f:
docs, tfidf_array = pickle.load(f)
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array)

View File

@ -1,3 +1,7 @@
import os
from datetime import datetime
from tempfile import TemporaryDirectory
import pytest import pytest
from langchain.retrievers.tfidf import TFIDFRetriever from langchain.retrievers.tfidf import TFIDFRetriever
@ -32,3 +36,26 @@ def test_from_documents() -> None:
tfidf_retriever = TFIDFRetriever.from_documents(documents=input_docs) tfidf_retriever = TFIDFRetriever.from_documents(documents=input_docs)
assert len(tfidf_retriever.docs) == 3 assert len(tfidf_retriever.docs) == 3
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5) assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
@pytest.mark.requires("sklearn")
def test_save_local_load_local() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts)
file_name = "tfidf_vectorizer"
temp_timestamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
tfidf_retriever.save_local(
folder_path=temp_folder,
file_name=file_name,
)
assert os.path.exists(os.path.join(temp_folder, f"{file_name}.joblib"))
assert os.path.exists(os.path.join(temp_folder, f"{file_name}.pkl"))
loaded_tfidf_retriever = TFIDFRetriever.load_local(
folder_path=temp_folder,
file_name=file_name,
)
assert len(loaded_tfidf_retriever.docs) == 3
assert loaded_tfidf_retriever.tfidf_array.toarray().shape == (3, 5)