more complex sql chain (#619)

add a more complex sql chain that first subsets the necessary tables
pull/618/head^2
Harrison Chase 2 years ago committed by GitHub
parent 49b3d6c78c
commit 1c71fadfdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -179,12 +179,82 @@
"db_chain.run(\"How many employees are there in the foobar table?\")"
]
},
{
"cell_type": "markdown",
"id": "c12ae15a",
"metadata": {},
"source": [
"## SQLDatabaseSequentialChain\n",
"\n",
"Chain for querying SQL database that is a sequential chain.\n",
"\n",
"The chain is as follows:\n",
"\n",
" 1. Based on the query, determine which tables to use.\n",
" 2. Based on those tables, call the normal SQL database chain.\n",
"\n",
"This is useful in cases where the number of tables in the database is large."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "e59a4740",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.sql_database.base import SQLDatabaseSequentialChain"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "58bb49b6",
"metadata": {},
"outputs": [],
"source": [
"chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "95017b1a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseSequentialChain chain...\u001b[0m\n",
"Table names to use:\n",
"\u001b[33;1m\u001b[1;3m['Employee', 'Customer']\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' 0 employees are also customers.'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"How many employees are also customers?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2998b03",
"metadata": {},
"outputs": [],
"source": []
}
],

@ -47,7 +47,10 @@ class SequentialChain(Chain, BaseModel):
for chain in chains:
missing_vars = set(chain.input_keys).difference(known_variables)
if missing_vars:
raise ValueError(f"Missing required input keys: {missing_vars}")
raise ValueError(
f"Missing required input keys: {missing_vars}, "
f"only had {known_variables}"
)
overlapping_keys = known_variables.intersection(chain.output_keys)
if overlapping_keys:
raise ValueError(

@ -1,11 +1,13 @@
"""Chain for interacting with SQL Database."""
from typing import Dict, List
from __future__ import annotations
from typing import Any, Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.sql_database import SQLDatabase
@ -53,15 +55,18 @@ class SQLDatabaseChain(Chain, BaseModel):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
if self.verbose:
self.callback_manager.on_text(input_text)
# If not present, then defaults to None which is all tables.
table_names_to_use = inputs.get("table_names_to_use")
table_info = self.database.get_table_info(table_names=table_names_to_use)
llm_inputs = {
"input": input_text,
"dialect": self.database.dialect,
"table_info": self.database.table_info,
"table_info": table_info,
"stop": ["\nSQLResult:"],
}
sql_cmd = llm_chain.predict(**llm_inputs)
@ -78,3 +83,68 @@ class SQLDatabaseChain(Chain, BaseModel):
if self.verbose:
self.callback_manager.on_text(final_result, color="green")
return {self.output_key: final_result}
class SQLDatabaseSequentialChain(Chain, BaseModel):
"""Chain for querying SQL database that is a sequential chain.
The chain is as follows:
1. Based on the query, determine which tables to use.
2. Based on those tables, call the normal SQL database chain.
This is useful in cases where the number of tables in the database is large.
"""
@classmethod
def from_llm(
cls,
llm: BaseLLM,
database: SQLDatabase,
query_prompt: BasePromptTemplate = PROMPT,
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
**kwargs: Any,
) -> SQLDatabaseSequentialChain:
"""Load the necessary chains."""
sql_chain = SQLDatabaseChain(llm=llm, database=database, prompt=query_prompt)
decider_chain = LLMChain(
llm=llm, prompt=decider_prompt, output_key="table_names"
)
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)
decider_chain: LLMChain
sql_chain: SQLDatabaseChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_table_names = self.sql_chain.database.get_table_names()
table_names = ", ".join(_table_names)
llm_inputs = {
"query": inputs[self.input_key],
"table_names": table_names,
}
table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs)
if self.verbose:
self.callback_manager.on_text("Table names to use:", end="\n")
self.callback_manager.on_text(str(table_names_to_use), color="yellow")
new_inputs = {
self.sql_chain.input_key: inputs[self.input_key],
"table_names_to_use": table_names_to_use,
}
return self.sql_chain(new_inputs, return_only_outputs=True)

@ -1,4 +1,5 @@
# flake8: noqa
from langchain.prompts.base import CommaSeparatedListOutputParser
from langchain.prompts.prompt import PromptTemplate
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
@ -17,3 +18,16 @@ Question: {input}"""
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)
_DECIDER_TEMPLATE = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be neccessary to answer this question.
Question: {query}
Table Names: {table_names}
Relevant Table Names:"""
DECIDER_PROMPT = PromptTemplate(
input_variables=["query", "table_names"],
template=_DECIDER_TEMPLATE,
output_parser=CommaSeparatedListOutputParser(),
)

@ -64,6 +64,14 @@ class ListOutputParser(BaseOutputParser):
"""Parse the output of an LLM call."""
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse out comma separated lists."""
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")
class RegexParser(BaseOutputParser, BaseModel):
"""Class to parse the output into a dictionary."""

@ -50,7 +50,8 @@ class SQLDatabase:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def _get_table_names(self) -> Iterable[str]:
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return set(self._all_tables) - set(self._ignore_tables)
@ -58,9 +59,19 @@ class SQLDatabase:
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables."""
all_table_names = self.get_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
template = "Table '{table_name}' has columns: {columns}."
tables = []
for table_name in self._get_table_names():
for table_name in all_table_names:
columns = []
for column in self._inspector.get_columns(table_name, schema=self._schema):
columns.append(f"{column['name']} ({str(column['type'])})")

Loading…
Cancel
Save