diff --git a/langchain/agents/agent_toolkits/sql/toolkit.py b/langchain/agents/agent_toolkits/sql/toolkit.py index 085d24e3..8249b134 100644 --- a/langchain/agents/agent_toolkits/sql/toolkit.py +++ b/langchain/agents/agent_toolkits/sql/toolkit.py @@ -33,9 +33,27 @@ class SQLDatabaseToolkit(BaseToolkit): def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" + query_sql_database_tool_description = ( + "Input to this tool is a detailed and correct SQL query, output is a " + "result from the database. If the query is not correct, an error message " + "will be returned. If an error is returned, rewrite the query, check the " + "query, and try again. If you encounter an issue with Unknown column " + "'xxxx' in 'field list', using schema_sql_db to query the correct table " + "fields." + ) + info_sql_database_tool_description = ( + "Input to this tool is a comma-separated list of tables, output is the " + "schema and sample rows for those tables. " + "Be sure that the tables actually exist by calling list_tables_sql_db " + "first! Example Input: 'table1, table2, table3'" + ) return [ - QuerySQLDataBaseTool(db=self.db), - InfoSQLDatabaseTool(db=self.db), + QuerySQLDataBaseTool( + db=self.db, description=query_sql_database_tool_description + ), + InfoSQLDatabaseTool( + db=self.db, description=info_sql_database_tool_description + ), ListSQLDatabaseTool(db=self.db), QueryCheckerTool(db=self.db, llm=self.llm), ] diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index 2e677c6c..9edad760 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -61,8 +61,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): name = "schema_sql_db" description = """ - Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. - Be sure that the tables actually exist by calling list_tables_sql_db first! + Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Example Input: "table1, table2, table3" """