|
|
|
@ -17,6 +17,7 @@ from langchain.utilities.sql_database import SQLDatabase
|
|
|
|
|
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
|
|
|
|
|
|
|
|
|
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
|
|
|
|
SQL_QUERY = "SQLQuery:"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SQLDatabaseChain(Chain):
|
|
|
|
@ -110,7 +111,7 @@ class SQLDatabaseChain(Chain):
|
|
|
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
|
|
|
input_text = f"{inputs[self.input_key]}\nSQLQuery:"
|
|
|
|
|
input_text = f"{inputs[self.input_key]}\n{SQL_QUERY}"
|
|
|
|
|
_run_manager.on_text(input_text, verbose=self.verbose)
|
|
|
|
|
# If not present, then defaults to None which is all tables.
|
|
|
|
|
table_names_to_use = inputs.get("table_names_to_use")
|
|
|
|
@ -140,6 +141,8 @@ class SQLDatabaseChain(Chain):
|
|
|
|
|
sql_cmd
|
|
|
|
|
) # output: sql generation (no checker)
|
|
|
|
|
intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec
|
|
|
|
|
if SQL_QUERY in sql_cmd:
|
|
|
|
|
sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip()
|
|
|
|
|
result = self.database.run(sql_cmd)
|
|
|
|
|
intermediate_steps.append(str(result)) # output: sql exec
|
|
|
|
|
else:
|
|
|
|
|