fix loading of sql chain (#10860)

Closing #6889
pull/10861/head
Harrison Chase 11 months ago committed by GitHub
parent 4074ea4c41
commit 1bc3244db9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -358,10 +358,16 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
from langchain_experimental.sql import SQLDatabaseChain
if "database" in kwargs:
database = kwargs.pop("database")
else:
raise ValueError("`database` must be present.")
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
chain = load_chain_from_config(llm_chain_config)
return SQLDatabaseChain(llm_chain=chain, database=database, **config)
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
@ -374,7 +380,6 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
prompt = load_prompt_from_config(prompt_config)
else:
prompt = None
from langchain_experimental.sql import SQLDatabaseChain
return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config)

Loading…
Cancel
Save