diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index 9543f62988..f8fe1396c8 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -358,10 +358,16 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any: + from langchain_experimental.sql import SQLDatabaseChain + if "database" in kwargs: database = kwargs.pop("database") else: raise ValueError("`database` must be present.") + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + chain = load_chain_from_config(llm_chain_config) + return SQLDatabaseChain(llm_chain=chain, database=database, **config) if "llm" in config: llm_config = config.pop("llm") llm = load_llm_from_config(llm_config) @@ -374,7 +380,6 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any: prompt = load_prompt_from_config(prompt_config) else: prompt = None - from langchain_experimental.sql import SQLDatabaseChain return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config)