add save and load tfidf vectorizer and docs for TFIDFRetriever (#8112)

This is to add save_local and load_local to tfidf_vectorizer and docs in
tfidf_retriever to make the vectorizer reusable.

<!-- Thank you for contributing to LangChain!

Replace this comment with:
- Description: add save_local and load_local to tfidf_vectorizer and
docs in tfidf_retriever
  - Issue: None
  - Dependencies: None
  - Tag maintainer: @rlancemartin, @eyurtsev
  - Twitter handle: @MlopsJ

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
shibuiwilliam 2023-08-04 15:06:27 +09:00 committed by GitHub
parent 0f68054401
commit 2759e2d857
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 141 additions and 4 deletions

View File

@ -16,7 +16,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "a801b57c",
"metadata": {},
"outputs": [],
@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "393ac030",
"metadata": {
"tags": []
@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "98b1c017",
"metadata": {
"tags": []
@ -133,6 +133,68 @@
"source": [
"result"
]
},
{
"cell_type": "markdown",
"id": "363f3c04",
"metadata": {},
"source": [
"## Save and load\n",
"\n",
"You can easily save and load this retriever, making it handy for local development!"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "10c90d03",
"metadata": {},
"outputs": [],
"source": [
"retriever.save_local(\"testing.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fb3b153c",
"metadata": {},
"outputs": [],
"source": [
"retriever_copy = TFIDFRetriever.load_local(\"testing.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c03ff3c7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='foo', metadata={}),\n",
" Document(page_content='foo bar', metadata={}),\n",
" Document(page_content='hello', metadata={}),\n",
" Document(page_content='world', metadata={})]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever_copy.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d7c5728",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -151,7 +213,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.10.1"
}
},
"nbformat": 4,

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import pickle
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
@ -76,3 +78,49 @@ class TFIDFRetriever(BaseRetriever):
) # Op -- (n_docs,1) -- Cosine Sim with each doc
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
return return_docs
def save_local(
self,
folder_path: str,
file_name: str = "tfidf_vectorizer",
) -> None:
try:
import joblib
except ImportError:
raise ImportError(
"Could not import joblib, please install with `pip install joblib`."
)
path = Path(folder_path)
path.mkdir(exist_ok=True, parents=True)
# Save vectorizer with joblib dump.
joblib.dump(self.vectorizer, path / f"{file_name}.joblib")
# Save docs and tfidf array as pickle.
with open(path / f"{file_name}.pkl", "wb") as f:
pickle.dump((self.docs, self.tfidf_array), f)
@classmethod
def load_local(
cls,
folder_path: str,
file_name: str = "tfidf_vectorizer",
) -> TFIDFRetriever:
try:
import joblib
except ImportError:
raise ImportError(
"Could not import joblib, please install with `pip install joblib`."
)
path = Path(folder_path)
# Load vectorizer with joblib load.
vectorizer = joblib.load(path / f"{file_name}.joblib")
# Load docs and tfidf array as pickle.
with open(path / f"{file_name}.pkl", "rb") as f:
docs, tfidf_array = pickle.load(f)
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array)

View File

@ -1,3 +1,7 @@
import os
from datetime import datetime
from tempfile import TemporaryDirectory
import pytest
from langchain.retrievers.tfidf import TFIDFRetriever
@ -32,3 +36,26 @@ def test_from_documents() -> None:
tfidf_retriever = TFIDFRetriever.from_documents(documents=input_docs)
assert len(tfidf_retriever.docs) == 3
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
@pytest.mark.requires("sklearn")
def test_save_local_load_local() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts)
file_name = "tfidf_vectorizer"
temp_timestamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
with TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
tfidf_retriever.save_local(
folder_path=temp_folder,
file_name=file_name,
)
assert os.path.exists(os.path.join(temp_folder, f"{file_name}.joblib"))
assert os.path.exists(os.path.join(temp_folder, f"{file_name}.pkl"))
loaded_tfidf_retriever = TFIDFRetriever.load_local(
folder_path=temp_folder,
file_name=file_name,
)
assert len(loaded_tfidf_retriever.docs) == 3
assert loaded_tfidf_retriever.tfidf_array.toarray().shape == (3, 5)