community[patch]: undo create_sql_agent breaking (#16797)

pull/16801/head
Bagatur 8 months ago committed by GitHub
parent ef2bd745cb
commit daf820c77b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from langchain_core.messages import AIMessage, SystemMessage from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.prompts import BasePromptTemplate, PromptTemplate
@ -40,6 +40,7 @@ def create_sql_agent(
prefix: Optional[str] = None, prefix: Optional[str] = None,
suffix: Optional[str] = None, suffix: Optional[str] = None,
format_instructions: Optional[str] = None, format_instructions: Optional[str] = None,
input_variables: Optional[List[str]] = None,
top_k: int = 10, top_k: int = 10,
max_iterations: Optional[int] = 15, max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None, max_execution_time: Optional[float] = None,
@ -69,6 +70,7 @@ def create_sql_agent(
format_instructions: Formatting instructions to pass to format_instructions: Formatting instructions to pass to
ZeroShotAgent.create_prompt() when 'agent_type' is ZeroShotAgent.create_prompt() when 'agent_type' is
"zero-shot-react-description". Otherwise ignored. "zero-shot-react-description". Otherwise ignored.
input_variables: DEPRECATED.
top_k: Number of rows to query for by default. top_k: Number of rows to query for by default.
max_iterations: Passed to AgentExecutor init. max_iterations: Passed to AgentExecutor init.
max_execution_time: Passed to AgentExecutor init. max_execution_time: Passed to AgentExecutor init.
@ -119,6 +121,9 @@ def create_sql_agent(
raise ValueError( raise ValueError(
"Must provide exactly one of 'toolkit' or 'db'. Received both." "Must provide exactly one of 'toolkit' or 'db'. Received both."
) )
if input_variables:
kwargs = kwargs or {}
kwargs["input_variables"] = input_variables
if kwargs: if kwargs:
warnings.warn( warnings.warn(
f"Received additional kwargs {kwargs} which are no longer supported." f"Received additional kwargs {kwargs} which are no longer supported."

Loading…
Cancel
Save