mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
12861273e1
**Description:** When using the SQLDatabaseChain with Llama2-70b LLM and, SQLite database. I was getting `Warning: You can only execute one statement at a time.`. ``` from langchain.sql_database import SQLDatabase from langchain_experimental.sql import SQLDatabaseChain sql_database_path = '/dccstor/mmdataretrieval/mm_dataset/swimming_record/rag_data/swimmingdataset.db' sql_db = get_database(sql_database_path) db_chain = SQLDatabaseChain.from_llm(mistral, sql_db, verbose=True, callbacks = [callback_obj]) db_chain.invoke({ "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?" }) ``` Error: ``` Warning Traceback (most recent call last) Cell In[31], line 3 1 import langchain 2 langchain.debug=False ----> 3 db_chain.invoke({ 4 "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?" 5 }) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:162, in Chain.invoke(self, input, config, **kwargs) 160 except BaseException as e: 161 run_manager.on_chain_error(e) --> 162 raise e 163 run_manager.on_chain_end(outputs) 164 final_outputs: Dict[str, Any] = self.prep_outputs( 165 inputs, outputs, return_only_outputs 166 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:156, in Chain.invoke(self, input, config, **kwargs) 149 run_manager = callback_manager.on_chain_start( 150 dumpd(self), 151 inputs, 152 name=run_name, 153 ) 154 try: 155 outputs = ( --> 156 self._call(inputs, run_manager=run_manager) 157 if new_arg_supported 158 else self._call(inputs) 159 ) 160 except BaseException as e: 161 run_manager.on_chain_error(e) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:198, in SQLDatabaseChain._call(self, inputs, run_manager) 194 except Exception as exc: 195 # Append intermediate steps to exception, to aid in logging and later 196 # improvement of few shot prompt seeds 197 exc.intermediate_steps = intermediate_steps # type: ignore --> 198 raise exc File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:143, in SQLDatabaseChain._call(self, inputs, run_manager) 139 intermediate_steps.append( 140 sql_cmd 141 ) # output: sql generation (no checker) 142 intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec --> 143 result = self.database.run(sql_cmd) 144 intermediate_steps.append(str(result)) # output: sql exec 145 else: File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:436, in SQLDatabase.run(self, command, fetch, include_columns) 425 def run( 426 self, 427 command: str, 428 fetch: Literal["all", "one"] = "all", 429 include_columns: bool = False, 430 ) -> str: 431 """Execute a SQL command and return a string representing the results. 432 433 If the statement returns rows, a string of the results is returned. 434 If the statement returns no rows, an empty string is returned. 435 """ --> 436 result = self._execute(command, fetch) 438 res = [ 439 { 440 column: truncate_word(value, length=self._max_string_length) (...) 443 for r in result 444 ] 446 if not include_columns: File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:413, in SQLDatabase._execute(self, command, fetch) 410 elif self.dialect == "postgresql": # postgresql 411 connection.exec_driver_sql("SET search_path TO %s", (self._schema,)) --> 413 cursor = connection.execute(text(command)) 414 if cursor.returns_rows: 415 if fetch == "all": File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1416, in Connection.execute(self, statement, parameters, execution_options) 1414 raise exc.ObjectNotExecutableError(statement) from err 1415 else: -> 1416 return meth( 1417 self, 1418 distilled_parameters, 1419 execution_options or NO_OPTIONS, 1420 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/sql/elements.py:516, in ClauseElement._execute_on_connection(self, connection, distilled_params, execution_options) 514 if TYPE_CHECKING: 515 assert isinstance(self, Executable) --> 516 return connection._execute_clauseelement( 517 self, distilled_params, execution_options 518 ) 519 else: 520 raise exc.ObjectNotExecutableError(self) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1639, in Connection._execute_clauseelement(self, elem, distilled_parameters, execution_options) 1627 compiled_cache: Optional[CompiledCacheType] = execution_options.get( 1628 "compiled_cache", self.engine._compiled_cache 1629 ) 1631 compiled_sql, extracted_params, cache_hit = elem._compile_w_cache( 1632 dialect=dialect, 1633 compiled_cache=compiled_cache, (...) 1637 linting=self.dialect.compiler_linting | compiler.WARN_LINTING, 1638 ) -> 1639 ret = self._execute_context( 1640 dialect, 1641 dialect.execution_ctx_cls._init_compiled, 1642 compiled_sql, 1643 distilled_parameters, 1644 execution_options, 1645 compiled_sql, 1646 distilled_parameters, 1647 elem, 1648 extracted_params, 1649 cache_hit=cache_hit, 1650 ) 1651 if has_events: 1652 self.dispatch.after_execute( 1653 self, 1654 elem, (...) 1658 ret, 1659 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1848, in Connection._execute_context(self, dialect, constructor, statement, parameters, execution_options, *args, **kw) 1843 return self._exec_insertmany_context( 1844 dialect, 1845 context, 1846 ) 1847 else: -> 1848 return self._exec_single_context( 1849 dialect, context, statement, parameters 1850 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1988, in Connection._exec_single_context(self, dialect, context, statement, parameters) 1985 result = context._setup_result_proxy() 1987 except BaseException as e: -> 1988 self._handle_dbapi_exception( 1989 e, str_statement, effective_parameters, cursor, context 1990 ) 1992 return result File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:2346, in Connection._handle_dbapi_exception(self, e, statement, parameters, cursor, context, is_sub_exec) 2344 else: 2345 assert exc_info[1] is not None -> 2346 raise exc_info[1].with_traceback(exc_info[2]) 2347 finally: 2348 del self._reentrant_error File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1969, in Connection._exec_single_context(self, dialect, context, statement, parameters) 1967 break 1968 if not evt_handled: -> 1969 self.dialect.do_execute( 1970 cursor, str_statement, effective_parameters, context 1971 ) 1973 if self._has_events or self.engine._has_events: 1974 self.dispatch.after_cursor_execute( 1975 self, 1976 cursor, (...) 1980 context.executemany, 1981 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/default.py:922, in DefaultDialect.do_execute(self, cursor, statement, parameters, context) 921 def do_execute(self, cursor, statement, parameters, context=None): --> 922 cursor.execute(statement, parameters) Warning: You can only execute one statement at a time. ``` **Issue:** The Error occurs because when generating the SQLQuery, the llm_input includes the stop character of "\nSQLResult:", so for this user query the LLM generated response is **SELECT Time FROM men_butterfly_100m WHERE Swimmer = 'Lance Larson';\nSQLResult:** it is required to remove the SQLResult suffix on the llm response before executing it on the database. ``` llm_inputs = { "input": input_text, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": table_info, "stop": ["\nSQLResult:"], } sql_cmd = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs, ).strip() if SQL_RESULT in sql_cmd: sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip() result = self.database.run(sql_cmd) ``` <!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
319 lines
13 KiB
Python
319 lines
13 KiB
Python
"""Chain for interacting with SQL Database."""
|
|
from __future__ import annotations
|
|
|
|
import warnings
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
|
from langchain.prompts.prompt import PromptTemplate
|
|
from langchain.schema import BasePromptTemplate
|
|
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
|
from langchain_community.utilities.sql_database import SQLDatabase
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
|
|
|
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
|
SQL_QUERY = "SQLQuery:"
|
|
SQL_RESULT = "SQLResult:"
|
|
|
|
|
|
class SQLDatabaseChain(Chain):
|
|
"""Chain for interacting with SQL Database.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_experimental.sql import SQLDatabaseChain
|
|
from langchain_community.llms import OpenAI, SQLDatabase
|
|
db = SQLDatabase(...)
|
|
db_chain = SQLDatabaseChain.from_llm(OpenAI(), db)
|
|
|
|
*Security note*: Make sure that the database connection uses credentials
|
|
that are narrowly-scoped to only include the permissions this chain needs.
|
|
Failure to do so may result in data corruption or loss, since this chain may
|
|
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
|
|
The best way to guard against such negative outcomes is to (as appropriate)
|
|
limit the permissions granted to the credentials used with this chain.
|
|
This issue shows an example negative outcome if these steps are not taken:
|
|
https://github.com/langchain-ai/langchain/issues/5923
|
|
"""
|
|
|
|
llm_chain: LLMChain
|
|
llm: Optional[BaseLanguageModel] = None
|
|
"""[Deprecated] LLM wrapper to use."""
|
|
database: SQLDatabase = Field(exclude=True)
|
|
"""SQL Database to connect to."""
|
|
prompt: Optional[BasePromptTemplate] = None
|
|
"""[Deprecated] Prompt to use to translate natural language to SQL."""
|
|
top_k: int = 5
|
|
"""Number of results to return from the query"""
|
|
input_key: str = "query" #: :meta private:
|
|
output_key: str = "result" #: :meta private:
|
|
return_sql: bool = False
|
|
"""Will return sql-command directly without executing it"""
|
|
return_intermediate_steps: bool = False
|
|
"""Whether or not to return the intermediate steps along with the final answer."""
|
|
return_direct: bool = False
|
|
"""Whether or not to return the result of querying the SQL table directly."""
|
|
use_query_checker: bool = False
|
|
"""Whether or not the query checker tool should be used to attempt
|
|
to fix the initial SQL from the LLM."""
|
|
query_checker_prompt: Optional[BasePromptTemplate] = None
|
|
"""The prompt template that should be used by the query checker"""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@root_validator(pre=True)
|
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
if "llm" in values:
|
|
warnings.warn(
|
|
"Directly instantiating an SQLDatabaseChain with an llm is deprecated. "
|
|
"Please instantiate with llm_chain argument or using the from_llm "
|
|
"class method."
|
|
)
|
|
if "llm_chain" not in values and values["llm"] is not None:
|
|
database = values["database"]
|
|
prompt = values.get("prompt") or SQL_PROMPTS.get(
|
|
database.dialect, PROMPT
|
|
)
|
|
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
|
return values
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Return the singular input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return the singular output key.
|
|
|
|
:meta private:
|
|
"""
|
|
if not self.return_intermediate_steps:
|
|
return [self.output_key]
|
|
else:
|
|
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
|
|
|
def _call(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
input_text = f"{inputs[self.input_key]}\n{SQL_QUERY}"
|
|
_run_manager.on_text(input_text, verbose=self.verbose)
|
|
# If not present, then defaults to None which is all tables.
|
|
table_names_to_use = inputs.get("table_names_to_use")
|
|
table_info = self.database.get_table_info(table_names=table_names_to_use)
|
|
llm_inputs = {
|
|
"input": input_text,
|
|
"top_k": str(self.top_k),
|
|
"dialect": self.database.dialect,
|
|
"table_info": table_info,
|
|
"stop": ["\nSQLResult:"],
|
|
}
|
|
if self.memory is not None:
|
|
for k in self.memory.memory_variables:
|
|
llm_inputs[k] = inputs[k]
|
|
intermediate_steps: List = []
|
|
try:
|
|
intermediate_steps.append(llm_inputs.copy()) # input: sql generation
|
|
sql_cmd = self.llm_chain.predict(
|
|
callbacks=_run_manager.get_child(),
|
|
**llm_inputs,
|
|
).strip()
|
|
if self.return_sql:
|
|
return {self.output_key: sql_cmd}
|
|
if not self.use_query_checker:
|
|
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
|
|
intermediate_steps.append(
|
|
sql_cmd
|
|
) # output: sql generation (no checker)
|
|
intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec
|
|
if SQL_QUERY in sql_cmd:
|
|
sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip()
|
|
if SQL_RESULT in sql_cmd:
|
|
sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip()
|
|
result = self.database.run(sql_cmd)
|
|
intermediate_steps.append(str(result)) # output: sql exec
|
|
else:
|
|
query_checker_prompt = self.query_checker_prompt or PromptTemplate(
|
|
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
|
)
|
|
query_checker_chain = LLMChain(
|
|
llm=self.llm_chain.llm, prompt=query_checker_prompt
|
|
)
|
|
query_checker_inputs = {
|
|
"query": sql_cmd,
|
|
"dialect": self.database.dialect,
|
|
}
|
|
checked_sql_command: str = query_checker_chain.predict(
|
|
callbacks=_run_manager.get_child(), **query_checker_inputs
|
|
).strip()
|
|
intermediate_steps.append(
|
|
checked_sql_command
|
|
) # output: sql generation (checker)
|
|
_run_manager.on_text(
|
|
checked_sql_command, color="green", verbose=self.verbose
|
|
)
|
|
intermediate_steps.append(
|
|
{"sql_cmd": checked_sql_command}
|
|
) # input: sql exec
|
|
result = self.database.run(checked_sql_command)
|
|
intermediate_steps.append(str(result)) # output: sql exec
|
|
sql_cmd = checked_sql_command
|
|
|
|
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
|
_run_manager.on_text(str(result), color="yellow", verbose=self.verbose)
|
|
# If return direct, we just set the final result equal to
|
|
# the result of the sql query result, otherwise try to get a human readable
|
|
# final answer
|
|
if self.return_direct:
|
|
final_result = result
|
|
else:
|
|
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
|
|
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
|
llm_inputs["input"] = input_text
|
|
intermediate_steps.append(llm_inputs.copy()) # input: final answer
|
|
final_result = self.llm_chain.predict(
|
|
callbacks=_run_manager.get_child(),
|
|
**llm_inputs,
|
|
).strip()
|
|
intermediate_steps.append(final_result) # output: final answer
|
|
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
|
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
|
if self.return_intermediate_steps:
|
|
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
|
return chain_result
|
|
except Exception as exc:
|
|
# Append intermediate steps to exception, to aid in logging and later
|
|
# improvement of few shot prompt seeds
|
|
exc.intermediate_steps = intermediate_steps # type: ignore
|
|
raise exc
|
|
|
|
@property
|
|
def _chain_type(self) -> str:
|
|
return "sql_database_chain"
|
|
|
|
@classmethod
|
|
def from_llm(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
db: SQLDatabase,
|
|
prompt: Optional[BasePromptTemplate] = None,
|
|
**kwargs: Any,
|
|
) -> SQLDatabaseChain:
|
|
"""Create a SQLDatabaseChain from an LLM and a database connection.
|
|
|
|
*Security note*: Make sure that the database connection uses credentials
|
|
that are narrowly-scoped to only include the permissions this chain needs.
|
|
Failure to do so may result in data corruption or loss, since this chain may
|
|
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
|
|
The best way to guard against such negative outcomes is to (as appropriate)
|
|
limit the permissions granted to the credentials used with this chain.
|
|
This issue shows an example negative outcome if these steps are not taken:
|
|
https://github.com/langchain-ai/langchain/issues/5923
|
|
"""
|
|
prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
return cls(llm_chain=llm_chain, database=db, **kwargs)
|
|
|
|
|
|
class SQLDatabaseSequentialChain(Chain):
|
|
"""Chain for querying SQL database that is a sequential chain.
|
|
|
|
The chain is as follows:
|
|
1. Based on the query, determine which tables to use.
|
|
2. Based on those tables, call the normal SQL database chain.
|
|
|
|
This is useful in cases where the number of tables in the database is large.
|
|
"""
|
|
|
|
decider_chain: LLMChain
|
|
sql_chain: SQLDatabaseChain
|
|
input_key: str = "query" #: :meta private:
|
|
output_key: str = "result" #: :meta private:
|
|
return_intermediate_steps: bool = False
|
|
|
|
@classmethod
|
|
def from_llm(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
db: SQLDatabase,
|
|
query_prompt: BasePromptTemplate = PROMPT,
|
|
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
|
|
**kwargs: Any,
|
|
) -> SQLDatabaseSequentialChain:
|
|
"""Load the necessary chains."""
|
|
sql_chain = SQLDatabaseChain.from_llm(llm, db, prompt=query_prompt, **kwargs)
|
|
decider_chain = LLMChain(
|
|
llm=llm, prompt=decider_prompt, output_key="table_names"
|
|
)
|
|
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Return the singular input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return the singular output key.
|
|
|
|
:meta private:
|
|
"""
|
|
if not self.return_intermediate_steps:
|
|
return [self.output_key]
|
|
else:
|
|
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
|
|
|
def _call(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
_table_names = self.sql_chain.database.get_usable_table_names()
|
|
table_names = ", ".join(_table_names)
|
|
llm_inputs = {
|
|
"query": inputs[self.input_key],
|
|
"table_names": table_names,
|
|
}
|
|
_lowercased_table_names = [name.lower() for name in _table_names]
|
|
table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs)
|
|
table_names_to_use = [
|
|
name
|
|
for name in table_names_from_chain
|
|
if name.lower() in _lowercased_table_names
|
|
]
|
|
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
|
|
_run_manager.on_text(
|
|
str(table_names_to_use), color="yellow", verbose=self.verbose
|
|
)
|
|
new_inputs = {
|
|
self.sql_chain.input_key: inputs[self.input_key],
|
|
"table_names_to_use": table_names_to_use,
|
|
}
|
|
return self.sql_chain(
|
|
new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True
|
|
)
|
|
|
|
@property
|
|
def _chain_type(self) -> str:
|
|
return "sql_database_sequential_chain"
|