From 1bc3244db92faef6ac7961f935a6b45b6df7eab3 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 20 Sep 2023 14:37:49 -0700 Subject: [PATCH] fix loading of sql chain (#10860) Closing #6889 --- libs/langchain/langchain/chains/loading.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index 9543f62988..f8fe1396c8 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -358,10 +358,16 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any: + from langchain_experimental.sql import SQLDatabaseChain + if "database" in kwargs: database = kwargs.pop("database") else: raise ValueError("`database` must be present.") + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + chain = load_chain_from_config(llm_chain_config) + return SQLDatabaseChain(llm_chain=chain, database=database, **config) if "llm" in config: llm_config = config.pop("llm") llm = load_llm_from_config(llm_config) @@ -374,7 +380,6 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any: prompt = load_prompt_from_config(prompt_config) else: prompt = None - from langchain_experimental.sql import SQLDatabaseChain return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config)