From 94ae126747974a005d8fdf653100c934351626d9 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 30 Jan 2023 15:10:48 -0800 Subject: [PATCH] return sql intermediate steps (#792) --- docs/modules/chains/examples/sqlite.ipynb | 62 +++++++++++++++++++++-- langchain/chains/sql_database/base.py | 17 +++++-- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index 7a6c8180c8..548ca4c693 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "15ff81df", "metadata": { "pycharm": { @@ -96,7 +96,7 @@ "' There are 9 employees.'" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -188,6 +188,62 @@ "db_chain.run(\"How many employees are there in the foobar table?\")" ] }, + { + "cell_type": "markdown", + "id": "88d8b969", + "metadata": {}, + "source": [ + "## Return Intermediate Steps\n", + "\n", + "You can also return the intermediate steps of the SQLDatabaseChain. This allows you to access the SQL statement that was generated, as well as the result of running that against the SQL Database." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "38559487", + "metadata": {}, + "outputs": [], + "source": [ + "db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "78b6af4d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n", + "How many employees are there in the foobar table? \n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n", + "SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "[' SELECT COUNT(*) FROM Employee;', '[(9,)]']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result = db_chain(\"How many employees are there in the foobar table?\")\n", + "result[\"intermediate_steps\"]" + ] + }, { "cell_type": "markdown", "id": "b408f800", @@ -405,7 +461,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index dd9bbe8454..0c35acbe69 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -34,6 +34,7 @@ class SQLDatabaseChain(Chain, BaseModel): """Number of results to return from the query""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: + return_intermediate_steps: bool = False class Config: """Configuration for this pydantic object.""" @@ -55,9 +56,12 @@ class SQLDatabaseChain(Chain, BaseModel): :meta private: """ - return [self.output_key] + if not self.return_intermediate_steps: + return [self.output_key] + else: + return [self.output_key, "intermediate_steps"] - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) input_text = f"{inputs[self.input_key]} \nSQLQuery:" self.callback_manager.on_text(input_text, verbose=self.verbose) @@ -71,10 +75,12 @@ class SQLDatabaseChain(Chain, BaseModel): "table_info": table_info, "stop": ["\nSQLResult:"], } - + intermediate_steps = [] sql_cmd = llm_chain.predict(**llm_inputs) + intermediate_steps.append(sql_cmd) self.callback_manager.on_text(sql_cmd, color="green", verbose=self.verbose) result = self.database.run(sql_cmd) + intermediate_steps.append(result) self.callback_manager.on_text("\nSQLResult: ", verbose=self.verbose) self.callback_manager.on_text(result, color="yellow", verbose=self.verbose) self.callback_manager.on_text("\nAnswer:", verbose=self.verbose) @@ -82,7 +88,10 @@ class SQLDatabaseChain(Chain, BaseModel): llm_inputs["input"] = input_text final_result = llm_chain.predict(**llm_inputs) self.callback_manager.on_text(final_result, color="green", verbose=self.verbose) - return {self.output_key: final_result} + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result["intermediate_steps"] = intermediate_steps + return chain_result @property def _chain_type(self) -> str: