langchain/libs/experimental/langchain_experimental/sql/base.py
Kirushikesh DB 12861273e1
experimental[patch]: Removed 'SQLResults:' from the LLMResponse in SQLDatabaseChain (#17104)
**Description:** 
When using the SQLDatabaseChain with Llama2-70b LLM and, SQLite
database. I was getting `Warning: You can only execute one statement at
a time.`.

```
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

sql_database_path = '/dccstor/mmdataretrieval/mm_dataset/swimming_record/rag_data/swimmingdataset.db'
sql_db = get_database(sql_database_path)
db_chain = SQLDatabaseChain.from_llm(mistral, sql_db, verbose=True, callbacks = [callback_obj])
db_chain.invoke({
    "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?"
})
```
Error:
```
Warning                                   Traceback (most recent call last)
Cell In[31], line 3
      1 import langchain
      2 langchain.debug=False
----> 3 db_chain.invoke({
      4     "query": "What is the best time of Lance Larson in men's 100 meter butterfly competition?"
      5 })

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:162, in Chain.invoke(self, input, config, **kwargs)
    160 except BaseException as e:
    161     run_manager.on_chain_error(e)
--> 162     raise e
    163 run_manager.on_chain_end(outputs)
    164 final_outputs: Dict[str, Any] = self.prep_outputs(
    165     inputs, outputs, return_only_outputs
    166 )

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain/chains/base.py:156, in Chain.invoke(self, input, config, **kwargs)
    149 run_manager = callback_manager.on_chain_start(
    150     dumpd(self),
    151     inputs,
    152     name=run_name,
    153 )
    154 try:
    155     outputs = (
--> 156         self._call(inputs, run_manager=run_manager)
    157         if new_arg_supported
    158         else self._call(inputs)
    159     )
    160 except BaseException as e:
    161     run_manager.on_chain_error(e)

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:198, in SQLDatabaseChain._call(self, inputs, run_manager)
    194 except Exception as exc:
    195     # Append intermediate steps to exception, to aid in logging and later
    196     # improvement of few shot prompt seeds
    197     exc.intermediate_steps = intermediate_steps  # type: ignore
--> 198     raise exc

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_experimental/sql/base.py:143, in SQLDatabaseChain._call(self, inputs, run_manager)
    139     intermediate_steps.append(
    140         sql_cmd
    141     )  # output: sql generation (no checker)
    142     intermediate_steps.append({"sql_cmd": sql_cmd})  # input: sql exec
--> 143     result = self.database.run(sql_cmd)
    144     intermediate_steps.append(str(result))  # output: sql exec
    145 else:

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:436, in SQLDatabase.run(self, command, fetch, include_columns)
    425 def run(
    426     self,
    427     command: str,
    428     fetch: Literal["all", "one"] = "all",
    429     include_columns: bool = False,
    430 ) -> str:
    431     """Execute a SQL command and return a string representing the results.
    432 
    433     If the statement returns rows, a string of the results is returned.
    434     If the statement returns no rows, an empty string is returned.
    435     """
--> 436     result = self._execute(command, fetch)
    438     res = [
    439         {
    440             column: truncate_word(value, length=self._max_string_length)
   (...)
    443         for r in result
    444     ]
    446     if not include_columns:

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/langchain_community/utilities/sql_database.py:413, in SQLDatabase._execute(self, command, fetch)
    410     elif self.dialect == "postgresql":  # postgresql
    411         connection.exec_driver_sql("SET search_path TO %s", (self._schema,))
--> 413 cursor = connection.execute(text(command))
    414 if cursor.returns_rows:
    415     if fetch == "all":

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1416, in Connection.execute(self, statement, parameters, execution_options)
   1414     raise exc.ObjectNotExecutableError(statement) from err
   1415 else:
-> 1416     return meth(
   1417         self,
   1418         distilled_parameters,
   1419         execution_options or NO_OPTIONS,
   1420     )

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/sql/elements.py:516, in ClauseElement._execute_on_connection(self, connection, distilled_params, execution_options)
    514     if TYPE_CHECKING:
    515         assert isinstance(self, Executable)
--> 516     return connection._execute_clauseelement(
    517         self, distilled_params, execution_options
    518     )
    519 else:
    520     raise exc.ObjectNotExecutableError(self)

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1639, in Connection._execute_clauseelement(self, elem, distilled_parameters, execution_options)
   1627 compiled_cache: Optional[CompiledCacheType] = execution_options.get(
   1628     "compiled_cache", self.engine._compiled_cache
   1629 )
   1631 compiled_sql, extracted_params, cache_hit = elem._compile_w_cache(
   1632     dialect=dialect,
   1633     compiled_cache=compiled_cache,
   (...)
   1637     linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
   1638 )
-> 1639 ret = self._execute_context(
   1640     dialect,
   1641     dialect.execution_ctx_cls._init_compiled,
   1642     compiled_sql,
   1643     distilled_parameters,
   1644     execution_options,
   1645     compiled_sql,
   1646     distilled_parameters,
   1647     elem,
   1648     extracted_params,
   1649     cache_hit=cache_hit,
   1650 )
   1651 if has_events:
   1652     self.dispatch.after_execute(
   1653         self,
   1654         elem,
   (...)
   1658         ret,
   1659     )

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1848, in Connection._execute_context(self, dialect, constructor, statement, parameters, execution_options, *args, **kw)
   1843     return self._exec_insertmany_context(
   1844         dialect,
   1845         context,
   1846     )
   1847 else:
-> 1848     return self._exec_single_context(
   1849         dialect, context, statement, parameters
   1850     )

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1988, in Connection._exec_single_context(self, dialect, context, statement, parameters)
   1985     result = context._setup_result_proxy()
   1987 except BaseException as e:
-> 1988     self._handle_dbapi_exception(
   1989         e, str_statement, effective_parameters, cursor, context
   1990     )
   1992 return result

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:2346, in Connection._handle_dbapi_exception(self, e, statement, parameters, cursor, context, is_sub_exec)
   2344     else:
   2345         assert exc_info[1] is not None
-> 2346         raise exc_info[1].with_traceback(exc_info[2])
   2347 finally:
   2348     del self._reentrant_error

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/base.py:1969, in Connection._exec_single_context(self, dialect, context, statement, parameters)
   1967                 break
   1968     if not evt_handled:
-> 1969         self.dialect.do_execute(
   1970             cursor, str_statement, effective_parameters, context
   1971         )
   1973 if self._has_events or self.engine._has_events:
   1974     self.dispatch.after_cursor_execute(
   1975         self,
   1976         cursor,
   (...)
   1980         context.executemany,
   1981     )

File ~/.conda/envs/guardrails1/lib/python3.9/site-packages/sqlalchemy/engine/default.py:922, in DefaultDialect.do_execute(self, cursor, statement, parameters, context)
    921 def do_execute(self, cursor, statement, parameters, context=None):
--> 922     cursor.execute(statement, parameters)

Warning: You can only execute one statement at a time.
```
**Issue:** 
The Error occurs because when generating the SQLQuery, the llm_input
includes the stop character of "\nSQLResult:", so for this user query
the LLM generated response is **SELECT Time FROM men_butterfly_100m
WHERE Swimmer = 'Lance Larson';\nSQLResult:** it is required to remove
the SQLResult suffix on the llm response before executing it on the
database.

```
llm_inputs = {
            "input": input_text,
            "top_k": str(self.top_k),
            "dialect": self.database.dialect,
            "table_info": table_info,
            "stop": ["\nSQLResult:"],
        }

sql_cmd = self.llm_chain.predict(
                callbacks=_run_manager.get_child(),
                **llm_inputs,
            ).strip()

if SQL_RESULT in sql_cmd:
    sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip()
result = self.database.run(sql_cmd)
```


<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
2024-03-29 01:22:35 -07:00

319 lines
13 KiB
Python

"""Chain for interacting with SQL Database."""
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.language_models import BaseLanguageModel
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
SQL_QUERY = "SQLQuery:"
SQL_RESULT = "SQLResult:"
class SQLDatabaseChain(Chain):
"""Chain for interacting with SQL Database.
Example:
.. code-block:: python
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.llms import OpenAI, SQLDatabase
db = SQLDatabase(...)
db_chain = SQLDatabaseChain.from_llm(OpenAI(), db)
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include the permissions this chain needs.
Failure to do so may result in data corruption or loss, since this chain may
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this chain.
This issue shows an example negative outcome if these steps are not taken:
https://github.com/langchain-ai/langchain/issues/5923
"""
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
database: SQLDatabase = Field(exclude=True)
"""SQL Database to connect to."""
prompt: Optional[BasePromptTemplate] = None
"""[Deprecated] Prompt to use to translate natural language to SQL."""
top_k: int = 5
"""Number of results to return from the query"""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_sql: bool = False
"""Will return sql-command directly without executing it"""
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the SQL table directly."""
use_query_checker: bool = False
"""Whether or not the query checker tool should be used to attempt
to fix the initial SQL from the LLM."""
query_checker_prompt: Optional[BasePromptTemplate] = None
"""The prompt template that should be used by the query checker"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an SQLDatabaseChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
database = values["database"]
prompt = values.get("prompt") or SQL_PROMPTS.get(
database.dialect, PROMPT
)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
input_text = f"{inputs[self.input_key]}\n{SQL_QUERY}"
_run_manager.on_text(input_text, verbose=self.verbose)
# If not present, then defaults to None which is all tables.
table_names_to_use = inputs.get("table_names_to_use")
table_info = self.database.get_table_info(table_names=table_names_to_use)
llm_inputs = {
"input": input_text,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": table_info,
"stop": ["\nSQLResult:"],
}
if self.memory is not None:
for k in self.memory.memory_variables:
llm_inputs[k] = inputs[k]
intermediate_steps: List = []
try:
intermediate_steps.append(llm_inputs.copy()) # input: sql generation
sql_cmd = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,
).strip()
if self.return_sql:
return {self.output_key: sql_cmd}
if not self.use_query_checker:
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
intermediate_steps.append(
sql_cmd
) # output: sql generation (no checker)
intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec
if SQL_QUERY in sql_cmd:
sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip()
if SQL_RESULT in sql_cmd:
sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip()
result = self.database.run(sql_cmd)
intermediate_steps.append(str(result)) # output: sql exec
else:
query_checker_prompt = self.query_checker_prompt or PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"]
)
query_checker_chain = LLMChain(
llm=self.llm_chain.llm, prompt=query_checker_prompt
)
query_checker_inputs = {
"query": sql_cmd,
"dialect": self.database.dialect,
}
checked_sql_command: str = query_checker_chain.predict(
callbacks=_run_manager.get_child(), **query_checker_inputs
).strip()
intermediate_steps.append(
checked_sql_command
) # output: sql generation (checker)
_run_manager.on_text(
checked_sql_command, color="green", verbose=self.verbose
)
intermediate_steps.append(
{"sql_cmd": checked_sql_command}
) # input: sql exec
result = self.database.run(checked_sql_command)
intermediate_steps.append(str(result)) # output: sql exec
sql_cmd = checked_sql_command
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
_run_manager.on_text(str(result), color="yellow", verbose=self.verbose)
# If return direct, we just set the final result equal to
# the result of the sql query result, otherwise try to get a human readable
# final answer
if self.return_direct:
final_result = result
else:
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text
intermediate_steps.append(llm_inputs.copy()) # input: final answer
final_result = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,
).strip()
intermediate_steps.append(final_result) # output: final answer
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result
except Exception as exc:
# Append intermediate steps to exception, to aid in logging and later
# improvement of few shot prompt seeds
exc.intermediate_steps = intermediate_steps # type: ignore
raise exc
@property
def _chain_type(self) -> str:
return "sql_database_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> SQLDatabaseChain:
"""Create a SQLDatabaseChain from an LLM and a database connection.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include the permissions this chain needs.
Failure to do so may result in data corruption or loss, since this chain may
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this chain.
This issue shows an example negative outcome if these steps are not taken:
https://github.com/langchain-ai/langchain/issues/5923
"""
prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, database=db, **kwargs)
class SQLDatabaseSequentialChain(Chain):
"""Chain for querying SQL database that is a sequential chain.
The chain is as follows:
1. Based on the query, determine which tables to use.
2. Based on those tables, call the normal SQL database chain.
This is useful in cases where the number of tables in the database is large.
"""
decider_chain: LLMChain
sql_chain: SQLDatabaseChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_intermediate_steps: bool = False
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
db: SQLDatabase,
query_prompt: BasePromptTemplate = PROMPT,
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
**kwargs: Any,
) -> SQLDatabaseSequentialChain:
"""Load the necessary chains."""
sql_chain = SQLDatabaseChain.from_llm(llm, db, prompt=query_prompt, **kwargs)
decider_chain = LLMChain(
llm=llm, prompt=decider_prompt, output_key="table_names"
)
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_table_names = self.sql_chain.database.get_usable_table_names()
table_names = ", ".join(_table_names)
llm_inputs = {
"query": inputs[self.input_key],
"table_names": table_names,
}
_lowercased_table_names = [name.lower() for name in _table_names]
table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs)
table_names_to_use = [
name
for name in table_names_from_chain
if name.lower() in _lowercased_table_names
]
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(table_names_to_use), color="yellow", verbose=self.verbose
)
new_inputs = {
self.sql_chain.input_key: inputs[self.input_key],
"table_names_to_use": table_names_to_use,
}
return self.sql_chain(
new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True
)
@property
def _chain_type(self) -> str:
return "sql_database_sequential_chain"