|
|
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.chains.sql_database.prompt import PROMPT
|
|
|
|
|
from langchain.input import print_text
|
|
|
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
|
from langchain.sql_database import SQLDatabase
|
|
|
|
|
|
|
|
|
@ -55,7 +54,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
|
|
|
|
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
|
|
|
|
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
|
|
|
|
|
if self.verbose:
|
|
|
|
|
print_text(input_text)
|
|
|
|
|
self.callback_manager.on_text(input_text)
|
|
|
|
|
llm_inputs = {
|
|
|
|
|
"input": input_text,
|
|
|
|
|
"dialect": self.database.dialect,
|
|
|
|
@ -64,15 +63,15 @@ class SQLDatabaseChain(Chain, BaseModel):
|
|
|
|
|
}
|
|
|
|
|
sql_cmd = llm_chain.predict(**llm_inputs)
|
|
|
|
|
if self.verbose:
|
|
|
|
|
print_text(sql_cmd, color="green")
|
|
|
|
|
self.callback_manager.on_text(sql_cmd, color="green")
|
|
|
|
|
result = self.database.run(sql_cmd)
|
|
|
|
|
if self.verbose:
|
|
|
|
|
print_text("\nSQLResult: ")
|
|
|
|
|
print_text(result, color="yellow")
|
|
|
|
|
print_text("\nAnswer:")
|
|
|
|
|
self.callback_manager.on_text("\nSQLResult: ")
|
|
|
|
|
self.callback_manager.on_text(result, color="yellow")
|
|
|
|
|
self.callback_manager.on_text("\nAnswer:")
|
|
|
|
|
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
|
|
|
|
llm_inputs["input"] = input_text
|
|
|
|
|
final_result = llm_chain.predict(**llm_inputs)
|
|
|
|
|
if self.verbose:
|
|
|
|
|
print_text(final_result, color="green")
|
|
|
|
|
self.callback_manager.on_text(final_result, color="green")
|
|
|
|
|
return {self.output_key: final_result}
|
|
|
|
|