pref: reduce DB query error rate (#5339)

# Reduce DB query error rate

If you use sql agent of `SQLDatabaseToolkit` to query data, it is prone
to errors in query fields and often uses fields that do not exist in
database tables for queries. However, the existing prompt does not
effectively make the agent aware that there are problems with the fields
they query. At this time, we urgently need to improve the prompt so that
the agent realizes that they have queried non-existent fields and allows
them to use the `schema_sql_db`, that is,` ListSQLDatabaseTool` first
queries the corresponding fields in the table in the database, and then
uses `QuerySQLDatabaseTool` for querying.

There is a demo of my project to show this problem.

**Original Agent**

```python
def create_mysql_kit():
    db = SQLDatabase.from_uri("mysql+pymysql://xxxxxxx")
    llm = OpenAI(temperature=0)

    toolkit = SQLDatabaseToolkit(db=db, llm=llm)
    agent_executor = create_sql_agent(
        llm=OpenAI(temperature=0),
        toolkit=toolkit,
        verbose=True
    )
    agent_executor.run("Who are the users of sysuser in this system? Tell me the username of all users")


if __name__ == '__main__':
    create_mysql_kit()

```

**original output**

```
> Entering new AgentExecutor chain...
Action: list_tables_sql_db
Action Input: ""
Observation: app_sysrole_menus, app_bimfacemodel, app_project_users, app_measuringpointdata, auth_user, auth_user_groups, django_apscheduler_djangojobexecution, app_project, app_elementpoint, django_apscheduler_djangojob, django_content_type, app_sysrole, django_admin_log, app_bimfaceaccount, app_measuringpoint_warning_thresholds, app_measuringpoint, app_company, auth_group_permissions, app_sysuser, app_sysuser_companies, app_sysmenu, app_datawarningthreshold, auth_group, auth_permission, app_datawarningrecord, auth_user_user_permissions, app_bimfaceaccount_bimface_models, django_migrations, app_measuringitem, django_session
Thought: I should query the app_sysuser table to get the users in the system.
Action: query_sql_db
Action Input: SELECT username FROM app_sysuser LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'")
[SQL: SELECT username FROM app_sysuser LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT username FROM app_sysuser LIMIT 10;
Observation: 

SELECT username FROM app_sysuser LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT username FROM app_sysuser LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'")
[SQL: SELECT username FROM app_sysuser LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT name FROM app_sysuser LIMIT 10;
Observation: 

SELECT name FROM app_sysuser LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT name FROM app_sysuser LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'name' in 'field list'")
[SQL: SELECT name FROM app_sysuser LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT id, username FROM app_sysuser LIMIT 10;
Observation: 

SELECT id, username FROM app_sysuser LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT id, username FROM app_sysuser LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'id' in 'field list'")
[SQL: SELECT id, username FROM app_sysuser LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT user_id, username FROM app_sysuser LIMIT 10;
Observation: 

SELECT user_id, username FROM app_sysuser WHERE user_id IS NOT NULL AND user_id BETWEEN 0 AND 10 UNION ALL SELECT user_id, username FROM app_sysuser WHERE username IS NOT NULL ORDER BY user_id ASC LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT user_id, username FROM app_sysuser WHERE user_id IS NOT NULL AND user_id BETWEEN 0 AND 10 UNION ALL SELECT user_id, username FROM app_sysuser WHERE username IS NOT NULL ORDER BY user_id ASC LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'")
[SQL: SELECT user_id, username FROM app_sysuser WHERE user_id IS NOT NULL AND user_id BETWEEN 0 AND 10 UNION ALL SELECT user_id, username FROM app_sysuser WHERE username IS NOT NULL ORDER BY user_id ASC LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT user_id, user_name FROM app_sysuser LIMIT 10;
Observation: 

SELECT user_id, user_name FROM app_sysuser WHERE user_id NOT IN (NULL) AND user_name BETWEEN 'A' AND 'Z' UNION ALL SELECT user_id, user_name FROM app_sysuser WHERE user_id IS NOT NULL AND user_name LIKE '%a%' LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT user_id, user_name FROM app_sysuser WHERE user_id NOT IN (NULL) AND user_name BETWEEN 'A' AND 'Z' UNION ALL SELECT user_id, user_name FROM app_sysuser WHERE user_id IS NOT NULL AND user_name LIKE '%a%' LIMIT 10;
Observation: []
Thought: I now know the final answer.
Final Answer: There are no users in the sysuser table.

> Finished chain.

Process finished with exit code 0
```

As you can see, agent always use incorrect sql to query the data.
Actually, the username field name of my `app_sysuser` table is
`user_name`, not the `username` believed by the agent.

**Optimized agent output**

```
> Entering new AgentExecutor chain...
Action: list_tables_sql_db
Action Input: ""
Observation: app_project_users, django_content_type, app_measuringpoint_warning_thresholds, app_bimfaceaccount_bimface_models, django_session, app_sysuser, django_migrations, app_bimfacemodel, app_elementpoint, app_measuringpoint, django_apscheduler_djangojobexecution, auth_permission, app_sysuser_companies, app_measuringpointdata, app_measuringitem, app_sysrole_menus, django_admin_log, auth_group, django_apscheduler_djangojob, app_sysmenu, app_project, app_bimfaceaccount, app_datawarningthreshold, app_datawarningrecord, auth_user_groups, auth_user_user_permissions, auth_group_permissions, auth_user, app_company, app_sysrole
Thought: I should query the app_sysuser table to get the usernames of all the users.
Action: query_sql_db
Action Input: SELECT username FROM app_sysuser
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'")
[SQL: SELECT username FROM app_sysuser]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should use the schema_sql_db tool to get the correct table fields.
Action: schema_sql_db
Action Input: app_sysuser
Observation: 
CREATE TABLE app_sysuser (
	user_id INTEGER NOT NULL AUTO_INCREMENT, 
	user_account VARCHAR(50) COLLATE utf8mb4_general_ci NOT NULL, 
	user_password VARCHAR(255) COLLATE utf8mb4_general_ci NOT NULL, 
	user_name VARCHAR(50) COLLATE utf8mb4_general_ci NOT NULL, 
	user_avatar VARCHAR(255) COLLATE utf8mb4_general_ci, 
	user_email VARCHAR(254) COLLATE utf8mb4_general_ci, 
	user_phone VARCHAR(20) COLLATE utf8mb4_general_ci, 
	user_create_time DATETIME(6) NOT NULL, 
	user_update_time DATETIME(6) NOT NULL, 
	user_status SMALLINT, 
	role_id_id INTEGER, 
	PRIMARY KEY (user_id), 
	CONSTRAINT app_sysuser_role_id_id_51b50f39_fk_app_sysrole_role_id FOREIGN KEY(role_id_id) REFERENCES app_sysrole (role_id)
)COLLATE utf8mb4_general_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB

/*
3 rows from app_sysuser table:
user_id	user_account	user_password	user_name	user_avatar	user_email	user_phone	user_create_time	user_update_time	user_status	role_id_id
xxxxxxxxxxxxxx
*/
Thought: I should query the app_sysuser table to get the usernames of all the users.
Action: query_sql_db
Action Input: SELECT user_account FROM app_sysuser LIMIT 10
Observation: [('baiyun',), ('eatrice',), ('lisi',), ('pingxiang',), ('wangwu',), ('zeeland',), ('zsj',), ('zzw',)]
Thought: I now know the final answer
Final Answer: The usernames of the users in the sysuser table are baiyun, eatrice, lisi, pingxiang, wangwu, zeeland, zsj, and zzw.

> Finished chain.

Process finished with exit code 0

```

I have tested about 10 related prompts and they all work properly, with
a much lower error rate compared to before


## Who can review?

@vowelparrot

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
searx_updates
Zeeland 12 months ago committed by GitHub
parent ce6dbe41a9
commit b72401b47b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,9 +33,27 @@ class SQLDatabaseToolkit(BaseToolkit):
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result from the database. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
"'xxxx' in 'field list', using schema_sql_db to query the correct table "
"fields."
)
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"schema and sample rows for those tables. "
"Be sure that the tables actually exist by calling list_tables_sql_db "
"first! Example Input: 'table1, table2, table3'"
)
return [
QuerySQLDataBaseTool(db=self.db),
InfoSQLDatabaseTool(db=self.db),
QuerySQLDataBaseTool(
db=self.db, description=query_sql_database_tool_description
),
InfoSQLDatabaseTool(
db=self.db, description=info_sql_database_tool_description
),
ListSQLDatabaseTool(db=self.db),
QueryCheckerTool(db=self.db, llm=self.llm),
]

@ -61,8 +61,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
name = "schema_sql_db"
description = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_sql_db first!
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Example Input: "table1, table2, table3"
"""

Loading…
Cancel
Save