mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
return sql intermediate steps (#792)
This commit is contained in:
parent
ae5695ad32
commit
94ae126747
@ -68,7 +68,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"id": "15ff81df",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
@ -96,7 +96,7 @@
|
||||
"' There are 9 employees.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -188,6 +188,62 @@
|
||||
"db_chain.run(\"How many employees are there in the foobar table?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "88d8b969",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Return Intermediate Steps\n",
|
||||
"\n",
|
||||
"You can also return the intermediate steps of the SQLDatabaseChain. This allows you to access the SQL statement that was generated, as well as the result of running that against the SQL Database."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "38559487",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "78b6af4d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
|
||||
"How many employees are there in the foobar table? \n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[' SELECT COUNT(*) FROM Employee;', '[(9,)]']"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = db_chain(\"How many employees are there in the foobar table?\")\n",
|
||||
"result[\"intermediate_steps\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b408f800",
|
||||
@ -405,7 +461,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.16"
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -34,6 +34,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
"""Number of results to return from the query"""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
return_intermediate_steps: bool = False
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -55,9 +56,12 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
if not self.return_intermediate_steps:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, "intermediate_steps"]
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
|
||||
self.callback_manager.on_text(input_text, verbose=self.verbose)
|
||||
@ -71,10 +75,12 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
"table_info": table_info,
|
||||
"stop": ["\nSQLResult:"],
|
||||
}
|
||||
|
||||
intermediate_steps = []
|
||||
sql_cmd = llm_chain.predict(**llm_inputs)
|
||||
intermediate_steps.append(sql_cmd)
|
||||
self.callback_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
|
||||
result = self.database.run(sql_cmd)
|
||||
intermediate_steps.append(result)
|
||||
self.callback_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
||||
self.callback_manager.on_text(result, color="yellow", verbose=self.verbose)
|
||||
self.callback_manager.on_text("\nAnswer:", verbose=self.verbose)
|
||||
@ -82,7 +88,10 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
llm_inputs["input"] = input_text
|
||||
final_result = llm_chain.predict(**llm_inputs)
|
||||
self.callback_manager.on_text(final_result, color="green", verbose=self.verbose)
|
||||
return {self.output_key: final_result}
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result["intermediate_steps"] = intermediate_steps
|
||||
return chain_result
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user