diff --git a/libs/langchain/langchain/document_loaders/csv_loader.py b/libs/langchain/langchain/document_loaders/csv_loader.py index bf9f31e05e..91d8700e25 100644 --- a/libs/langchain/langchain/document_loaders/csv_loader.py +++ b/libs/langchain/langchain/document_loaders/csv_loader.py @@ -1,6 +1,6 @@ import csv from io import TextIOWrapper -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader @@ -36,6 +36,7 @@ class CSVLoader(BaseLoader): self, file_path: str, source_column: Optional[str] = None, + metadata_columns: Sequence[str] = (), csv_args: Optional[Dict] = None, encoding: Optional[str] = None, autodetect_encoding: bool = False, @@ -46,6 +47,7 @@ class CSVLoader(BaseLoader): file_path: The path to the CSV file. source_column: The name of the column in the CSV file to use as the source. Optional. Defaults to None. + metadata_columns: A sequence of column names to use as metadata. Optional. csv_args: A dictionary of arguments to pass to the csv.DictReader. Optional. Defaults to None. encoding: The encoding of the CSV file. Optional. Defaults to None. @@ -53,6 +55,7 @@ class CSVLoader(BaseLoader): """ self.file_path = file_path self.source_column = source_column + self.metadata_columns = metadata_columns self.encoding = encoding self.csv_args = csv_args or {} self.autodetect_encoding = autodetect_encoding @@ -85,9 +88,9 @@ class CSVLoader(BaseLoader): def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: docs = [] + csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore for i, row in enumerate(csv_reader): - content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) try: source = ( row[self.source_column] @@ -98,7 +101,17 @@ class CSVLoader(BaseLoader): raise ValueError( f"Source column '{self.source_column}' not found in CSV file." ) + content = "\n".join( + f"{k.strip()}: {v.strip()}" + for k, v in row.items() + if k not in self.metadata_columns + ) metadata = {"source": source, "row": i} + for col in self.metadata_columns: + try: + metadata[col] = row[col] + except KeyError: + raise ValueError(f"Metadata column '{col}' not found in CSV file.") doc = Document(page_content=content, metadata=metadata) docs.append(doc)