From bd7e0a534cdd767406db66fb060a10cd65235a63 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 28 Apr 2023 21:54:24 -0700 Subject: [PATCH] Harrison/csv loader (#3771) Co-authored-by: mrT23 --- langchain/document_loaders/csv_loader.py | 27 +++-- poetry.lock | 30 ++++- pyproject.toml | 1 + .../document_loader/test_csv_loader.py | 106 ++++++++++++++++++ 4 files changed, 148 insertions(+), 16 deletions(-) create mode 100644 tests/unit_tests/document_loader/test_csv_loader.py diff --git a/langchain/document_loaders/csv_loader.py b/langchain/document_loaders/csv_loader.py index 9911f605..54c0d8f5 100644 --- a/langchain/document_loaders/csv_loader.py +++ b/langchain/document_loaders/csv_loader.py @@ -1,4 +1,4 @@ -from csv import DictReader +import csv from typing import Dict, List, Optional from langchain.docstore.document import Document @@ -38,23 +38,30 @@ class CSVLoader(BaseLoader): self.encoding = encoding if csv_args is None: self.csv_args = { - "delimiter": ",", - "quotechar": '"', + "delimiter": csv.Dialect.delimiter, + "quotechar": csv.Dialect.quotechar, } else: self.csv_args = csv_args def load(self) -> List[Document]: - docs = [] + """Load data into document objects.""" + docs = [] with open(self.file_path, newline="", encoding=self.encoding) as csvfile: - csv = DictReader(csvfile, **self.csv_args) # type: ignore - for i, row in enumerate(csv): + csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore + for i, row in enumerate(csv_reader): content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) - if self.source_column is not None: - source = row[self.source_column] - else: - source = self.file_path + try: + source = ( + row[self.source_column] + if self.source_column is not None + else self.file_path + ) + except KeyError: + raise ValueError( + f"Source column '{self.source_column}' not found in CSV file." + ) metadata = {"source": source, "row": i} doc = Document(page_content=content, metadata=metadata) docs.append(doc) diff --git a/poetry.lock b/poetry.lock index f3b6f0c6..fe13f627 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "absl-py" @@ -6270,6 +6270,24 @@ files = [ pytest = ">=5.0.0" python-dotenv = ">=0.9.1" +[[package]] +name = "pytest-mock" +version = "3.10.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-mock-3.10.0.tar.gz", hash = "sha256:fbbdb085ef7c252a326fd8cdcac0aa3b1333d8811f131bdcc701002e1be7ed4f"}, + {file = "pytest_mock-3.10.0-py3-none-any.whl", hash = "sha256:f4c973eeae0282963eb293eb173ce91b091a79c1334455acfac9ddee8a1c784b"}, +] + +[package.dependencies] +pytest = ">=5.0" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "pytest-vcr" version = "1.0.2" @@ -7655,7 +7673,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and platform_machine == \"aarch64\" or python_version >= \"3\" and platform_machine == \"ppc64le\" or python_version >= \"3\" and platform_machine == \"x86_64\" or python_version >= \"3\" and platform_machine == \"amd64\" or python_version >= \"3\" and platform_machine == \"AMD64\" or python_version >= \"3\" and platform_machine == \"win32\" or python_version >= \"3\" and platform_machine == \"WIN32\""} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} [package.extras] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] @@ -9418,15 +9436,15 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "lancedb", "lark", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] -azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"] +all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "lark", "pexpect", "pyvespa"] +azure = ["azure-identity", "azure-cosmos", "openai", "azure-core"] cohere = ["cohere"] embeddings = ["sentence-transformers"] -llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] +llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] openai = ["openai"] qdrant = ["qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "2ef913e267f1a10beee9f97924dc38df89fbe8a1ddc0de0f6e6f04e272763823" +content-hash = "bfd037a6e4fbe62305bc1305999cc0cc83d84a740ebf8f66036d9ae4d59a5760" diff --git a/pyproject.toml b/pyproject.toml index c1630158..d27f9699 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ freezegun = "^1.2.2" responses = "^0.22.0" pytest-asyncio = "^0.20.3" lark = "^1.1.5" +pytest-mock = "^3.10.0" [tool.poetry.group.test_integration] optional = true diff --git a/tests/unit_tests/document_loader/test_csv_loader.py b/tests/unit_tests/document_loader/test_csv_loader.py new file mode 100644 index 00000000..98169969 --- /dev/null +++ b/tests/unit_tests/document_loader/test_csv_loader.py @@ -0,0 +1,106 @@ +from pytest_mock import MockerFixture + +from langchain.docstore.document import Document +from langchain.document_loaders.csv_loader import CSVLoader + + +class TestCSVLoader: + # Tests that a CSV file with valid data is loaded successfully. + def test_csv_loader_load_valid_data(self, mocker: MockerFixture) -> None: + # Setup + file_path = "test.csv" + expected_docs = [ + Document( + page_content="column1: value1\ncolumn2: value2\ncolumn3: value3", + metadata={"source": file_path, "row": 0}, + ), + Document( + page_content="column1: value4\ncolumn2: value5\ncolumn3: value6", + metadata={"source": file_path, "row": 1}, + ), + ] + mocker.patch("builtins.open", mocker.mock_open()) + mock_csv_reader = mocker.patch("csv.DictReader") + mock_csv_reader.return_value = [ + {"column1": "value1", "column2": "value2", "column3": "value3"}, + {"column1": "value4", "column2": "value5", "column3": "value6"}, + ] + + # Exercise + loader = CSVLoader(file_path=file_path) + result = loader.load() + + # Assert + assert result == expected_docs + + # Tests that an empty CSV file is handled correctly. + def test_csv_loader_load_empty_file(self, mocker: MockerFixture) -> None: + # Setup + file_path = "test.csv" + expected_docs: list = [] + mocker.patch("builtins.open", mocker.mock_open()) + mock_csv_reader = mocker.patch("csv.DictReader") + mock_csv_reader.return_value = [] + + # Exercise + loader = CSVLoader(file_path=file_path) + result = loader.load() + + # Assert + assert result == expected_docs + + # Tests that a CSV file with only one row is handled correctly. + def test_csv_loader_load_single_row_file(self, mocker: MockerFixture) -> None: + # Setup + file_path = "test.csv" + expected_docs = [ + Document( + page_content="column1: value1\ncolumn2: value2\ncolumn3: value3", + metadata={"source": file_path, "row": 0}, + ) + ] + mocker.patch("builtins.open", mocker.mock_open()) + mock_csv_reader = mocker.patch("csv.DictReader") + mock_csv_reader.return_value = [ + {"column1": "value1", "column2": "value2", "column3": "value3"} + ] + + # Exercise + loader = CSVLoader(file_path=file_path) + result = loader.load() + + # Assert + assert result == expected_docs + + # Tests that a CSV file with only one column is handled correctly. + def test_csv_loader_load_single_column_file(self, mocker: MockerFixture) -> None: + # Setup + file_path = "test.csv" + expected_docs = [ + Document( + page_content="column1: value1", + metadata={"source": file_path, "row": 0}, + ), + Document( + page_content="column1: value2", + metadata={"source": file_path, "row": 1}, + ), + Document( + page_content="column1: value3", + metadata={"source": file_path, "row": 2}, + ), + ] + mocker.patch("builtins.open", mocker.mock_open()) + mock_csv_reader = mocker.patch("csv.DictReader") + mock_csv_reader.return_value = [ + {"column1": "value1"}, + {"column1": "value2"}, + {"column1": "value3"}, + ] + + # Exercise + loader = CSVLoader(file_path=file_path) + result = loader.load() + + # Assert + assert result == expected_docs