mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community[patch]: added content_columns option to CSVLoader (#23809)
**Description:** Adding a new option to the CSVLoader that allows us to implicitly specify the columns that are used for generating the Document content. Currently these are implicitly set as "all fields not part of the metadata_columns". In some cases however it is useful to have a field both as a metadata and as part of the document content.
This commit is contained in:
parent
ab527027ac
commit
6a8f8a56ac
@ -104,6 +104,8 @@ class CSVLoader(BaseLoader):
|
||||
csv_args: Optional[Dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
*,
|
||||
content_columns: Sequence[str] = (),
|
||||
):
|
||||
"""
|
||||
|
||||
@ -116,6 +118,8 @@ class CSVLoader(BaseLoader):
|
||||
Optional. Defaults to None.
|
||||
encoding: The encoding of the CSV file. Optional. Defaults to None.
|
||||
autodetect_encoding: Whether to try to autodetect the file encoding.
|
||||
content_columns: A sequence of column names to use for the document content.
|
||||
If not present, use all columns that are not part of the metadata.
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.source_column = source_column
|
||||
@ -123,6 +127,7 @@ class CSVLoader(BaseLoader):
|
||||
self.encoding = encoding
|
||||
self.csv_args = csv_args or {}
|
||||
self.autodetect_encoding = autodetect_encoding
|
||||
self.content_columns = content_columns
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
try:
|
||||
@ -163,7 +168,11 @@ class CSVLoader(BaseLoader):
|
||||
if isinstance(v, str) else ','.join(map(str.strip, v))
|
||||
if isinstance(v, list) else v}"""
|
||||
for k, v in row.items()
|
||||
if k not in self.metadata_columns
|
||||
if (
|
||||
k in self.content_columns
|
||||
if self.content_columns
|
||||
else k not in self.metadata_columns
|
||||
)
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
for col in self.metadata_columns:
|
||||
|
@ -108,6 +108,27 @@ class TestCSVLoader:
|
||||
# Assert
|
||||
assert result == expected_docs
|
||||
|
||||
def test_csv_loader_content_columns(self) -> None:
|
||||
# Setup
|
||||
file_path = self._get_csv_file_path("test_none_col.csv")
|
||||
expected_docs = [
|
||||
Document(
|
||||
page_content="column1: value1\n" "column3: value3",
|
||||
metadata={"source": file_path, "row": 0},
|
||||
),
|
||||
Document(
|
||||
page_content="column1: value6\n" "column3: value8",
|
||||
metadata={"source": file_path, "row": 1},
|
||||
),
|
||||
]
|
||||
|
||||
# Exercise
|
||||
loader = CSVLoader(file_path=file_path, content_columns=("column1", "column3"))
|
||||
result = loader.load()
|
||||
|
||||
# Assert
|
||||
assert result == expected_docs
|
||||
|
||||
# utility functions
|
||||
def _get_csv_file_path(self, file_name: str) -> str:
|
||||
return str(Path(__file__).resolve().parent / "test_docs" / "csv" / file_name)
|
||||
|
Loading…
Reference in New Issue
Block a user