mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Adding File-Like object support in CSV Agent Toolkit (#10409)
If loading a CSV from a direct or temporary source, loading the file-like object (subclass of IOBase) directly allows the agent creation process to succeed, instead of throwing a ValueError. Added an additional elif and tweaked value error message. Added test to validate this functionality. Pandas from_csv supports this natively but this current implementation only accepts strings or paths to files. https://pandas.pydata.org/docs/user_guide/io.html#io-read-csv-table --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
999163fbd6
commit
50128c8b39
@ -1,3 +1,4 @@
|
|||||||
|
from io import IOBase
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor
|
||||||
@ -7,7 +8,7 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
|
|
||||||
def create_csv_agent(
|
def create_csv_agent(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
path: Union[str, List[str]],
|
path: Union[str, IOBase, List[Union[str, IOBase]]],
|
||||||
pandas_kwargs: Optional[dict] = None,
|
pandas_kwargs: Optional[dict] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
@ -20,14 +21,14 @@ def create_csv_agent(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_kwargs = pandas_kwargs or {}
|
_kwargs = pandas_kwargs or {}
|
||||||
if isinstance(path, str):
|
if isinstance(path, (str, IOBase)):
|
||||||
df = pd.read_csv(path, **_kwargs)
|
df = pd.read_csv(path, **_kwargs)
|
||||||
elif isinstance(path, list):
|
elif isinstance(path, list):
|
||||||
df = []
|
df = []
|
||||||
for item in path:
|
for item in path:
|
||||||
if not isinstance(item, str):
|
if not isinstance(item, (str, IOBase)):
|
||||||
raise ValueError(f"Expected str, got {type(path)}")
|
raise ValueError(f"Expected str or file-like object, got {type(path)}")
|
||||||
df.append(pd.read_csv(item, **_kwargs))
|
df.append(pd.read_csv(item, **_kwargs))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Expected str or list, got {type(path)}")
|
raise ValueError(f"Expected str, list, or file-like object, got {type(path)}")
|
||||||
return create_pandas_dataframe_agent(llm, df, **kwargs)
|
return create_pandas_dataframe_agent(llm, df, **kwargs)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import io
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -34,6 +35,15 @@ def csv_list(tmp_path_factory: TempPathFactory) -> DataFrame:
|
|||||||
return [filename1, filename2]
|
return [filename1, filename2]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def csv_file_like(tmp_path_factory: TempPathFactory) -> io.BytesIO:
|
||||||
|
random_data = np.random.rand(4, 4)
|
||||||
|
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
df.to_pickle(buffer)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
def test_csv_agent_creation(csv: str) -> None:
|
def test_csv_agent_creation(csv: str) -> None:
|
||||||
agent = create_csv_agent(OpenAI(temperature=0), csv)
|
agent = create_csv_agent(OpenAI(temperature=0), csv)
|
||||||
assert isinstance(agent, AgentExecutor)
|
assert isinstance(agent, AgentExecutor)
|
||||||
@ -55,3 +65,12 @@ def test_multi_csv(csv_list: list) -> None:
|
|||||||
result = re.search(r".*(6).*", response)
|
result = re.search(r".*(6).*", response)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.group(1) is not None
|
assert result.group(1) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_like(file_like: io.BytesIO) -> None:
|
||||||
|
agent = create_csv_agent(OpenAI(temperature=0), file_like, verbose=True)
|
||||||
|
assert isinstance(agent, AgentExecutor)
|
||||||
|
response = agent.run("How many rows in the csv? Give me a number.")
|
||||||
|
result = re.search(r".*(4).*", response)
|
||||||
|
assert result is not None
|
||||||
|
assert result.group(1) is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user