diff --git a/libs/langchain/langchain/agents/agent_toolkits/csv/base.py b/libs/langchain/langchain/agents/agent_toolkits/csv/base.py index 90aa8dd77a..f16b8772fd 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/csv/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/csv/base.py @@ -1,3 +1,4 @@ +from io import IOBase from typing import Any, List, Optional, Union from langchain.agents.agent import AgentExecutor @@ -7,7 +8,7 @@ from langchain.schema.language_model import BaseLanguageModel def create_csv_agent( llm: BaseLanguageModel, - path: Union[str, List[str]], + path: Union[str, IOBase, List[Union[str, IOBase]]], pandas_kwargs: Optional[dict] = None, **kwargs: Any, ) -> AgentExecutor: @@ -20,14 +21,14 @@ def create_csv_agent( ) _kwargs = pandas_kwargs or {} - if isinstance(path, str): + if isinstance(path, (str, IOBase)): df = pd.read_csv(path, **_kwargs) elif isinstance(path, list): df = [] for item in path: - if not isinstance(item, str): - raise ValueError(f"Expected str, got {type(path)}") + if not isinstance(item, (str, IOBase)): + raise ValueError(f"Expected str or file-like object, got {type(path)}") df.append(pd.read_csv(item, **_kwargs)) else: - raise ValueError(f"Expected str or list, got {type(path)}") + raise ValueError(f"Expected str, list, or file-like object, got {type(path)}") return create_pandas_dataframe_agent(llm, df, **kwargs) diff --git a/libs/langchain/tests/integration_tests/agent/test_csv_agent.py b/libs/langchain/tests/integration_tests/agent/test_csv_agent.py index c45607e50b..08169edb6e 100644 --- a/libs/langchain/tests/integration_tests/agent/test_csv_agent.py +++ b/libs/langchain/tests/integration_tests/agent/test_csv_agent.py @@ -1,3 +1,4 @@ +import io import re import numpy as np @@ -34,6 +35,15 @@ def csv_list(tmp_path_factory: TempPathFactory) -> DataFrame: return [filename1, filename2] +@pytest.fixture(scope="module") +def csv_file_like(tmp_path_factory: TempPathFactory) -> io.BytesIO: + random_data = np.random.rand(4, 4) + df = DataFrame(random_data, columns=["name", "age", "food", "sport"]) + buffer = io.BytesIO() + df.to_pickle(buffer) + return buffer + + def test_csv_agent_creation(csv: str) -> None: agent = create_csv_agent(OpenAI(temperature=0), csv) assert isinstance(agent, AgentExecutor) @@ -55,3 +65,12 @@ def test_multi_csv(csv_list: list) -> None: result = re.search(r".*(6).*", response) assert result is not None assert result.group(1) is not None + + +def test_file_like(file_like: io.BytesIO) -> None: + agent = create_csv_agent(OpenAI(temperature=0), file_like, verbose=True) + assert isinstance(agent, AgentExecutor) + response = agent.run("How many rows in the csv? Give me a number.") + result = re.search(r".*(4).*", response) + assert result is not None + assert result.group(1) is not None