Adding File-Like object support in CSV Agent Toolkit (#10409)

If loading a CSV from a direct or temporary source, loading the
file-like object (subclass of IOBase) directly allows the agent creation
process to succeed, instead of throwing a ValueError.

Added an additional elif and tweaked value error message.
Added test to validate this functionality.

Pandas from_csv supports this natively but this current implementation
only accepts strings or paths to files.
https://pandas.pydata.org/docs/user_guide/io.html#io-read-csv-table

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/10509/head
James Barney 12 months ago committed by GitHub
parent 999163fbd6
commit 50128c8b39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,4 @@
from io import IOBase
from typing import Any, List, Optional, Union
from langchain.agents.agent import AgentExecutor
@ -7,7 +8,7 @@ from langchain.schema.language_model import BaseLanguageModel
def create_csv_agent(
llm: BaseLanguageModel,
path: Union[str, List[str]],
path: Union[str, IOBase, List[Union[str, IOBase]]],
pandas_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> AgentExecutor:
@ -20,14 +21,14 @@ def create_csv_agent(
)
_kwargs = pandas_kwargs or {}
if isinstance(path, str):
if isinstance(path, (str, IOBase)):
df = pd.read_csv(path, **_kwargs)
elif isinstance(path, list):
df = []
for item in path:
if not isinstance(item, str):
raise ValueError(f"Expected str, got {type(path)}")
if not isinstance(item, (str, IOBase)):
raise ValueError(f"Expected str or file-like object, got {type(path)}")
df.append(pd.read_csv(item, **_kwargs))
else:
raise ValueError(f"Expected str or list, got {type(path)}")
raise ValueError(f"Expected str, list, or file-like object, got {type(path)}")
return create_pandas_dataframe_agent(llm, df, **kwargs)

@ -1,3 +1,4 @@
import io
import re
import numpy as np
@ -34,6 +35,15 @@ def csv_list(tmp_path_factory: TempPathFactory) -> DataFrame:
return [filename1, filename2]
@pytest.fixture(scope="module")
def csv_file_like(tmp_path_factory: TempPathFactory) -> io.BytesIO:
random_data = np.random.rand(4, 4)
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
buffer = io.BytesIO()
df.to_pickle(buffer)
return buffer
def test_csv_agent_creation(csv: str) -> None:
agent = create_csv_agent(OpenAI(temperature=0), csv)
assert isinstance(agent, AgentExecutor)
@ -55,3 +65,12 @@ def test_multi_csv(csv_list: list) -> None:
result = re.search(r".*(6).*", response)
assert result is not None
assert result.group(1) is not None
def test_file_like(file_like: io.BytesIO) -> None:
agent = create_csv_agent(OpenAI(temperature=0), file_like, verbose=True)
assert isinstance(agent, AgentExecutor)
response = agent.run("How many rows in the csv? Give me a number.")
result = re.search(r".*(4).*", response)
assert result is not None
assert result.group(1) is not None

Loading…
Cancel
Save