Add option to specify metadata columns in CSV loader (#11576)

#### Description
This PR adds the option to specify additional metadata columns in the
CSVLoader beyond just `Source`.

The current CSV loader includes all columns in `page_content` and if we
want to have columns specified for `page_content` and `metadata` we have
to do something like the below.:
```
csv = pd.read_csv(
        "path_to_csv"
    ).to_dict("records")

documents = [
        Document(
            page_content=doc["content"],
            metadata={
                "last_modified_by": doc["last_modified_by"],
                "point_of_contact": doc["point_of_contact"],
            }
        ) for doc in csv
    ]
```
#### Usage
Example Usage:
```
csv_test  =  CSVLoader(
      file_path="path_to_csv", 
      metadata_columns=["last_modified_by", "point_of_contact"]
 )
```
Example CSV:
```
content, last_modified_by, point_of_contact
"hello world", "Person A", "Person B"
```

Example Result:
```
Document {
 page_content: "hello world"
 metadata: {
 row: '0',
 source: 'path_to_csv',
 last_modified_by: 'Person A',
 point_of_contact: 'Person B',
 }
```

---------

Co-authored-by: Ben Chello <bchello@dropbox.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11553/head
benchello 11 months ago committed by GitHub
parent 447a523662
commit 5de64e6d60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
import csv import csv
from io import TextIOWrapper from io import TextIOWrapper
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Sequence
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.base import BaseLoader
@ -36,6 +36,7 @@ class CSVLoader(BaseLoader):
self, self,
file_path: str, file_path: str,
source_column: Optional[str] = None, source_column: Optional[str] = None,
metadata_columns: Sequence[str] = (),
csv_args: Optional[Dict] = None, csv_args: Optional[Dict] = None,
encoding: Optional[str] = None, encoding: Optional[str] = None,
autodetect_encoding: bool = False, autodetect_encoding: bool = False,
@ -46,6 +47,7 @@ class CSVLoader(BaseLoader):
file_path: The path to the CSV file. file_path: The path to the CSV file.
source_column: The name of the column in the CSV file to use as the source. source_column: The name of the column in the CSV file to use as the source.
Optional. Defaults to None. Optional. Defaults to None.
metadata_columns: A sequence of column names to use as metadata. Optional.
csv_args: A dictionary of arguments to pass to the csv.DictReader. csv_args: A dictionary of arguments to pass to the csv.DictReader.
Optional. Defaults to None. Optional. Defaults to None.
encoding: The encoding of the CSV file. Optional. Defaults to None. encoding: The encoding of the CSV file. Optional. Defaults to None.
@ -53,6 +55,7 @@ class CSVLoader(BaseLoader):
""" """
self.file_path = file_path self.file_path = file_path
self.source_column = source_column self.source_column = source_column
self.metadata_columns = metadata_columns
self.encoding = encoding self.encoding = encoding
self.csv_args = csv_args or {} self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding self.autodetect_encoding = autodetect_encoding
@ -85,9 +88,9 @@ class CSVLoader(BaseLoader):
def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
docs = [] docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader): for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try: try:
source = ( source = (
row[self.source_column] row[self.source_column]
@ -98,7 +101,17 @@ class CSVLoader(BaseLoader):
raise ValueError( raise ValueError(
f"Source column '{self.source_column}' not found in CSV file." f"Source column '{self.source_column}' not found in CSV file."
) )
content = "\n".join(
f"{k.strip()}: {v.strip()}"
for k, v in row.items()
if k not in self.metadata_columns
)
metadata = {"source": source, "row": i} metadata = {"source": source, "row": i}
for col in self.metadata_columns:
try:
metadata[col] = row[col]
except KeyError:
raise ValueError(f"Metadata column '{col}' not found in CSV file.")
doc = Document(page_content=content, metadata=metadata) doc = Document(page_content=content, metadata=metadata)
docs.append(doc) docs.append(doc)

Loading…
Cancel
Save