mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
- **Description:** Fix #11737 issue (extra_tools option of create_pandas_dataframe_agent is not working), - **Issue:** #11737 , - **Dependencies:** no, - **Tag maintainer:** @baskaryan, @eyurtsev, @hwchase17 I needed this method at work, so I modified it myself and used it. There is a similar issue(#11737) and PR(#13018) of @PyroGenesis, so I combined my code at the original PR. You may be busy, but it would be great help for me if you checked. Thank you. - **Twitter handle:** @lunara_x If you need an .ipynb example about this, please tag me. I will share what I am working on after removing any work-related content. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
77a15fa988
commit
f758c8adc4
@ -33,7 +33,8 @@ def _get_multi_prompt(
|
|||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
number_of_head_rows: int = 5,
|
number_of_head_rows: int = 5,
|
||||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
extra_tools: Sequence[BaseTool] = (),
|
||||||
|
) -> Tuple[BasePromptTemplate, List[BaseTool]]:
|
||||||
num_dfs = len(dfs)
|
num_dfs = len(dfs)
|
||||||
if suffix is not None:
|
if suffix is not None:
|
||||||
suffix_to_use = suffix
|
suffix_to_use = suffix
|
||||||
@ -55,12 +56,13 @@ def _get_multi_prompt(
|
|||||||
df_locals = {}
|
df_locals = {}
|
||||||
for i, dataframe in enumerate(dfs):
|
for i, dataframe in enumerate(dfs):
|
||||||
df_locals[f"df{i + 1}"] = dataframe
|
df_locals[f"df{i + 1}"] = dataframe
|
||||||
tools = [PythonAstREPLTool(locals=df_locals)]
|
tools = [PythonAstREPLTool(locals=df_locals)] + list(extra_tools)
|
||||||
|
|
||||||
prompt = ZeroShotAgent.create_prompt(
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
tools,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix_to_use,
|
||||||
|
input_variables=input_variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
partial_prompt = prompt.partial()
|
partial_prompt = prompt.partial()
|
||||||
if "dfs_head" in input_variables:
|
if "dfs_head" in input_variables:
|
||||||
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
||||||
@ -77,7 +79,8 @@ def _get_single_prompt(
|
|||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
number_of_head_rows: int = 5,
|
number_of_head_rows: int = 5,
|
||||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
extra_tools: Sequence[BaseTool] = (),
|
||||||
|
) -> Tuple[BasePromptTemplate, List[BaseTool]]:
|
||||||
if suffix is not None:
|
if suffix is not None:
|
||||||
suffix_to_use = suffix
|
suffix_to_use = suffix
|
||||||
include_df_head = True
|
include_df_head = True
|
||||||
@ -96,10 +99,13 @@ def _get_single_prompt(
|
|||||||
if prefix is None:
|
if prefix is None:
|
||||||
prefix = PREFIX
|
prefix = PREFIX
|
||||||
|
|
||||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
tools = [PythonAstREPLTool(locals={"df": df})] + list(extra_tools)
|
||||||
|
|
||||||
prompt = ZeroShotAgent.create_prompt(
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
tools,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix_to_use,
|
||||||
|
input_variables=input_variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
partial_prompt = prompt.partial()
|
partial_prompt = prompt.partial()
|
||||||
@ -117,7 +123,8 @@ def _get_prompt_and_tools(
|
|||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
number_of_head_rows: int = 5,
|
number_of_head_rows: int = 5,
|
||||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
extra_tools: Sequence[BaseTool] = (),
|
||||||
|
) -> Tuple[BasePromptTemplate, List[BaseTool]]:
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@ -141,6 +148,7 @@ def _get_prompt_and_tools(
|
|||||||
input_variables=input_variables,
|
input_variables=input_variables,
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
number_of_head_rows=number_of_head_rows,
|
number_of_head_rows=number_of_head_rows,
|
||||||
|
extra_tools=extra_tools,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not isinstance(df, pd.DataFrame):
|
if not isinstance(df, pd.DataFrame):
|
||||||
@ -152,6 +160,7 @@ def _get_prompt_and_tools(
|
|||||||
input_variables=input_variables,
|
input_variables=input_variables,
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
number_of_head_rows=number_of_head_rows,
|
number_of_head_rows=number_of_head_rows,
|
||||||
|
extra_tools=extra_tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -287,6 +296,7 @@ def create_pandas_dataframe_agent(
|
|||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Construct a pandas agent from an LLM and dataframe."""
|
"""Construct a pandas agent from an LLM and dataframe."""
|
||||||
agent: BaseSingleActionAgent
|
agent: BaseSingleActionAgent
|
||||||
|
base_tools: Sequence[BaseTool]
|
||||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||||
prompt, base_tools = _get_prompt_and_tools(
|
prompt, base_tools = _get_prompt_and_tools(
|
||||||
df,
|
df,
|
||||||
@ -295,8 +305,9 @@ def create_pandas_dataframe_agent(
|
|||||||
input_variables=input_variables,
|
input_variables=input_variables,
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
number_of_head_rows=number_of_head_rows,
|
number_of_head_rows=number_of_head_rows,
|
||||||
|
extra_tools=extra_tools,
|
||||||
)
|
)
|
||||||
tools = base_tools + list(extra_tools)
|
tools = base_tools
|
||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -318,7 +329,7 @@ def create_pandas_dataframe_agent(
|
|||||||
include_df_in_prompt=include_df_in_prompt,
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
number_of_head_rows=number_of_head_rows,
|
number_of_head_rows=number_of_head_rows,
|
||||||
)
|
)
|
||||||
tools = base_tools + list(extra_tools)
|
tools = list(base_tools) + list(extra_tools)
|
||||||
agent = OpenAIFunctionsAgent(
|
agent = OpenAIFunctionsAgent(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=_prompt,
|
prompt=_prompt,
|
||||||
|
Loading…
Reference in New Issue
Block a user