diff --git a/libs/experimental/langchain_experimental/sql/base.py b/libs/experimental/langchain_experimental/sql/base.py index b6de5e67c9..7376b08115 100644 --- a/libs/experimental/langchain_experimental/sql/base.py +++ b/libs/experimental/langchain_experimental/sql/base.py @@ -18,6 +18,7 @@ from langchain_experimental.pydantic_v1 import Extra, Field, root_validator INTERMEDIATE_STEPS_KEY = "intermediate_steps" SQL_QUERY = "SQLQuery:" +SQL_RESULT = "SQLResult:" class SQLDatabaseChain(Chain): @@ -143,6 +144,8 @@ class SQLDatabaseChain(Chain): intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec if SQL_QUERY in sql_cmd: sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip() + if SQL_RESULT in sql_cmd: + sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip() result = self.database.run(sql_cmd) intermediate_steps.append(str(result)) # output: sql exec else: