@ -1,11 +1,13 @@
""" Chain for interacting with SQL Database. """
from typing import Dict , List
from __future__ import annotations
from typing import Any , Dict , List
from pydantic import BaseModel , Extra
from langchain . chains . base import Chain
from langchain . chains . llm import LLMChain
from langchain . chains . sql_database . prompt import PROMPT
from langchain . chains . sql_database . prompt import DECIDER_PROMPT, PROMPT
from langchain . llms . base import BaseLLM
from langchain . prompts . base import BasePromptTemplate
from langchain . sql_database import SQLDatabase
@ -53,15 +55,18 @@ class SQLDatabaseChain(Chain, BaseModel):
"""
return [ self . output_key ]
def _call ( self , inputs : Dict [ str , str ] ) - > Dict [ str , str ] :
def _call ( self , inputs : Dict [ str , Any ] ) - > Dict [ str , str ] :
llm_chain = LLMChain ( llm = self . llm , prompt = self . prompt )
input_text = f " { inputs [ self . input_key ] } \n SQLQuery: "
if self . verbose :
self . callback_manager . on_text ( input_text )
# If not present, then defaults to None which is all tables.
table_names_to_use = inputs . get ( " table_names_to_use " )
table_info = self . database . get_table_info ( table_names = table_names_to_use )
llm_inputs = {
" input " : input_text ,
" dialect " : self . database . dialect ,
" table_info " : self . database . table_info ,
" table_info " : table_info ,
" stop " : [ " \n SQLResult: " ] ,
}
sql_cmd = llm_chain . predict ( * * llm_inputs )
@ -78,3 +83,68 @@ class SQLDatabaseChain(Chain, BaseModel):
if self . verbose :
self . callback_manager . on_text ( final_result , color = " green " )
return { self . output_key : final_result }
class SQLDatabaseSequentialChain ( Chain , BaseModel ) :
""" Chain for querying SQL database that is a sequential chain.
The chain is as follows :
1. Based on the query , determine which tables to use .
2. Based on those tables , call the normal SQL database chain .
This is useful in cases where the number of tables in the database is large .
"""
@classmethod
def from_llm (
cls ,
llm : BaseLLM ,
database : SQLDatabase ,
query_prompt : BasePromptTemplate = PROMPT ,
decider_prompt : BasePromptTemplate = DECIDER_PROMPT ,
* * kwargs : Any ,
) - > SQLDatabaseSequentialChain :
""" Load the necessary chains. """
sql_chain = SQLDatabaseChain ( llm = llm , database = database , prompt = query_prompt )
decider_chain = LLMChain (
llm = llm , prompt = decider_prompt , output_key = " table_names "
)
return cls ( sql_chain = sql_chain , decider_chain = decider_chain , * * kwargs )
decider_chain : LLMChain
sql_chain : SQLDatabaseChain
input_key : str = " query " #: :meta private:
output_key : str = " result " #: :meta private:
@property
def input_keys ( self ) - > List [ str ] :
""" Return the singular input key.
: meta private :
"""
return [ self . input_key ]
@property
def output_keys ( self ) - > List [ str ] :
""" Return the singular output key.
: meta private :
"""
return [ self . output_key ]
def _call ( self , inputs : Dict [ str , str ] ) - > Dict [ str , str ] :
_table_names = self . sql_chain . database . get_table_names ( )
table_names = " , " . join ( _table_names )
llm_inputs = {
" query " : inputs [ self . input_key ] ,
" table_names " : table_names ,
}
table_names_to_use = self . decider_chain . predict_and_parse ( * * llm_inputs )
if self . verbose :
self . callback_manager . on_text ( " Table names to use: " , end = " \n " )
self . callback_manager . on_text ( str ( table_names_to_use ) , color = " yellow " )
new_inputs = {
self . sql_chain . input_key : inputs [ self . input_key ] ,
" table_names_to_use " : table_names_to_use ,
}
return self . sql_chain ( new_inputs , return_only_outputs = True )