From 12861273e12cd97e173e27c6319fb68f0739747a Mon Sep 17 00:00:00 2001 From: Kirushikesh DB <49152921+Kirushikesh@users.noreply.github.com> Date: Fri, 29 Mar 2024 13:52:35 +0530 Subject: [PATCH] experimental[patch]: Removed 'SQLResults:' from the LLMResponse in SQLDatabaseChain (#17104) **Description:** When using the SQLDatabaseChain with Llama2-70b LLM and, SQLite database. I was getting `Warning: You can only execute one statement at a time.`. ``` from langchain.sql_database import SQLDatabase from langchain_experimental.sql import SQLDatabaseChain sql_database_path = '/dccstor/mmdataretrieval/mm_dataset/swimming_record/rag_data/swimmingdataset.db' sql_db = get_database(sql_database_path) db_chain = SQLDatabaseChain.from_llm(mistral, sql_db, verbose=True, callbacks = [callback_obj]) db_chain.invoke({ "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?" }) ``` Error: ``` Warning Traceback (most recent call last) Cell In[31], line 3 1 import langchain 2 langchain.debug=False ----> 3 db_chain.invoke({ 4 "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?" 5 }) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:162, in Chain.invoke(self, input, config, **kwargs) 160 except BaseException as e: 161 run_manager.on_chain_error(e) --> 162 raise e 163 run_manager.on_chain_end(outputs) 164 final_outputs: Dict[str, Any] = self.prep_outputs( 165 inputs, outputs, return_only_outputs 166 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:156, in Chain.invoke(self, input, config, **kwargs) 149 run_manager = callback_manager.on_chain_start( 150 dumpd(self), 151 inputs, 152 name=run_name, 153 ) 154 try: 155 outputs = ( --> 156 self._call(inputs, run_manager=run_manager) 157 if new_arg_supported 158 else self._call(inputs) 159 ) 160 except BaseException as e: 161 run_manager.on_chain_error(e) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:198, in SQLDatabaseChain._call(self, inputs, run_manager) 194 except Exception as exc: 195 # Append intermediate steps to exception, to aid in logging and later 196 # improvement of few shot prompt seeds 197 exc.intermediate_steps = intermediate_steps # type: ignore --> 198 raise exc File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:143, in SQLDatabaseChain._call(self, inputs, run_manager) 139 intermediate_steps.append( 140 sql_cmd 141 ) # output: sql generation (no checker) 142 intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec --> 143 result = self.database.run(sql_cmd) 144 intermediate_steps.append(str(result)) # output: sql exec 145 else: File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:436, in SQLDatabase.run(self, command, fetch, include_columns) 425 def run( 426 self, 427 command: str, 428 fetch: Literal["all", "one"] = "all", 429 include_columns: bool = False, 430 ) -> str: 431 """Execute a SQL command and return a string representing the results. 432 433 If the statement returns rows, a string of the results is returned. 434 If the statement returns no rows, an empty string is returned. 435 """ --> 436 result = self._execute(command, fetch) 438 res = [ 439 { 440 column: truncate_word(value, length=self._max_string_length) (...) 443 for r in result 444 ] 446 if not include_columns: File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:413, in SQLDatabase._execute(self, command, fetch) 410 elif self.dialect == "postgresql": # postgresql 411 connection.exec_driver_sql("SET search_path TO %s", (self._schema,)) --> 413 cursor = connection.execute(text(command)) 414 if cursor.returns_rows: 415 if fetch == "all": File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1416, in Connection.execute(self, statement, parameters, execution_options) 1414 raise exc.ObjectNotExecutableError(statement) from err 1415 else: -> 1416 return meth( 1417 self, 1418 distilled_parameters, 1419 execution_options or NO_OPTIONS, 1420 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/sql/elements.py:516, in ClauseElement._execute_on_connection(self, connection, distilled_params, execution_options) 514 if TYPE_CHECKING: 515 assert isinstance(self, Executable) --> 516 return connection._execute_clauseelement( 517 self, distilled_params, execution_options 518 ) 519 else: 520 raise exc.ObjectNotExecutableError(self) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1639, in Connection._execute_clauseelement(self, elem, distilled_parameters, execution_options) 1627 compiled_cache: Optional[CompiledCacheType] = execution_options.get( 1628 "compiled_cache", self.engine._compiled_cache 1629 ) 1631 compiled_sql, extracted_params, cache_hit = elem._compile_w_cache( 1632 dialect=dialect, 1633 compiled_cache=compiled_cache, (...) 1637 linting=self.dialect.compiler_linting | compiler.WARN_LINTING, 1638 ) -> 1639 ret = self._execute_context( 1640 dialect, 1641 dialect.execution_ctx_cls._init_compiled, 1642 compiled_sql, 1643 distilled_parameters, 1644 execution_options, 1645 compiled_sql, 1646 distilled_parameters, 1647 elem, 1648 extracted_params, 1649 cache_hit=cache_hit, 1650 ) 1651 if has_events: 1652 self.dispatch.after_execute( 1653 self, 1654 elem, (...) 1658 ret, 1659 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1848, in Connection._execute_context(self, dialect, constructor, statement, parameters, execution_options, *args, **kw) 1843 return self._exec_insertmany_context( 1844 dialect, 1845 context, 1846 ) 1847 else: -> 1848 return self._exec_single_context( 1849 dialect, context, statement, parameters 1850 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1988, in Connection._exec_single_context(self, dialect, context, statement, parameters) 1985 result = context._setup_result_proxy() 1987 except BaseException as e: -> 1988 self._handle_dbapi_exception( 1989 e, str_statement, effective_parameters, cursor, context 1990 ) 1992 return result File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:2346, in Connection._handle_dbapi_exception(self, e, statement, parameters, cursor, context, is_sub_exec) 2344 else: 2345 assert exc_info[1] is not None -> 2346 raise exc_info[1].with_traceback(exc_info[2]) 2347 finally: 2348 del self._reentrant_error File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1969, in Connection._exec_single_context(self, dialect, context, statement, parameters) 1967 break 1968 if not evt_handled: -> 1969 self.dialect.do_execute( 1970 cursor, str_statement, effective_parameters, context 1971 ) 1973 if self._has_events or self.engine._has_events: 1974 self.dispatch.after_cursor_execute( 1975 self, 1976 cursor, (...) 1980 context.executemany, 1981 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/default.py:922, in DefaultDialect.do_execute(self, cursor, statement, parameters, context) 921 def do_execute(self, cursor, statement, parameters, context=None): --> 922 cursor.execute(statement, parameters) Warning: You can only execute one statement at a time. ``` **Issue:** The Error occurs because when generating the SQLQuery, the llm_input includes the stop character of "\nSQLResult:", so for this user query the LLM generated response is **SELECT Time FROM men_butterfly_100m WHERE Swimmer = 'Lance Larson';\nSQLResult:** it is required to remove the SQLResult suffix on the llm response before executing it on the database. ``` llm_inputs = { "input": input_text, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": table_info, "stop": ["\nSQLResult:"], } sql_cmd = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs, ).strip() if SQL_RESULT in sql_cmd: sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip() result = self.database.run(sql_cmd) ``` --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- libs/experimental/langchain_experimental/sql/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libs/experimental/langchain_experimental/sql/base.py b/libs/experimental/langchain_experimental/sql/base.py index b6de5e67c9..7376b08115 100644 --- a/libs/experimental/langchain_experimental/sql/base.py +++ b/libs/experimental/langchain_experimental/sql/base.py @@ -18,6 +18,7 @@ from langchain_experimental.pydantic_v1 import Extra, Field, root_validator INTERMEDIATE_STEPS_KEY = "intermediate_steps" SQL_QUERY = "SQLQuery:" +SQL_RESULT = "SQLResult:" class SQLDatabaseChain(Chain): @@ -143,6 +144,8 @@ class SQLDatabaseChain(Chain): intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec if SQL_QUERY in sql_cmd: sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip() + if SQL_RESULT in sql_cmd: + sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip() result = self.database.run(sql_cmd) intermediate_steps.append(str(result)) # output: sql exec else: