improve sql prompt (#4611)

Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
Co-authored-by: Taqi Jaffri <tjaffri@gmail.com>
textloader_autodetect_encodings
Harrison Chase 1 year ago committed by GitHub
parent 01531cb16d
commit 7d425cbf38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because one or more lines are too long

@ -12,7 +12,11 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.sql_database import SQLDatabase
from langchain.tools.sql_database.prompt import QUERY_CHECKER
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
class SQLDatabaseChain(Chain):
@ -41,6 +45,11 @@ class SQLDatabaseChain(Chain):
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the SQL table directly."""
use_query_checker: bool = False
"""Whether or not the query checker tool should be used to attempt
to fix the initial SQL from the LLM."""
query_checker_prompt: Optional[BasePromptTemplate] = None
"""The prompt template that should be used by the query checker"""
class Config:
"""Configuration for this pydantic object."""
@ -81,7 +90,7 @@ class SQLDatabaseChain(Chain):
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, "intermediate_steps"]
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _call(
self,
@ -96,36 +105,80 @@ class SQLDatabaseChain(Chain):
table_info = self.database.get_table_info(table_names=table_names_to_use)
llm_inputs = {
"input": input_text,
"top_k": self.top_k,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": table_info,
"stop": ["\nSQLResult:"],
}
intermediate_steps = []
sql_cmd = self.llm_chain.predict(
callbacks=_run_manager.get_child(), **llm_inputs
)
intermediate_steps.append(sql_cmd)
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
result = self.database.run(sql_cmd)
intermediate_steps.append(result)
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
# If return direct, we just set the final result equal to the sql query
if self.return_direct:
final_result = result
else:
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text
final_result = self.llm_chain.predict(
callbacks=_run_manager.get_child(), **llm_inputs
)
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result["intermediate_steps"] = intermediate_steps
return chain_result
intermediate_steps: List = []
try:
intermediate_steps.append(llm_inputs) # input: sql generation
sql_cmd = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,
).strip()
if not self.use_query_checker:
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
intermediate_steps.append(
sql_cmd
) # output: sql generation (no checker)
intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec
result = self.database.run(sql_cmd)
intermediate_steps.append(str(result)) # output: sql exec
else:
query_checker_prompt = self.query_checker_prompt or PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"]
)
query_checker_chain = LLMChain(
llm=self.llm, prompt=query_checker_prompt
)
query_checker_inputs = {
"query": sql_cmd,
"dialect": self.database.dialect,
}
checked_sql_command: str = query_checker_chain.predict(
callbacks=_run_manager.get_child(), **query_checker_inputs
).strip()
intermediate_steps.append(
checked_sql_command
) # output: sql generation (checker)
_run_manager.on_text(
checked_sql_command, color="green", verbose=self.verbose
)
intermediate_steps.append(
{"sql_cmd": checked_sql_command}
) # input: sql exec
result = self.database.run(checked_sql_command)
intermediate_steps.append(str(result)) # output: sql exec
sql_cmd = checked_sql_command
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
# If return direct, we just set the final result equal to
# the result of the sql query result, otherwise try to get a human readable
# final answer
if self.return_direct:
final_result = result
else:
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text
intermediate_steps.append(llm_inputs) # input: final answer
final_result = self.llm_chain.predict(
callbacks=_run_manager.get_child(),
**llm_inputs,
).strip()
intermediate_steps.append(final_result) # output: final answer
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result
except Exception as exc:
# Append intermediate steps to exception, to aid in logging and later
# improvement of few shot prompt seeds
exc.intermediate_steps = intermediate_steps # type: ignore
raise exc
@property
def _chain_type(self) -> str:
@ -195,7 +248,7 @@ class SQLDatabaseSequentialChain(Chain):
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, "intermediate_steps"]
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _call(
self,
@ -209,9 +262,13 @@ class SQLDatabaseSequentialChain(Chain):
"query": inputs[self.input_key],
"table_names": table_names,
}
table_names_to_use = self.decider_chain.predict_and_parse(
callbacks=_run_manager.get_child(), **llm_inputs
)
_lowercased_table_names = [name.lower() for name in _table_names]
table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs)
table_names_to_use = [
name
for name in table_names_from_chain
if name.lower() in _lowercased_table_names
]
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(table_names_to_use), color="yellow", verbose=self.verbose

@ -3,6 +3,11 @@ from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.prompts.prompt import PromptTemplate
PROMPT_SUFFIX = """Only use the following tables:
{table_info}
Question: {input}"""
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
@ -16,17 +21,14 @@ SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the tables listed below.
{table_info}
Question: {input}"""
"""
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect", "top_k"],
template=_DEFAULT_TEMPLATE,
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
)
_DECIDER_TEMPLATE = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be necessary to answer this question.
Question: {query}
@ -53,14 +55,11 @@ SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
Question: {input}"""
"""
DUCKDB_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_duckdb_prompt,
template=_duckdb_prompt + PROMPT_SUFFIX,
)
_googlesql_prompt = """You are a GoogleSQL expert. Given an input question, first create a syntactically correct GoogleSQL query to run, then look at the results of the query and return the answer to the input question.
@ -76,14 +75,11 @@ SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
Question: {input}"""
"""
GOOGLESQL_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_googlesql_prompt,
template=_googlesql_prompt + PROMPT_SUFFIX,
)
@ -100,13 +96,11 @@ SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
Question: {input}"""
"""
MSSQL_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"], template=_mssql_prompt
input_variables=["input", "table_info", "top_k"],
template=_mssql_prompt + PROMPT_SUFFIX,
)
@ -123,14 +117,11 @@ SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
Question: {input}"""
"""
MYSQL_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_mysql_prompt,
template=_mysql_prompt + PROMPT_SUFFIX,
)
@ -147,14 +138,11 @@ SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
Question: {input}"""
"""
MARIADB_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_mariadb_prompt,
template=_mariadb_prompt + PROMPT_SUFFIX,
)
@ -171,14 +159,11 @@ SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
Question: {input}"""
"""
ORACLE_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_oracle_prompt,
template=_oracle_prompt + PROMPT_SUFFIX,
)
@ -195,13 +180,11 @@ SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
Question: {input}"""
"""
POSTGRES_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"], template=_postgres_prompt
input_variables=["input", "table_info", "top_k"],
template=_postgres_prompt + PROMPT_SUFFIX,
)
@ -218,14 +201,11 @@ SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
Question: {input}"""
"""
SQLITE_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_sqlite_prompt,
template=_sqlite_prompt + PROMPT_SUFFIX,
)
_clickhouse_prompt = """You are a ClickHouse expert. Given an input question, first create a syntactically correct Clic query to run, then look at the results of the query and return the answer to the input question.
@ -241,14 +221,11 @@ SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}
Question: {input}"""
"""
CLICKHOUSE_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_clickhouse_prompt,
template=_clickhouse_prompt + PROMPT_SUFFIX,
)
_prestodb_prompt = """You are a PrestoDB expert. Given an input question, first create a syntactically correct PrestoDB query to run, then look at the results of the query and return the answer to the input question.
@ -264,14 +241,11 @@ SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}
Question: {input}"""
"""
PRESTODB_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_prestodb_prompt,
template=_prestodb_prompt + PROMPT_SUFFIX,
)

Loading…
Cancel
Save