diff --git a/libs/langchain/langchain/agents/agent_toolkits/pandas/base.py b/libs/langchain/langchain/agents/agent_toolkits/pandas/base.py index f32015c469..7891c47a68 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/pandas/base.py +++ b/libs/langchain/langchain/agents/agent_toolkits/pandas/base.py @@ -1,5 +1,5 @@ """Agent for working with pandas objects.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent from langchain.agents.agent_toolkits.pandas.prompt import ( @@ -21,6 +21,7 @@ from langchain.chains.llm import LLMChain from langchain.schema import BasePromptTemplate from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import SystemMessage +from langchain.tools import BaseTool from langchain.tools.python.tool import PythonAstREPLTool @@ -280,12 +281,13 @@ def create_pandas_dataframe_agent( agent_executor_kwargs: Optional[Dict[str, Any]] = None, include_df_in_prompt: Optional[bool] = True, number_of_head_rows: int = 5, + extra_tools: Sequence[BaseTool] = (), **kwargs: Dict[str, Any], ) -> AgentExecutor: """Construct a pandas agent from an LLM and dataframe.""" agent: BaseSingleActionAgent if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: - prompt, tools = _get_prompt_and_tools( + prompt, base_tools = _get_prompt_and_tools( df, prefix=prefix, suffix=suffix, @@ -293,6 +295,7 @@ def create_pandas_dataframe_agent( include_df_in_prompt=include_df_in_prompt, number_of_head_rows=number_of_head_rows, ) + tools = base_tools + list(extra_tools) llm_chain = LLMChain( llm=llm, prompt=prompt, @@ -306,7 +309,7 @@ def create_pandas_dataframe_agent( **kwargs, ) elif agent_type == AgentType.OPENAI_FUNCTIONS: - _prompt, tools = _get_functions_prompt_and_tools( + _prompt, base_tools = _get_functions_prompt_and_tools( df, prefix=prefix, suffix=suffix, @@ -314,6 +317,7 @@ def create_pandas_dataframe_agent( include_df_in_prompt=include_df_in_prompt, number_of_head_rows=number_of_head_rows, ) + tools = base_tools + list(extra_tools) agent = OpenAIFunctionsAgent( llm=llm, prompt=_prompt,