mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
add save and load tfidf vectorizer and docs for TFIDFRetriever (#8112)
This is to add save_local and load_local to tfidf_vectorizer and docs in tfidf_retriever to make the vectorizer reusable. <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: add save_local and load_local to tfidf_vectorizer and docs in tfidf_retriever - Issue: None - Dependencies: None - Tag maintainer: @rlancemartin, @eyurtsev - Twitter handle: @MlopsJ Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
0f68054401
commit
2759e2d857
@ -16,7 +16,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"id": "a801b57c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -26,7 +26,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"id": "393ac030",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -46,7 +46,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"id": "98b1c017",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -133,6 +133,68 @@
|
||||
"source": [
|
||||
"result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "363f3c04",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Save and load\n",
|
||||
"\n",
|
||||
"You can easily save and load this retriever, making it handy for local development!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "10c90d03",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever.save_local(\"testing.pkl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "fb3b153c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever_copy = TFIDFRetriever.load_local(\"testing.pkl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "c03ff3c7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='foo', metadata={}),\n",
|
||||
" Document(page_content='foo bar', metadata={}),\n",
|
||||
" Document(page_content='hello', metadata={}),\n",
|
||||
" Document(page_content='world', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever_copy.get_relevant_documents(\"foo\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2d7c5728",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -151,7 +213,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
"version": "3.10.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
@ -76,3 +78,49 @@ class TFIDFRetriever(BaseRetriever):
|
||||
) # Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
return return_docs
|
||||
|
||||
def save_local(
|
||||
self,
|
||||
folder_path: str,
|
||||
file_name: str = "tfidf_vectorizer",
|
||||
) -> None:
|
||||
try:
|
||||
import joblib
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import joblib, please install with `pip install joblib`."
|
||||
)
|
||||
|
||||
path = Path(folder_path)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Save vectorizer with joblib dump.
|
||||
joblib.dump(self.vectorizer, path / f"{file_name}.joblib")
|
||||
|
||||
# Save docs and tfidf array as pickle.
|
||||
with open(path / f"{file_name}.pkl", "wb") as f:
|
||||
pickle.dump((self.docs, self.tfidf_array), f)
|
||||
|
||||
@classmethod
|
||||
def load_local(
|
||||
cls,
|
||||
folder_path: str,
|
||||
file_name: str = "tfidf_vectorizer",
|
||||
) -> TFIDFRetriever:
|
||||
try:
|
||||
import joblib
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import joblib, please install with `pip install joblib`."
|
||||
)
|
||||
|
||||
path = Path(folder_path)
|
||||
|
||||
# Load vectorizer with joblib load.
|
||||
vectorizer = joblib.load(path / f"{file_name}.joblib")
|
||||
|
||||
# Load docs and tfidf array as pickle.
|
||||
with open(path / f"{file_name}.pkl", "rb") as f:
|
||||
docs, tfidf_array = pickle.load(f)
|
||||
|
||||
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array)
|
||||
|
@ -1,3 +1,7 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.tfidf import TFIDFRetriever
|
||||
@ -32,3 +36,26 @@ def test_from_documents() -> None:
|
||||
tfidf_retriever = TFIDFRetriever.from_documents(documents=input_docs)
|
||||
assert len(tfidf_retriever.docs) == 3
|
||||
assert tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
|
||||
|
||||
|
||||
@pytest.mark.requires("sklearn")
|
||||
def test_save_local_load_local() -> None:
|
||||
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
|
||||
tfidf_retriever = TFIDFRetriever.from_texts(texts=input_texts)
|
||||
|
||||
file_name = "tfidf_vectorizer"
|
||||
temp_timestamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
||||
with TemporaryDirectory(suffix="_" + temp_timestamp + "/") as temp_folder:
|
||||
tfidf_retriever.save_local(
|
||||
folder_path=temp_folder,
|
||||
file_name=file_name,
|
||||
)
|
||||
assert os.path.exists(os.path.join(temp_folder, f"{file_name}.joblib"))
|
||||
assert os.path.exists(os.path.join(temp_folder, f"{file_name}.pkl"))
|
||||
|
||||
loaded_tfidf_retriever = TFIDFRetriever.load_local(
|
||||
folder_path=temp_folder,
|
||||
file_name=file_name,
|
||||
)
|
||||
assert len(loaded_tfidf_retriever.docs) == 3
|
||||
assert loaded_tfidf_retriever.tfidf_array.toarray().shape == (3, 5)
|
||||
|
Loading…
Reference in New Issue
Block a user