community[patch]: avoid executing `toolkit.get_context()` when not necessary (#19762)

If `prompt` is passed into `create_sql_agent()`, then
`toolkit.get_context()` shouldn't be executed against the database
unless relevant prompt variables (`table_info` or `table_names`) are
present .
pull/19765/head
Arturs Konfino 6 months ago committed by GitHub
parent ec7a59c96c
commit 2319212d54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -150,17 +150,18 @@ def create_sql_agent(
prompt = prompt.partial(top_k=str(top_k))
if "dialect" in prompt.input_variables:
prompt = prompt.partial(dialect=toolkit.dialect)
db_context = toolkit.get_context()
if "table_info" in prompt.input_variables:
prompt = prompt.partial(table_info=db_context["table_info"])
tools = [
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
]
if "table_names" in prompt.input_variables:
prompt = prompt.partial(table_names=db_context["table_names"])
tools = [
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]
if any(key in prompt.input_variables for key in ["table_info", "table_names"]):
db_context = toolkit.get_context()
if "table_info" in prompt.input_variables:
prompt = prompt.partial(table_info=db_context["table_info"])
tools = [
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
]
if "table_names" in prompt.input_variables:
prompt = prompt.partial(table_names=db_context["table_names"])
tools = [
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
if prompt is None:

Loading…
Cancel
Save