@ -34,6 +34,7 @@ class SQLDatabaseChain(Chain, BaseModel):
""" Number of results to return from the query """
input_key : str = " query " #: :meta private:
output_key : str = " result " #: :meta private:
return_intermediate_steps : bool = False
class Config :
""" Configuration for this pydantic object. """
@ -55,9 +56,12 @@ class SQLDatabaseChain(Chain, BaseModel):
: meta private :
"""
return [ self . output_key ]
if not self . return_intermediate_steps :
return [ self . output_key ]
else :
return [ self . output_key , " intermediate_steps " ]
def _call ( self , inputs : Dict [ str , Any ] ) - > Dict [ str , str ] :
def _call ( self , inputs : Dict [ str , Any ] ) - > Dict [ str , Any ] :
llm_chain = LLMChain ( llm = self . llm , prompt = self . prompt )
input_text = f " { inputs [ self . input_key ] } \n SQLQuery: "
self . callback_manager . on_text ( input_text , verbose = self . verbose )
@ -71,10 +75,12 @@ class SQLDatabaseChain(Chain, BaseModel):
" table_info " : table_info ,
" stop " : [ " \n SQLResult: " ] ,
}
intermediate_steps = [ ]
sql_cmd = llm_chain . predict ( * * llm_inputs )
intermediate_steps . append ( sql_cmd )
self . callback_manager . on_text ( sql_cmd , color = " green " , verbose = self . verbose )
result = self . database . run ( sql_cmd )
intermediate_steps . append ( result )
self . callback_manager . on_text ( " \n SQLResult: " , verbose = self . verbose )
self . callback_manager . on_text ( result , color = " yellow " , verbose = self . verbose )
self . callback_manager . on_text ( " \n Answer: " , verbose = self . verbose )
@ -82,7 +88,10 @@ class SQLDatabaseChain(Chain, BaseModel):
llm_inputs [ " input " ] = input_text
final_result = llm_chain . predict ( * * llm_inputs )
self . callback_manager . on_text ( final_result , color = " green " , verbose = self . verbose )
return { self . output_key : final_result }
chain_result : Dict [ str , Any ] = { self . output_key : final_result }
if self . return_intermediate_steps :
chain_result [ " intermediate_steps " ] = intermediate_steps
return chain_result
@property
def _chain_type ( self ) - > str :