From b72401b47b07017b86a75456acb24b24185a7952 Mon Sep 17 00:00:00 2001 From: Zeeland Date: Sun, 4 Jun 2023 05:55:56 +0800 Subject: [PATCH] pref: reduce DB query error rate (#5339) # Reduce DB query error rate If you use sql agent of `SQLDatabaseToolkit` to query data, it is prone to errors in query fields and often uses fields that do not exist in database tables for queries. However, the existing prompt does not effectively make the agent aware that there are problems with the fields they query. At this time, we urgently need to improve the prompt so that the agent realizes that they have queried non-existent fields and allows them to use the `schema_sql_db`, that is,` ListSQLDatabaseTool` first queries the corresponding fields in the table in the database, and then uses `QuerySQLDatabaseTool` for querying. There is a demo of my project to show this problem. **Original Agent** ```python def create_mysql_kit(): db = SQLDatabase.from_uri("mysql+pymysql://xxxxxxx") llm = OpenAI(temperature=0) toolkit = SQLDatabaseToolkit(db=db, llm=llm) agent_executor = create_sql_agent( llm=OpenAI(temperature=0), toolkit=toolkit, verbose=True ) agent_executor.run("Who are the users of sysuser in this system? Tell me the username of all users") if __name__ == '__main__': create_mysql_kit() ``` **original output** ``` > Entering new AgentExecutor chain... Action: list_tables_sql_db Action Input: "" Observation: app_sysrole_menus, app_bimfacemodel, app_project_users, app_measuringpointdata, auth_user, auth_user_groups, django_apscheduler_djangojobexecution, app_project, app_elementpoint, django_apscheduler_djangojob, django_content_type, app_sysrole, django_admin_log, app_bimfaceaccount, app_measuringpoint_warning_thresholds, app_measuringpoint, app_company, auth_group_permissions, app_sysuser, app_sysuser_companies, app_sysmenu, app_datawarningthreshold, auth_group, auth_permission, app_datawarningrecord, auth_user_user_permissions, app_bimfaceaccount_bimface_models, django_migrations, app_measuringitem, django_session Thought: I should query the app_sysuser table to get the users in the system. Action: query_sql_db Action Input: SELECT username FROM app_sysuser LIMIT 10; Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'") [SQL: SELECT username FROM app_sysuser LIMIT 10;] (Background on this error at: https://sqlalche.me/e/20/e3q8) Thought: I should double check my query before executing it. Action: query_checker_sql_db Action Input: SELECT username FROM app_sysuser LIMIT 10; Observation: SELECT username FROM app_sysuser LIMIT 10; Thought: The query looks correct, so I should execute it. Action: query_sql_db Action Input: SELECT username FROM app_sysuser LIMIT 10; Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'") [SQL: SELECT username FROM app_sysuser LIMIT 10;] (Background on this error at: https://sqlalche.me/e/20/e3q8) Thought: I should double check my query before executing it. Action: query_checker_sql_db Action Input: SELECT name FROM app_sysuser LIMIT 10; Observation: SELECT name FROM app_sysuser LIMIT 10; Thought: The query looks correct, so I should execute it. Action: query_sql_db Action Input: SELECT name FROM app_sysuser LIMIT 10; Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'name' in 'field list'") [SQL: SELECT name FROM app_sysuser LIMIT 10;] (Background on this error at: https://sqlalche.me/e/20/e3q8) Thought: I should double check my query before executing it. Action: query_checker_sql_db Action Input: SELECT id, username FROM app_sysuser LIMIT 10; Observation: SELECT id, username FROM app_sysuser LIMIT 10; Thought: The query looks correct, so I should execute it. Action: query_sql_db Action Input: SELECT id, username FROM app_sysuser LIMIT 10; Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'id' in 'field list'") [SQL: SELECT id, username FROM app_sysuser LIMIT 10;] (Background on this error at: https://sqlalche.me/e/20/e3q8) Thought: I should double check my query before executing it. Action: query_checker_sql_db Action Input: SELECT user_id, username FROM app_sysuser LIMIT 10; Observation: SELECT user_id, username FROM app_sysuser WHERE user_id IS NOT NULL AND user_id BETWEEN 0 AND 10 UNION ALL SELECT user_id, username FROM app_sysuser WHERE username IS NOT NULL ORDER BY user_id ASC LIMIT 10; Thought: The query looks correct, so I should execute it. Action: query_sql_db Action Input: SELECT user_id, username FROM app_sysuser WHERE user_id IS NOT NULL AND user_id BETWEEN 0 AND 10 UNION ALL SELECT user_id, username FROM app_sysuser WHERE username IS NOT NULL ORDER BY user_id ASC LIMIT 10; Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'") [SQL: SELECT user_id, username FROM app_sysuser WHERE user_id IS NOT NULL AND user_id BETWEEN 0 AND 10 UNION ALL SELECT user_id, username FROM app_sysuser WHERE username IS NOT NULL ORDER BY user_id ASC LIMIT 10;] (Background on this error at: https://sqlalche.me/e/20/e3q8) Thought: I should double check my query before executing it. Action: query_checker_sql_db Action Input: SELECT user_id, user_name FROM app_sysuser LIMIT 10; Observation: SELECT user_id, user_name FROM app_sysuser WHERE user_id NOT IN (NULL) AND user_name BETWEEN 'A' AND 'Z' UNION ALL SELECT user_id, user_name FROM app_sysuser WHERE user_id IS NOT NULL AND user_name LIKE '%a%' LIMIT 10; Thought: The query looks correct, so I should execute it. Action: query_sql_db Action Input: SELECT user_id, user_name FROM app_sysuser WHERE user_id NOT IN (NULL) AND user_name BETWEEN 'A' AND 'Z' UNION ALL SELECT user_id, user_name FROM app_sysuser WHERE user_id IS NOT NULL AND user_name LIKE '%a%' LIMIT 10; Observation: [] Thought: I now know the final answer. Final Answer: There are no users in the sysuser table. > Finished chain. Process finished with exit code 0 ``` As you can see, agent always use incorrect sql to query the data. Actually, the username field name of my `app_sysuser` table is `user_name`, not the `username` believed by the agent. **Optimized agent output** ``` > Entering new AgentExecutor chain... Action: list_tables_sql_db Action Input: "" Observation: app_project_users, django_content_type, app_measuringpoint_warning_thresholds, app_bimfaceaccount_bimface_models, django_session, app_sysuser, django_migrations, app_bimfacemodel, app_elementpoint, app_measuringpoint, django_apscheduler_djangojobexecution, auth_permission, app_sysuser_companies, app_measuringpointdata, app_measuringitem, app_sysrole_menus, django_admin_log, auth_group, django_apscheduler_djangojob, app_sysmenu, app_project, app_bimfaceaccount, app_datawarningthreshold, app_datawarningrecord, auth_user_groups, auth_user_user_permissions, auth_group_permissions, auth_user, app_company, app_sysrole Thought: I should query the app_sysuser table to get the usernames of all the users. Action: query_sql_db Action Input: SELECT username FROM app_sysuser Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'") [SQL: SELECT username FROM app_sysuser] (Background on this error at: https://sqlalche.me/e/20/e3q8) Thought: I should use the schema_sql_db tool to get the correct table fields. Action: schema_sql_db Action Input: app_sysuser Observation: CREATE TABLE app_sysuser ( user_id INTEGER NOT NULL AUTO_INCREMENT, user_account VARCHAR(50) COLLATE utf8mb4_general_ci NOT NULL, user_password VARCHAR(255) COLLATE utf8mb4_general_ci NOT NULL, user_name VARCHAR(50) COLLATE utf8mb4_general_ci NOT NULL, user_avatar VARCHAR(255) COLLATE utf8mb4_general_ci, user_email VARCHAR(254) COLLATE utf8mb4_general_ci, user_phone VARCHAR(20) COLLATE utf8mb4_general_ci, user_create_time DATETIME(6) NOT NULL, user_update_time DATETIME(6) NOT NULL, user_status SMALLINT, role_id_id INTEGER, PRIMARY KEY (user_id), CONSTRAINT app_sysuser_role_id_id_51b50f39_fk_app_sysrole_role_id FOREIGN KEY(role_id_id) REFERENCES app_sysrole (role_id) )COLLATE utf8mb4_general_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB /* 3 rows from app_sysuser table: user_id user_account user_password user_name user_avatar user_email user_phone user_create_time user_update_time user_status role_id_id xxxxxxxxxxxxxx */ Thought: I should query the app_sysuser table to get the usernames of all the users. Action: query_sql_db Action Input: SELECT user_account FROM app_sysuser LIMIT 10 Observation: [('baiyun',), ('eatrice',), ('lisi',), ('pingxiang',), ('wangwu',), ('zeeland',), ('zsj',), ('zzw',)] Thought: I now know the final answer Final Answer: The usernames of the users in the sysuser table are baiyun, eatrice, lisi, pingxiang, wangwu, zeeland, zsj, and zzw. > Finished chain. Process finished with exit code 0 ``` I have tested about 10 related prompts and they all work properly, with a much lower error rate compared to before ## Who can review? @vowelparrot --------- Co-authored-by: Dev 2049 --- .../agents/agent_toolkits/sql/toolkit.py | 22 +++++++++++++++++-- langchain/tools/sql_database/tool.py | 3 +-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/langchain/agents/agent_toolkits/sql/toolkit.py b/langchain/agents/agent_toolkits/sql/toolkit.py index 085d24e3..8249b134 100644 --- a/langchain/agents/agent_toolkits/sql/toolkit.py +++ b/langchain/agents/agent_toolkits/sql/toolkit.py @@ -33,9 +33,27 @@ class SQLDatabaseToolkit(BaseToolkit): def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" + query_sql_database_tool_description = ( + "Input to this tool is a detailed and correct SQL query, output is a " + "result from the database. If the query is not correct, an error message " + "will be returned. If an error is returned, rewrite the query, check the " + "query, and try again. If you encounter an issue with Unknown column " + "'xxxx' in 'field list', using schema_sql_db to query the correct table " + "fields." + ) + info_sql_database_tool_description = ( + "Input to this tool is a comma-separated list of tables, output is the " + "schema and sample rows for those tables. " + "Be sure that the tables actually exist by calling list_tables_sql_db " + "first! Example Input: 'table1, table2, table3'" + ) return [ - QuerySQLDataBaseTool(db=self.db), - InfoSQLDatabaseTool(db=self.db), + QuerySQLDataBaseTool( + db=self.db, description=query_sql_database_tool_description + ), + InfoSQLDatabaseTool( + db=self.db, description=info_sql_database_tool_description + ), ListSQLDatabaseTool(db=self.db), QueryCheckerTool(db=self.db, llm=self.llm), ] diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index 2e677c6c..9edad760 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -61,8 +61,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): name = "schema_sql_db" description = """ - Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. - Be sure that the tables actually exist by calling list_tables_sql_db first! + Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Example Input: "table1, table2, table3" """