return sql intermediate steps (#792)

This commit is contained in:
Harrison Chase 2023-01-30 15:10:48 -08:00 committed by GitHub
parent ae5695ad32
commit 94ae126747
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 7 deletions

View File

@ -68,7 +68,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "15ff81df",
"metadata": {
"pycharm": {
@ -96,7 +96,7 @@
"' There are 9 employees.'"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -188,6 +188,62 @@
"db_chain.run(\"How many employees are there in the foobar table?\")"
]
},
{
"cell_type": "markdown",
"id": "88d8b969",
"metadata": {},
"source": [
"## Return Intermediate Steps\n",
"\n",
"You can also return the intermediate steps of the SQLDatabaseChain. This allows you to access the SQL statement that was generated, as well as the result of running that against the SQL Database."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "38559487",
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "78b6af4d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there in the foobar table? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"[' SELECT COUNT(*) FROM Employee;', '[(9,)]']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = db_chain(\"How many employees are there in the foobar table?\")\n",
"result[\"intermediate_steps\"]"
]
},
{
"cell_type": "markdown",
"id": "b408f800",
@ -405,7 +461,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.10.9"
}
},
"nbformat": 4,

View File

@ -34,6 +34,7 @@ class SQLDatabaseChain(Chain, BaseModel):
"""Number of results to return from the query"""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_intermediate_steps: bool = False
class Config:
"""Configuration for this pydantic object."""
@ -55,9 +56,12 @@ class SQLDatabaseChain(Chain, BaseModel):
:meta private:
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, "intermediate_steps"]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
self.callback_manager.on_text(input_text, verbose=self.verbose)
@ -71,10 +75,12 @@ class SQLDatabaseChain(Chain, BaseModel):
"table_info": table_info,
"stop": ["\nSQLResult:"],
}
intermediate_steps = []
sql_cmd = llm_chain.predict(**llm_inputs)
intermediate_steps.append(sql_cmd)
self.callback_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
result = self.database.run(sql_cmd)
intermediate_steps.append(result)
self.callback_manager.on_text("\nSQLResult: ", verbose=self.verbose)
self.callback_manager.on_text(result, color="yellow", verbose=self.verbose)
self.callback_manager.on_text("\nAnswer:", verbose=self.verbose)
@ -82,7 +88,10 @@ class SQLDatabaseChain(Chain, BaseModel):
llm_inputs["input"] = input_text
final_result = llm_chain.predict(**llm_inputs)
self.callback_manager.on_text(final_result, color="green", verbose=self.verbose)
return {self.output_key: final_result}
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result["intermediate_steps"] = intermediate_steps
return chain_result
@property
def _chain_type(self) -> str: