option for csv agent to not include df in prompt (#4610)

This commit is contained in:
Harrison Chase 2023-05-12 21:55:22 -07:00 committed by GitHub
parent 7d425cbf38
commit 485ecc3580
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 11 deletions

View File

@ -2,7 +2,11 @@
from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX, SUFFIX
from langchain.agents.agent_toolkits.pandas.prompt import (
PREFIX,
SUFFIX_NO_DF,
SUFFIX_WITH_DF,
)
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
@ -15,7 +19,7 @@ def create_pandas_dataframe_agent(
df: Any,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
@ -23,6 +27,7 @@ def create_pandas_dataframe_agent(
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
include_df_in_prompt: Optional[bool] = True,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a pandas agent from an LLM and dataframe."""
@ -35,14 +40,27 @@ def create_pandas_dataframe_agent(
if not isinstance(df, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
if include_df_in_prompt is not None and suffix is not None:
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
if suffix is not None:
suffix_to_use = suffix
if input_variables is None:
input_variables = ["df", "input", "agent_scratchpad"]
else:
if include_df_in_prompt:
suffix_to_use = SUFFIX_WITH_DF
input_variables = ["df", "input", "agent_scratchpad"]
else:
suffix_to_use = SUFFIX_NO_DF
input_variables = ["input", "agent_scratchpad"]
tools = [PythonAstREPLTool(locals={"df": df})]
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
)
if "df" in input_variables:
partial_prompt = prompt.partial(df=str(df.head().to_markdown()))
else:
partial_prompt = prompt
llm_chain = LLMChain(
llm=llm,
prompt=partial_prompt,

View File

@ -4,7 +4,12 @@ PREFIX = """
You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
You should use the tools below to answer the question posed of you:"""
SUFFIX = """
SUFFIX_NO_DF = """
Begin!
Question: {input}
{agent_scratchpad}"""
SUFFIX_WITH_DF = """
This is the result of `print(df.head())`:
{df}

View File

@ -102,7 +102,7 @@ class PythonAstREPLTool(BaseTool):
output = mystdout.getvalue()
except Exception as e:
sys.stdout = old_stdout
output = str(e)
output = repr(e)
return output
except Exception as e:
return "{}: {}".format(type(e).__name__, str(e))

View File

@ -21,5 +21,6 @@ class PythonREPL(BaseModel):
output = mystdout.getvalue()
except Exception as e:
sys.stdout = old_stdout
output = str(e)
output = repr(e)
print(output)
return output

View File

@ -54,7 +54,7 @@ def test_python_repl_no_previous_variables() -> None:
foo = 3 # noqa: F841
repl = PythonREPL()
output = repl.run("print(foo)")
assert output == "name 'foo' is not defined"
assert output == """NameError("name 'foo' is not defined")"""
def test_python_repl_pass_in_locals() -> None: