langchain/libs/experimental/langchain_experimental/sql/base.py

319 lines
13 KiB
Python
Raw Normal View History

"""Chain for interacting with SQL Database."""
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
2023-07-28 05:00:52 +00:00
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.schema import BasePromptTemplate
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.prompt import PromptTemplate
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
2023-07-22 01:44:32 +00:00
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
SQL_QUERY = "SQLQuery:"
experimental[patch]: Removed 'SQLResults:' from the LLMResponse in SQLDatabaseChain (#17104) **Description:** When using the SQLDatabaseChain with Llama2-70b LLM and, SQLite database. I was getting `Warning: You can only execute one statement at a time.`. ``` from langchain.sql_database import SQLDatabase from langchain_experimental.sql import SQLDatabaseChain sql_database_path = '/dccstor/mmdataretrieval/mm_dataset/swimming_record/rag_data/swimmingdataset.db' sql_db = get_database(sql_database_path) db_chain = SQLDatabaseChain.from_llm(mistral, sql_db, verbose=True, callbacks = [callback_obj]) db_chain.invoke({ "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?" }) ``` Error: ``` Warning Traceback (most recent call last) Cell In[31], line 3 1 import langchain 2 langchain.debug=False ----> 3 db_chain.invoke({ 4 "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?" 5 }) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:162, in Chain.invoke(self, input, config, **kwargs) 160 except BaseException as e: 161 run_manager.on_chain_error(e) --> 162 raise e 163 run_manager.on_chain_end(outputs) 164 final_outputs: Dict[str, Any] = self.prep_outputs( 165 inputs, outputs, return_only_outputs 166 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:156, in Chain.invoke(self, input, config, **kwargs) 149 run_manager = callback_manager.on_chain_start( 150 dumpd(self), 151 inputs, 152 name=run_name, 153 ) 154 try: 155 outputs = ( --> 156 self._call(inputs, run_manager=run_manager) 157 if new_arg_supported 158 else self._call(inputs) 159 ) 160 except BaseException as e: 161 run_manager.on_chain_error(e) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:198, in SQLDatabaseChain._call(self, inputs, run_manager) 194 except Exception as exc: 195 # Append intermediate steps to exception, to aid in logging and later 196 # improvement of few shot prompt seeds 197 exc.intermediate_steps = intermediate_steps # type: ignore --> 198 raise exc File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:143, in SQLDatabaseChain._call(self, inputs, run_manager) 139 intermediate_steps.append( 140 sql_cmd 141 ) # output: sql generation (no checker) 142 intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec --> 143 result = self.database.run(sql_cmd) 144 intermediate_steps.append(str(result)) # output: sql exec 145 else: File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:436, in SQLDatabase.run(self, command, fetch, include_columns) 425 def run( 426 self, 427 command: str, 428 fetch: Literal["all", "one"] = "all", 429 include_columns: bool = False, 430 ) -> str: 431 """Execute a SQL command and return a string representing the results. 432 433 If the statement returns rows, a string of the results is returned. 434 If the statement returns no rows, an empty string is returned. 435 """ --> 436 result = self._execute(command, fetch) 438 res = [ 439 { 440 column: truncate_word(value, length=self._max_string_length) (...) 443 for r in result 444 ] 446 if not include_columns: File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:413, in SQLDatabase._execute(self, command, fetch) 410 elif self.dialect == "postgresql": # postgresql 411 connection.exec_driver_sql("SET search_path TO %s", (self._schema,)) --> 413 cursor = connection.execute(text(command)) 414 if cursor.returns_rows: 415 if fetch == "all": File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1416, in Connection.execute(self, statement, parameters, execution_options) 1414 raise exc.ObjectNotExecutableError(statement) from err 1415 else: -> 1416 return meth( 1417 self, 1418 distilled_parameters, 1419 execution_options or NO_OPTIONS, 1420 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/sql/elements.py:516, in ClauseElement._execute_on_connection(self, connection, distilled_params, execution_options) 514 if TYPE_CHECKING: 515 assert isinstance(self, Executable) --> 516 return connection._execute_clauseelement( 517 self, distilled_params, execution_options 518 ) 519 else: 520 raise exc.ObjectNotExecutableError(self) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1639, in Connection._execute_clauseelement(self, elem, distilled_parameters, execution_options) 1627 compiled_cache: Optional[CompiledCacheType] = execution_options.get( 1628 "compiled_cache", self.engine._compiled_cache 1629 ) 1631 compiled_sql, extracted_params, cache_hit = elem._compile_w_cache( 1632 dialect=dialect, 1633 compiled_cache=compiled_cache, (...) 1637 linting=self.dialect.compiler_linting | compiler.WARN_LINTING, 1638 ) -> 1639 ret = self._execute_context( 1640 dialect, 1641 dialect.execution_ctx_cls._init_compiled, 1642 compiled_sql, 1643 distilled_parameters, 1644 execution_options, 1645 compiled_sql, 1646 distilled_parameters, 1647 elem, 1648 extracted_params, 1649 cache_hit=cache_hit, 1650 ) 1651 if has_events: 1652 self.dispatch.after_execute( 1653 self, 1654 elem, (...) 1658 ret, 1659 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1848, in Connection._execute_context(self, dialect, constructor, statement, parameters, execution_options, *args, **kw) 1843 return self._exec_insertmany_context( 1844 dialect, 1845 context, 1846 ) 1847 else: -> 1848 return self._exec_single_context( 1849 dialect, context, statement, parameters 1850 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1988, in Connection._exec_single_context(self, dialect, context, statement, parameters) 1985 result = context._setup_result_proxy() 1987 except BaseException as e: -> 1988 self._handle_dbapi_exception( 1989 e, str_statement, effective_parameters, cursor, context 1990 ) 1992 return result File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:2346, in Connection._handle_dbapi_exception(self, e, statement, parameters, cursor, context, is_sub_exec) 2344 else: 2345 assert exc_info[1] is not None -> 2346 raise exc_info[1].with_traceback(exc_info[2]) 2347 finally: 2348 del self._reentrant_error File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1969, in Connection._exec_single_context(self, dialect, context, statement, parameters) 1967 break 1968 if not evt_handled: -> 1969 self.dialect.do_execute( 1970 cursor, str_statement, effective_parameters, context 1971 ) 1973 if self._has_events or self.engine._has_events: 1974 self.dispatch.after_cursor_execute( 1975 self, 1976 cursor, (...) 1980 context.executemany, 1981 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/default.py:922, in DefaultDialect.do_execute(self, cursor, statement, parameters, context) 921 def do_execute(self, cursor, statement, parameters, context=None): --> 922 cursor.execute(statement, parameters) Warning: You can only execute one statement at a time. ``` **Issue:** The Error occurs because when generating the SQLQuery, the llm_input includes the stop character of "\nSQLResult:", so for this user query the LLM generated response is **SELECT Time FROM men_butterfly_100m WHERE Swimmer = 'Lance Larson';\nSQLResult:** it is required to remove the SQLResult suffix on the llm response before executing it on the database. ``` llm_inputs = { "input": input_text, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": table_info, "stop": ["\nSQLResult:"], } sql_cmd = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs, ).strip() if SQL_RESULT in sql_cmd: sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip() result = self.database.run(sql_cmd) ``` <!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
2024-03-29 08:22:35 +00:00
SQL_RESULT = "SQLResult:"
class SQLDatabaseChain(Chain):
"""Chain for interacting with SQL Database.
Example:
.. code-block:: python
2023-07-22 01:44:32 +00:00
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.llms import OpenAI, SQLDatabase
db = SQLDatabase(...)
db_chain = SQLDatabaseChain.from_llm(OpenAI(), db)
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include the permissions this chain needs.
Failure to do so may result in data corruption or loss, since this chain may
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this chain.
This issue shows an example negative outcome if these steps are not taken:
https://github.com/langchain-ai/langchain/issues/5923
"""
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
database: SQLDatabase = Field(exclude=True)
"""SQL Database to connect to."""
prompt: Optional[BasePromptTemplate] = None
"""[Deprecated] Prompt to use to translate natural language to SQL."""
top_k: int = 5
"""Number of results to return from the query"""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_sql: bool = False
"""Will return sql-command directly without executing it"""
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the SQL table directly."""
use_query_checker: bool = False
"""Whether or not the query checker tool should be used to attempt
to fix the initial SQL from the LLM."""
query_checker_prompt: Optional[BasePromptTemplate] = None
"""The prompt template that should be used by the query checker"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an SQLDatabaseChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
database = values["database"]
prompt = values.get("prompt") or SQL_PROMPTS.get(
database.dialect, PROMPT
)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
input_text = f"{inputs[self.input_key]}\n{SQL_QUERY}"
_run_manager.on_text(input_text, verbose=self.verbose)
# If not present, then defaults to None which is all tables.
table_names_to_use = inputs.get("table_names_to_use")
table_info = self.database.get_table_info(table_names=table_names_to_use)
llm_inputs = {
"input": input_text,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": table_info,
"stop": ["\nSQLResult:"],
}
if self.memory is not None:
for k in self.memory.memory_variables:
llm_inputs[k] = inputs[k]
intermediate_steps: List = []
try:
fix llm_inputs duplication problem in intermediate_steps in SQLDatabaseChain (#10279) Use `.copy()` to fix the bug that the first `llm_inputs` element is overwritten by the second `llm_inputs` element in `intermediate_steps`. ***Problem description:*** In [line 127]( https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L127C17-L127C17), the `llm_inputs` of the sql generation step is appended as the first element of `intermediate_steps`: ``` intermediate_steps.append(llm_inputs) # input: sql generation ``` However, `llm_inputs` is a mutable dict, it is updated in [line 179](https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/sql/base.py#L179) for the final answer step: ``` llm_inputs["input"] = input_text ``` Then, the updated `llm_inputs` is appended as another element of `intermediate_steps` in [line 180](https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L180): ``` intermediate_steps.append(llm_inputs) # input: final answer ``` As a result, the final `intermediate_steps` returned in [line 189](https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L189C43-L189C43) actually contains two same `llm_inputs` elements, i.e., the `llm_inputs` for the sql generation step overwritten by the one for final answer step by mistake. Users are not able to get the actual `llm_inputs` for the sql generation step from `intermediate_steps` Simply calling `.copy()` when appending `llm_inputs` to `intermediate_steps` can solve this problem.
2023-10-06 04:32:08 +00:00
intermediate_steps.append(llm_inputs.copy()) # input: sql generation
sql_cmd = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,
).strip()
if self.return_sql:
return {self.output_key: sql_cmd}
if not self.use_query_checker:
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
intermediate_steps.append(
sql_cmd
) # output: sql generation (no checker)
intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec
if SQL_QUERY in sql_cmd:
sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip()
experimental[patch]: Removed 'SQLResults:' from the LLMResponse in SQLDatabaseChain (#17104) **Description:** When using the SQLDatabaseChain with Llama2-70b LLM and, SQLite database. I was getting `Warning: You can only execute one statement at a time.`. ``` from langchain.sql_database import SQLDatabase from langchain_experimental.sql import SQLDatabaseChain sql_database_path = '/dccstor/mmdataretrieval/mm_dataset/swimming_record/rag_data/swimmingdataset.db' sql_db = get_database(sql_database_path) db_chain = SQLDatabaseChain.from_llm(mistral, sql_db, verbose=True, callbacks = [callback_obj]) db_chain.invoke({ "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?" }) ``` Error: ``` Warning Traceback (most recent call last) Cell In[31], line 3 1 import langchain 2 langchain.debug=False ----> 3 db_chain.invoke({ 4 "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?" 5 }) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:162, in Chain.invoke(self, input, config, **kwargs) 160 except BaseException as e: 161 run_manager.on_chain_error(e) --> 162 raise e 163 run_manager.on_chain_end(outputs) 164 final_outputs: Dict[str, Any] = self.prep_outputs( 165 inputs, outputs, return_only_outputs 166 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:156, in Chain.invoke(self, input, config, **kwargs) 149 run_manager = callback_manager.on_chain_start( 150 dumpd(self), 151 inputs, 152 name=run_name, 153 ) 154 try: 155 outputs = ( --> 156 self._call(inputs, run_manager=run_manager) 157 if new_arg_supported 158 else self._call(inputs) 159 ) 160 except BaseException as e: 161 run_manager.on_chain_error(e) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:198, in SQLDatabaseChain._call(self, inputs, run_manager) 194 except Exception as exc: 195 # Append intermediate steps to exception, to aid in logging and later 196 # improvement of few shot prompt seeds 197 exc.intermediate_steps = intermediate_steps # type: ignore --> 198 raise exc File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:143, in SQLDatabaseChain._call(self, inputs, run_manager) 139 intermediate_steps.append( 140 sql_cmd 141 ) # output: sql generation (no checker) 142 intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec --> 143 result = self.database.run(sql_cmd) 144 intermediate_steps.append(str(result)) # output: sql exec 145 else: File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:436, in SQLDatabase.run(self, command, fetch, include_columns) 425 def run( 426 self, 427 command: str, 428 fetch: Literal["all", "one"] = "all", 429 include_columns: bool = False, 430 ) -> str: 431 """Execute a SQL command and return a string representing the results. 432 433 If the statement returns rows, a string of the results is returned. 434 If the statement returns no rows, an empty string is returned. 435 """ --> 436 result = self._execute(command, fetch) 438 res = [ 439 { 440 column: truncate_word(value, length=self._max_string_length) (...) 443 for r in result 444 ] 446 if not include_columns: File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:413, in SQLDatabase._execute(self, command, fetch) 410 elif self.dialect == "postgresql": # postgresql 411 connection.exec_driver_sql("SET search_path TO %s", (self._schema,)) --> 413 cursor = connection.execute(text(command)) 414 if cursor.returns_rows: 415 if fetch == "all": File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1416, in Connection.execute(self, statement, parameters, execution_options) 1414 raise exc.ObjectNotExecutableError(statement) from err 1415 else: -> 1416 return meth( 1417 self, 1418 distilled_parameters, 1419 execution_options or NO_OPTIONS, 1420 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/sql/elements.py:516, in ClauseElement._execute_on_connection(self, connection, distilled_params, execution_options) 514 if TYPE_CHECKING: 515 assert isinstance(self, Executable) --> 516 return connection._execute_clauseelement( 517 self, distilled_params, execution_options 518 ) 519 else: 520 raise exc.ObjectNotExecutableError(self) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1639, in Connection._execute_clauseelement(self, elem, distilled_parameters, execution_options) 1627 compiled_cache: Optional[CompiledCacheType] = execution_options.get( 1628 "compiled_cache", self.engine._compiled_cache 1629 ) 1631 compiled_sql, extracted_params, cache_hit = elem._compile_w_cache( 1632 dialect=dialect, 1633 compiled_cache=compiled_cache, (...) 1637 linting=self.dialect.compiler_linting | compiler.WARN_LINTING, 1638 ) -> 1639 ret = self._execute_context( 1640 dialect, 1641 dialect.execution_ctx_cls._init_compiled, 1642 compiled_sql, 1643 distilled_parameters, 1644 execution_options, 1645 compiled_sql, 1646 distilled_parameters, 1647 elem, 1648 extracted_params, 1649 cache_hit=cache_hit, 1650 ) 1651 if has_events: 1652 self.dispatch.after_execute( 1653 self, 1654 elem, (...) 1658 ret, 1659 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1848, in Connection._execute_context(self, dialect, constructor, statement, parameters, execution_options, *args, **kw) 1843 return self._exec_insertmany_context( 1844 dialect, 1845 context, 1846 ) 1847 else: -> 1848 return self._exec_single_context( 1849 dialect, context, statement, parameters 1850 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1988, in Connection._exec_single_context(self, dialect, context, statement, parameters) 1985 result = context._setup_result_proxy() 1987 except BaseException as e: -> 1988 self._handle_dbapi_exception( 1989 e, str_statement, effective_parameters, cursor, context 1990 ) 1992 return result File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:2346, in Connection._handle_dbapi_exception(self, e, statement, parameters, cursor, context, is_sub_exec) 2344 else: 2345 assert exc_info[1] is not None -> 2346 raise exc_info[1].with_traceback(exc_info[2]) 2347 finally: 2348 del self._reentrant_error File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1969, in Connection._exec_single_context(self, dialect, context, statement, parameters) 1967 break 1968 if not evt_handled: -> 1969 self.dialect.do_execute( 1970 cursor, str_statement, effective_parameters, context 1971 ) 1973 if self._has_events or self.engine._has_events: 1974 self.dispatch.after_cursor_execute( 1975 self, 1976 cursor, (...) 1980 context.executemany, 1981 ) File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/default.py:922, in DefaultDialect.do_execute(self, cursor, statement, parameters, context) 921 def do_execute(self, cursor, statement, parameters, context=None): --> 922 cursor.execute(statement, parameters) Warning: You can only execute one statement at a time. ``` **Issue:** The Error occurs because when generating the SQLQuery, the llm_input includes the stop character of "\nSQLResult:", so for this user query the LLM generated response is **SELECT Time FROM men_butterfly_100m WHERE Swimmer = 'Lance Larson';\nSQLResult:** it is required to remove the SQLResult suffix on the llm response before executing it on the database. ``` llm_inputs = { "input": input_text, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": table_info, "stop": ["\nSQLResult:"], } sql_cmd = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs, ).strip() if SQL_RESULT in sql_cmd: sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip() result = self.database.run(sql_cmd) ``` <!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
2024-03-29 08:22:35 +00:00
if SQL_RESULT in sql_cmd:
sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip()
result = self.database.run(sql_cmd)
intermediate_steps.append(str(result)) # output: sql exec
else:
query_checker_prompt = self.query_checker_prompt or PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"]
)
query_checker_chain = LLMChain(
llm=self.llm_chain.llm, prompt=query_checker_prompt
)
query_checker_inputs = {
"query": sql_cmd,
"dialect": self.database.dialect,
}
checked_sql_command: str = query_checker_chain.predict(
callbacks=_run_manager.get_child(), **query_checker_inputs
).strip()
intermediate_steps.append(
checked_sql_command
) # output: sql generation (checker)
_run_manager.on_text(
checked_sql_command, color="green", verbose=self.verbose
)
intermediate_steps.append(
{"sql_cmd": checked_sql_command}
) # input: sql exec
result = self.database.run(checked_sql_command)
intermediate_steps.append(str(result)) # output: sql exec
sql_cmd = checked_sql_command
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
_run_manager.on_text(str(result), color="yellow", verbose=self.verbose)
# If return direct, we just set the final result equal to
# the result of the sql query result, otherwise try to get a human readable
# final answer
if self.return_direct:
final_result = result
else:
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text
fix llm_inputs duplication problem in intermediate_steps in SQLDatabaseChain (#10279) Use `.copy()` to fix the bug that the first `llm_inputs` element is overwritten by the second `llm_inputs` element in `intermediate_steps`. ***Problem description:*** In [line 127]( https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L127C17-L127C17), the `llm_inputs` of the sql generation step is appended as the first element of `intermediate_steps`: ``` intermediate_steps.append(llm_inputs) # input: sql generation ``` However, `llm_inputs` is a mutable dict, it is updated in [line 179](https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/sql/base.py#L179) for the final answer step: ``` llm_inputs["input"] = input_text ``` Then, the updated `llm_inputs` is appended as another element of `intermediate_steps` in [line 180](https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L180): ``` intermediate_steps.append(llm_inputs) # input: final answer ``` As a result, the final `intermediate_steps` returned in [line 189](https://github.com/langchain-ai/langchain/blob/c732d8fffd39d2b02bdc393c37d2ccdd48f7626d/libs/experimental/langchain_experimental/sql/base.py#L189C43-L189C43) actually contains two same `llm_inputs` elements, i.e., the `llm_inputs` for the sql generation step overwritten by the one for final answer step by mistake. Users are not able to get the actual `llm_inputs` for the sql generation step from `intermediate_steps` Simply calling `.copy()` when appending `llm_inputs` to `intermediate_steps` can solve this problem.
2023-10-06 04:32:08 +00:00
intermediate_steps.append(llm_inputs.copy()) # input: final answer
final_result = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,
).strip()
intermediate_steps.append(final_result) # output: final answer
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result
except Exception as exc:
# Append intermediate steps to exception, to aid in logging and later
# improvement of few shot prompt seeds
exc.intermediate_steps = intermediate_steps # type: ignore
raise exc
@property
def _chain_type(self) -> str:
return "sql_database_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> SQLDatabaseChain:
"""Create a SQLDatabaseChain from an LLM and a database connection.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include the permissions this chain needs.
Failure to do so may result in data corruption or loss, since this chain may
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this chain.
This issue shows an example negative outcome if these steps are not taken:
https://github.com/langchain-ai/langchain/issues/5923
"""
prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, database=db, **kwargs)
class SQLDatabaseSequentialChain(Chain):
"""Chain for querying SQL database that is a sequential chain.
The chain is as follows:
1. Based on the query, determine which tables to use.
2. Based on those tables, call the normal SQL database chain.
This is useful in cases where the number of tables in the database is large.
"""
decider_chain: LLMChain
sql_chain: SQLDatabaseChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_intermediate_steps: bool = False
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
db: SQLDatabase,
query_prompt: BasePromptTemplate = PROMPT,
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
**kwargs: Any,
) -> SQLDatabaseSequentialChain:
"""Load the necessary chains."""
sql_chain = SQLDatabaseChain.from_llm(llm, db, prompt=query_prompt, **kwargs)
decider_chain = LLMChain(
llm=llm, prompt=decider_prompt, output_key="table_names"
)
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_table_names = self.sql_chain.database.get_usable_table_names()
table_names = ", ".join(_table_names)
llm_inputs = {
"query": inputs[self.input_key],
"table_names": table_names,
}
_lowercased_table_names = [name.lower() for name in _table_names]
table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs)
table_names_to_use = [
name
for name in table_names_from_chain
if name.lower() in _lowercased_table_names
]
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(table_names_to_use), color="yellow", verbose=self.verbose
)
new_inputs = {
self.sql_chain.input_key: inputs[self.input_key],
"table_names_to_use": table_names_to_use,
}
return self.sql_chain(
new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True
)
@property
def _chain_type(self) -> str:
return "sql_database_sequential_chain"