fix loading of sql chain (#10860)

Closing #6889
This commit is contained in:
Harrison Chase 2023-09-20 14:37:49 -07:00 committed by GitHub
parent 4074ea4c41
commit 1bc3244db9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -358,10 +358,16 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any: def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
from langchain_experimental.sql import SQLDatabaseChain
if "database" in kwargs: if "database" in kwargs:
database = kwargs.pop("database") database = kwargs.pop("database")
else: else:
raise ValueError("`database` must be present.") raise ValueError("`database` must be present.")
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
chain = load_chain_from_config(llm_chain_config)
return SQLDatabaseChain(llm_chain=chain, database=database, **config)
if "llm" in config: if "llm" in config:
llm_config = config.pop("llm") llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config) llm = load_llm_from_config(llm_config)
@ -374,7 +380,6 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
prompt = load_prompt_from_config(prompt_config) prompt = load_prompt_from_config(prompt_config)
else: else:
prompt = None prompt = None
from langchain_experimental.sql import SQLDatabaseChain
return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config) return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config)