|
|
@ -97,11 +97,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
|
|
|
values["llm_chain"] = LLMChain(
|
|
|
|
values["llm_chain"] = LLMChain(
|
|
|
|
llm=values.get("llm"),
|
|
|
|
llm=values.get("llm"),
|
|
|
|
prompt=PromptTemplate(
|
|
|
|
prompt=PromptTemplate(
|
|
|
|
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
|
|
|
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
|
|
|
),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
|
|
|
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
|
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
|
|
|
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
|
|
|
)
|
|
|
|
)
|
|
|
|