From 5de64e6d609180f193304dcad619b1c203cd2b42 Mon Sep 17 00:00:00 2001 From: benchello Date: Mon, 9 Oct 2023 17:56:45 -0400 Subject: [PATCH] Add option to specify metadata columns in CSV loader (#11576) #### Description This PR adds the option to specify additional metadata columns in the CSVLoader beyond just `Source`. The current CSV loader includes all columns in `page_content` and if we want to have columns specified for `page_content` and `metadata` we have to do something like the below.: ``` csv = pd.read_csv( "path_to_csv" ).to_dict("records") documents = [ Document( page_content=doc["content"], metadata={ "last_modified_by": doc["last_modified_by"], "point_of_contact": doc["point_of_contact"], } ) for doc in csv ] ``` #### Usage Example Usage: ``` csv_test = CSVLoader( file_path="path_to_csv", metadata_columns=["last_modified_by", "point_of_contact"] ) ``` Example CSV: ``` content, last_modified_by, point_of_contact "hello world", "Person A", "Person B" ``` Example Result: ``` Document { page_content: "hello world" metadata: { row: '0', source: 'path_to_csv', last_modified_by: 'Person A', point_of_contact: 'Person B', } ``` --------- Co-authored-by: Ben Chello Co-authored-by: Bagatur --- .../langchain/document_loaders/csv_loader.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/document_loaders/csv_loader.py b/libs/langchain/langchain/document_loaders/csv_loader.py index bf9f31e05e..91d8700e25 100644 --- a/libs/langchain/langchain/document_loaders/csv_loader.py +++ b/libs/langchain/langchain/document_loaders/csv_loader.py @@ -1,6 +1,6 @@ import csv from io import TextIOWrapper -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader @@ -36,6 +36,7 @@ class CSVLoader(BaseLoader): self, file_path: str, source_column: Optional[str] = None, + metadata_columns: Sequence[str] = (), csv_args: Optional[Dict] = None, encoding: Optional[str] = None, autodetect_encoding: bool = False, @@ -46,6 +47,7 @@ class CSVLoader(BaseLoader): file_path: The path to the CSV file. source_column: The name of the column in the CSV file to use as the source. Optional. Defaults to None. + metadata_columns: A sequence of column names to use as metadata. Optional. csv_args: A dictionary of arguments to pass to the csv.DictReader. Optional. Defaults to None. encoding: The encoding of the CSV file. Optional. Defaults to None. @@ -53,6 +55,7 @@ class CSVLoader(BaseLoader): """ self.file_path = file_path self.source_column = source_column + self.metadata_columns = metadata_columns self.encoding = encoding self.csv_args = csv_args or {} self.autodetect_encoding = autodetect_encoding @@ -85,9 +88,9 @@ class CSVLoader(BaseLoader): def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: docs = [] + csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore for i, row in enumerate(csv_reader): - content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) try: source = ( row[self.source_column] @@ -98,7 +101,17 @@ class CSVLoader(BaseLoader): raise ValueError( f"Source column '{self.source_column}' not found in CSV file." ) + content = "\n".join( + f"{k.strip()}: {v.strip()}" + for k, v in row.items() + if k not in self.metadata_columns + ) metadata = {"source": source, "row": i} + for col in self.metadata_columns: + try: + metadata[col] = row[col] + except KeyError: + raise ValueError(f"Metadata column '{col}' not found in CSV file.") doc = Document(page_content=content, metadata=metadata) docs.append(doc)