From cdc20d1203b4c122aa9cc23d1266c7e2dcd93e68 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 14 May 2023 18:25:59 -0700 Subject: [PATCH] Harrison/json loader fix (#4686) Co-authored-by: Triet Le <112841660+triet-lq-holistics@users.noreply.github.com> --- langchain/document_loaders/json_loader.py | 92 ++++++++----- poetry.lock | 14 +- pyproject.toml | 2 +- .../document_loader/test_json_loader.py | 123 ++++++++++++++++++ 4 files changed, 187 insertions(+), 44 deletions(-) create mode 100644 tests/unit_tests/document_loader/test_json_loader.py diff --git a/langchain/document_loaders/json_loader.py b/langchain/document_loaders/json_loader.py index 2100640f..f1e594b2 100644 --- a/langchain/document_loaders/json_loader.py +++ b/langchain/document_loaders/json_loader.py @@ -1,7 +1,7 @@ """Loader that loads data from JSON.""" import json from pathlib import Path -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader @@ -23,6 +23,7 @@ class JSONLoader(BaseLoader): jq_schema: str, content_key: Optional[str] = None, metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, + text_content: bool = True, ): """Initialize the JSONLoader. @@ -35,6 +36,8 @@ class JSONLoader(BaseLoader): metadata_func (Callable[Dict, Dict]): A function that takes in the JSON object extracted by the jq_schema and the default metadata and returns a dict of the updated metadata. + text_content (bool): Boolean flag to indicates whether the content is in + string format, default to True """ try: import jq # noqa:F401 @@ -47,58 +50,75 @@ class JSONLoader(BaseLoader): self._jq_schema = jq.compile(jq_schema) self._content_key = content_key self._metadata_func = metadata_func + self._text_content = text_content def load(self) -> List[Document]: """Load and return documents from the JSON file.""" - data = self._jq_schema.input(json.loads(self.file_path.read_text())) # Perform some validation # This is not a perfect validation, but it should catch most cases # and prevent the user from getting a cryptic error later on. if self._content_key is not None: - sample = data.first() - if not isinstance(sample, dict): - raise ValueError( - f"Expected the jq schema to result in a list of objects (dict), \ - so sample must be a dict but got `{type(sample)}`" - ) - - if sample.get(self._content_key) is None: - raise ValueError( - f"Expected the jq schema to result in a list of objects (dict) \ - with the key `{self._content_key}`" - ) - - if self._metadata_func is not None: - sample_metadata = self._metadata_func(sample, {}) - if not isinstance(sample_metadata, dict): - raise ValueError( - f"Expected the metadata_func to return a dict but got \ - `{type(sample_metadata)}`" - ) + self._validate_content_key(data) docs = [] - for i, sample in enumerate(data, 1): metadata = dict( source=str(self.file_path), seq_num=i, ) + text = self._get_text(sample=sample, metadata=metadata) + docs.append(Document(page_content=text, metadata=metadata)) + + return docs - if self._content_key is not None: - text = sample.get(self._content_key) - if self._metadata_func is not None: - # We pass in the metadata dict to the metadata_func - # so that the user can customize the default metadata - # based on the content of the JSON object. - metadata = self._metadata_func(sample, metadata) - else: - text = sample + def _get_text(self, sample: Any, metadata: dict) -> str: + """Convert sample to string format""" + if self._content_key is not None: + content = sample.get(self._content_key) + if self._metadata_func is not None: + # We pass in the metadata dict to the metadata_func + # so that the user can customize the default metadata + # based on the content of the JSON object. + metadata = self._metadata_func(sample, metadata) + else: + content = sample + + if self._text_content and not isinstance(content, str): + raise ValueError( + f"Expected page_content is string, got {type(content)} instead. \ + Set `text_content=False` if the desired input for \ + `page_content` is not a string" + ) - # In case the text is None, set it to an empty string - text = text or "" + # In case the text is None, set it to an empty string + elif isinstance(content, str): + return content + elif isinstance(content, dict): + return json.dumps(content) if content else "" + else: + return str(content) if content is not None else "" + + def _validate_content_key(self, data: Any) -> None: + """Check if content key is valid""" + sample = data.first() + if not isinstance(sample, dict): + raise ValueError( + f"Expected the jq schema to result in a list of objects (dict), \ + so sample must be a dict but got `{type(sample)}`" + ) - docs.append(Document(page_content=text, metadata=metadata)) + if sample.get(self._content_key) is None: + raise ValueError( + f"Expected the jq schema to result in a list of objects (dict) \ + with the key `{self._content_key}`" + ) - return docs + if self._metadata_func is not None: + sample_metadata = self._metadata_func(sample, {}) + if not isinstance(sample_metadata, dict): + raise ValueError( + f"Expected the metadata_func to return a dict but got \ + `{type(sample_metadata)}`" + ) diff --git a/poetry.lock b/poetry.lock index 58342d0c..688b0364 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "absl-py" @@ -9994,18 +9994,18 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "hnswlib", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "lark", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "protobuf", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] -azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"] +all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "protobuf", "hnswlib", "steamship", "pdfminer-six"] +azure = ["azure-identity", "azure-cosmos", "openai", "azure-core"] cohere = ["cohere"] embeddings = ["sentence-transformers"] -extended-testing = ["pdfminer-six", "pypdf", "tqdm"] -hnswlib = ["docarray", "hnswlib", "protobuf"] +extended-testing = ["pypdf", "pdfminer-six", "tqdm", "jq"] +hnswlib = ["docarray", "protobuf", "hnswlib"] in-memory-store = ["docarray"] -llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] +llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] openai = ["openai", "tiktoken"] qdrant = ["qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "6d5c4aa06539e6f7c7531c30d73cbf08fbdea75486bf4b81c106b9e678a13b45" +content-hash = "42b518704c39bc25c6da05f81a9488a9a6fecfd7784b3c9915d30127ce384a63" diff --git a/pyproject.toml b/pyproject.toml index 5f31b8ed..3c141f54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -171,7 +171,7 @@ azure = ["azure-identity", "azure-cosmos", "openai", "azure-core"] all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "protobuf", "hnswlib", "steamship", "pdfminer-six"] # An extra used to be able to add extended testing. extended_testing = [ - "pypdf", "pdfminer.six", "tqdm" + "pypdf", "pdfminer.six", "tqdm", "jq" ] [tool.ruff] diff --git a/tests/unit_tests/document_loader/test_json_loader.py b/tests/unit_tests/document_loader/test_json_loader.py new file mode 100644 index 00000000..31739d4d --- /dev/null +++ b/tests/unit_tests/document_loader/test_json_loader.py @@ -0,0 +1,123 @@ +import pytest +from pytest import raises +from pytest_mock import MockerFixture + +from langchain.docstore.document import Document +from langchain.document_loaders.json_loader import JSONLoader + + +@pytest.mark.requires("jq") +def test_load_valid_string_content(mocker: MockerFixture) -> None: + file_path = "/workspaces/langchain/test.json" + expected_docs = [ + Document( + page_content="value1", + metadata={"source": file_path, "seq_num": 1}, + ), + Document( + page_content="value2", + metadata={"source": file_path, "seq_num": 2}, + ), + ] + mocker.patch("builtins.open", mocker.mock_open()) + mock_csv_reader = mocker.patch("pathlib.Path.read_text") + mock_csv_reader.return_value = '[{"text": "value1"}, {"text": "value2"}]' + + loader = JSONLoader(file_path=file_path, jq_schema=".[].text", text_content=True) + result = loader.load() + + assert result == expected_docs + + +@pytest.mark.requires("jq") +def test_load_valid_dict_content(mocker: MockerFixture) -> None: + file_path = "/workspaces/langchain/test.json" + expected_docs = [ + Document( + page_content='{"text": "value1"}', + metadata={"source": file_path, "seq_num": 1}, + ), + Document( + page_content='{"text": "value2"}', + metadata={"source": file_path, "seq_num": 2}, + ), + ] + mocker.patch("builtins.open", mocker.mock_open()) + mock_csv_reader = mocker.patch("pathlib.Path.read_text") + mock_csv_reader.return_value = """ + [{"text": "value1"}, {"text": "value2"}] + """ + + loader = JSONLoader(file_path=file_path, jq_schema=".[]", text_content=False) + result = loader.load() + + assert result == expected_docs + + +@pytest.mark.requires("jq") +def test_load_valid_bool_content(mocker: MockerFixture) -> None: + file_path = "/workspaces/langchain/test.json" + expected_docs = [ + Document( + page_content="False", + metadata={"source": file_path, "seq_num": 1}, + ), + Document( + page_content="True", + metadata={"source": file_path, "seq_num": 2}, + ), + ] + mocker.patch("builtins.open", mocker.mock_open()) + mock_csv_reader = mocker.patch("pathlib.Path.read_text") + mock_csv_reader.return_value = """ + [ + {"flag": false}, {"flag": true} + ] + """ + + loader = JSONLoader(file_path=file_path, jq_schema=".[].flag", text_content=False) + result = loader.load() + + assert result == expected_docs + + +@pytest.mark.requires("jq") +def test_load_valid_numeric_content(mocker: MockerFixture) -> None: + file_path = "/workspaces/langchain/test.json" + expected_docs = [ + Document( + page_content="99", + metadata={"source": file_path, "seq_num": 1}, + ), + Document( + page_content="99.5", + metadata={"source": file_path, "seq_num": 2}, + ), + ] + mocker.patch("builtins.open", mocker.mock_open()) + mock_csv_reader = mocker.patch("pathlib.Path.read_text") + mock_csv_reader.return_value = """ + [ + {"num": 99}, {"num": 99.5} + ] + """ + + loader = JSONLoader(file_path=file_path, jq_schema=".[].num", text_content=False) + result = loader.load() + + assert result == expected_docs + + +@pytest.mark.requires("jq") +def test_load_invalid_test_content(mocker: MockerFixture) -> None: + file_path = "/workspaces/langchain/test.json" + mocker.patch("builtins.open", mocker.mock_open()) + mock_csv_reader = mocker.patch("pathlib.Path.read_text") + mock_csv_reader.return_value = """ + [{"text": "value1"}, {"text": "value2"}] + """ + + loader = JSONLoader(file_path=file_path, jq_schema=".[]", text_content=True) + + with raises(ValueError): + loader.load()