From 485ecc35805f9d8cf97ca32cdb751e862adad37f Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 12 May 2023 21:55:22 -0700 Subject: [PATCH] option for csv agent to not include df in prompt (#4610) --- .../agents/agent_toolkits/pandas/base.py | 32 +++++++++++++++---- .../agents/agent_toolkits/pandas/prompt.py | 7 +++- langchain/tools/python/tool.py | 2 +- langchain/utilities/python.py | 3 +- tests/unit_tests/test_python.py | 2 +- 5 files changed, 35 insertions(+), 11 deletions(-) diff --git a/langchain/agents/agent_toolkits/pandas/base.py b/langchain/agents/agent_toolkits/pandas/base.py index 200bf453..0913337e 100644 --- a/langchain/agents/agent_toolkits/pandas/base.py +++ b/langchain/agents/agent_toolkits/pandas/base.py @@ -2,7 +2,11 @@ from typing import Any, Dict, List, Optional from langchain.agents.agent import AgentExecutor -from langchain.agents.agent_toolkits.pandas.prompt import PREFIX, SUFFIX +from langchain.agents.agent_toolkits.pandas.prompt import ( + PREFIX, + SUFFIX_NO_DF, + SUFFIX_WITH_DF, +) from langchain.agents.mrkl.base import ZeroShotAgent from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager @@ -15,7 +19,7 @@ def create_pandas_dataframe_agent( df: Any, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = PREFIX, - suffix: str = SUFFIX, + suffix: Optional[str] = None, input_variables: Optional[List[str]] = None, verbose: bool = False, return_intermediate_steps: bool = False, @@ -23,6 +27,7 @@ def create_pandas_dataframe_agent( max_execution_time: Optional[float] = None, early_stopping_method: str = "force", agent_executor_kwargs: Optional[Dict[str, Any]] = None, + include_df_in_prompt: Optional[bool] = True, **kwargs: Dict[str, Any], ) -> AgentExecutor: """Construct a pandas agent from an LLM and dataframe.""" @@ -35,14 +40,27 @@ def create_pandas_dataframe_agent( if not isinstance(df, pd.DataFrame): raise ValueError(f"Expected pandas object, got {type(df)}") - - if input_variables is None: - input_variables = ["df", "input", "agent_scratchpad"] + if include_df_in_prompt is not None and suffix is not None: + raise ValueError("If suffix is specified, include_df_in_prompt should not be.") + if suffix is not None: + suffix_to_use = suffix + if input_variables is None: + input_variables = ["df", "input", "agent_scratchpad"] + else: + if include_df_in_prompt: + suffix_to_use = SUFFIX_WITH_DF + input_variables = ["df", "input", "agent_scratchpad"] + else: + suffix_to_use = SUFFIX_NO_DF + input_variables = ["input", "agent_scratchpad"] tools = [PythonAstREPLTool(locals={"df": df})] prompt = ZeroShotAgent.create_prompt( - tools, prefix=prefix, suffix=suffix, input_variables=input_variables + tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables ) - partial_prompt = prompt.partial(df=str(df.head().to_markdown())) + if "df" in input_variables: + partial_prompt = prompt.partial(df=str(df.head().to_markdown())) + else: + partial_prompt = prompt llm_chain = LLMChain( llm=llm, prompt=partial_prompt, diff --git a/langchain/agents/agent_toolkits/pandas/prompt.py b/langchain/agents/agent_toolkits/pandas/prompt.py index 525d92e2..7988c61f 100644 --- a/langchain/agents/agent_toolkits/pandas/prompt.py +++ b/langchain/agents/agent_toolkits/pandas/prompt.py @@ -4,7 +4,12 @@ PREFIX = """ You are working with a pandas dataframe in Python. The name of the dataframe is `df`. You should use the tools below to answer the question posed of you:""" -SUFFIX = """ +SUFFIX_NO_DF = """ +Begin! +Question: {input} +{agent_scratchpad}""" + +SUFFIX_WITH_DF = """ This is the result of `print(df.head())`: {df} diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index c947cc62..e9062ecc 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -102,7 +102,7 @@ class PythonAstREPLTool(BaseTool): output = mystdout.getvalue() except Exception as e: sys.stdout = old_stdout - output = str(e) + output = repr(e) return output except Exception as e: return "{}: {}".format(type(e).__name__, str(e)) diff --git a/langchain/utilities/python.py b/langchain/utilities/python.py index 4abb22a1..39b8419e 100644 --- a/langchain/utilities/python.py +++ b/langchain/utilities/python.py @@ -21,5 +21,6 @@ class PythonREPL(BaseModel): output = mystdout.getvalue() except Exception as e: sys.stdout = old_stdout - output = str(e) + output = repr(e) + print(output) return output diff --git a/tests/unit_tests/test_python.py b/tests/unit_tests/test_python.py index 28319542..33bf70e9 100644 --- a/tests/unit_tests/test_python.py +++ b/tests/unit_tests/test_python.py @@ -54,7 +54,7 @@ def test_python_repl_no_previous_variables() -> None: foo = 3 # noqa: F841 repl = PythonREPL() output = repl.run("print(foo)") - assert output == "name 'foo' is not defined" + assert output == """NameError("name 'foo' is not defined")""" def test_python_repl_pass_in_locals() -> None: