option for csv agent to not include df in prompt (#4610)

This commit is contained in:
Harrison Chase 2023-05-12 21:55:22 -07:00 committed by GitHub
parent 7d425cbf38
commit 485ecc3580
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 11 deletions

View File

@ -2,7 +2,11 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX, SUFFIX from langchain.agents.agent_toolkits.pandas.prompt import (
PREFIX,
SUFFIX_NO_DF,
SUFFIX_WITH_DF,
)
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
@ -15,7 +19,7 @@ def create_pandas_dataframe_agent(
df: Any, df: Any,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX, prefix: str = PREFIX,
suffix: str = SUFFIX, suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None, input_variables: Optional[List[str]] = None,
verbose: bool = False, verbose: bool = False,
return_intermediate_steps: bool = False, return_intermediate_steps: bool = False,
@ -23,6 +27,7 @@ def create_pandas_dataframe_agent(
max_execution_time: Optional[float] = None, max_execution_time: Optional[float] = None,
early_stopping_method: str = "force", early_stopping_method: str = "force",
agent_executor_kwargs: Optional[Dict[str, Any]] = None, agent_executor_kwargs: Optional[Dict[str, Any]] = None,
include_df_in_prompt: Optional[bool] = True,
**kwargs: Dict[str, Any], **kwargs: Dict[str, Any],
) -> AgentExecutor: ) -> AgentExecutor:
"""Construct a pandas agent from an LLM and dataframe.""" """Construct a pandas agent from an LLM and dataframe."""
@ -35,14 +40,27 @@ def create_pandas_dataframe_agent(
if not isinstance(df, pd.DataFrame): if not isinstance(df, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}") raise ValueError(f"Expected pandas object, got {type(df)}")
if include_df_in_prompt is not None and suffix is not None:
if input_variables is None: raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
input_variables = ["df", "input", "agent_scratchpad"] if suffix is not None:
suffix_to_use = suffix
if input_variables is None:
input_variables = ["df", "input", "agent_scratchpad"]
else:
if include_df_in_prompt:
suffix_to_use = SUFFIX_WITH_DF
input_variables = ["df", "input", "agent_scratchpad"]
else:
suffix_to_use = SUFFIX_NO_DF
input_variables = ["input", "agent_scratchpad"]
tools = [PythonAstREPLTool(locals={"df": df})] tools = [PythonAstREPLTool(locals={"df": df})]
prompt = ZeroShotAgent.create_prompt( prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix, input_variables=input_variables tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
) )
partial_prompt = prompt.partial(df=str(df.head().to_markdown())) if "df" in input_variables:
partial_prompt = prompt.partial(df=str(df.head().to_markdown()))
else:
partial_prompt = prompt
llm_chain = LLMChain( llm_chain = LLMChain(
llm=llm, llm=llm,
prompt=partial_prompt, prompt=partial_prompt,

View File

@ -4,7 +4,12 @@ PREFIX = """
You are working with a pandas dataframe in Python. The name of the dataframe is `df`. You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
You should use the tools below to answer the question posed of you:""" You should use the tools below to answer the question posed of you:"""
SUFFIX = """ SUFFIX_NO_DF = """
Begin!
Question: {input}
{agent_scratchpad}"""
SUFFIX_WITH_DF = """
This is the result of `print(df.head())`: This is the result of `print(df.head())`:
{df} {df}

View File

@ -102,7 +102,7 @@ class PythonAstREPLTool(BaseTool):
output = mystdout.getvalue() output = mystdout.getvalue()
except Exception as e: except Exception as e:
sys.stdout = old_stdout sys.stdout = old_stdout
output = str(e) output = repr(e)
return output return output
except Exception as e: except Exception as e:
return "{}: {}".format(type(e).__name__, str(e)) return "{}: {}".format(type(e).__name__, str(e))

View File

@ -21,5 +21,6 @@ class PythonREPL(BaseModel):
output = mystdout.getvalue() output = mystdout.getvalue()
except Exception as e: except Exception as e:
sys.stdout = old_stdout sys.stdout = old_stdout
output = str(e) output = repr(e)
print(output)
return output return output

View File

@ -54,7 +54,7 @@ def test_python_repl_no_previous_variables() -> None:
foo = 3 # noqa: F841 foo = 3 # noqa: F841
repl = PythonREPL() repl = PythonREPL()
output = repl.run("print(foo)") output = repl.run("print(foo)")
assert output == "name 'foo' is not defined" assert output == """NameError("name 'foo' is not defined")"""
def test_python_repl_pass_in_locals() -> None: def test_python_repl_pass_in_locals() -> None: