@ -150,17 +150,18 @@ def create_sql_agent(
prompt = prompt . partial ( top_k = str ( top_k ) )
if " dialect " in prompt . input_variables :
prompt = prompt . partial ( dialect = toolkit . dialect )
db_context = toolkit . get_context ( )
if " table_info " in prompt . input_variables :
prompt = prompt . partial ( table_info = db_context [ " table_info " ] )
tools = [
tool for tool in tools if not isinstance ( tool , InfoSQLDatabaseTool )
]
if " table_names " in prompt . input_variables :
prompt = prompt . partial ( table_names = db_context [ " table_names " ] )
tools = [
tool for tool in tools if not isinstance ( tool , ListSQLDatabaseTool )
]
if any ( key in prompt . input_variables for key in [ " table_info " , " table_names " ] ) :
db_context = toolkit . get_context ( )
if " table_info " in prompt . input_variables :
prompt = prompt . partial ( table_info = db_context [ " table_info " ] )
tools = [
tool for tool in tools if not isinstance ( tool , InfoSQLDatabaseTool )
]
if " table_names " in prompt . input_variables :
prompt = prompt . partial ( table_names = db_context [ " table_names " ] )
tools = [
tool for tool in tools if not isinstance ( tool , ListSQLDatabaseTool )
]
if agent_type == AgentType . ZERO_SHOT_REACT_DESCRIPTION :
if prompt is None :