Fix: Pass along kwargs when creating a sql agent (#2350)

Currently, `agent_toolkits.sql.create_sql_agent()` passes kwargs to the
`ZeroShotAgent` that it creates but not to `AgentExecutor` that it also
creates. This prevents the caller from providing some useful arguments
like `max_iterations` and `early_stopping_method`

This PR changes `create_sql_agent` so that it passes kwargs to both
constructors.

---------

Co-authored-by: Zachary Jones <zjones@zetaglobal.com>
This commit is contained in:
Zach Jones 2023-04-04 00:50:51 -04:00 committed by GitHub
parent 7ed8d00bba
commit c969a779c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,6 +20,8 @@ def create_sql_agent(
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
early_stopping_method: str = "force",
verbose: bool = False,
**kwargs: Any,
) -> AgentExecutor:
@ -41,5 +43,9 @@ def create_sql_agent(
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=toolkit.get_tools(), verbose=verbose
agent=agent,
tools=toolkit.get_tools(),
verbose=verbose,
max_iterations=max_iterations,
early_stopping_method=early_stopping_method,
)