mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
parent
4074ea4c41
commit
1bc3244db9
@ -358,10 +358,16 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
|
|||||||
|
|
||||||
|
|
||||||
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
|
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
|
||||||
|
from langchain_experimental.sql import SQLDatabaseChain
|
||||||
|
|
||||||
if "database" in kwargs:
|
if "database" in kwargs:
|
||||||
database = kwargs.pop("database")
|
database = kwargs.pop("database")
|
||||||
else:
|
else:
|
||||||
raise ValueError("`database` must be present.")
|
raise ValueError("`database` must be present.")
|
||||||
|
if "llm_chain" in config:
|
||||||
|
llm_chain_config = config.pop("llm_chain")
|
||||||
|
chain = load_chain_from_config(llm_chain_config)
|
||||||
|
return SQLDatabaseChain(llm_chain=chain, database=database, **config)
|
||||||
if "llm" in config:
|
if "llm" in config:
|
||||||
llm_config = config.pop("llm")
|
llm_config = config.pop("llm")
|
||||||
llm = load_llm_from_config(llm_config)
|
llm = load_llm_from_config(llm_config)
|
||||||
@ -374,7 +380,6 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
|
|||||||
prompt = load_prompt_from_config(prompt_config)
|
prompt = load_prompt_from_config(prompt_config)
|
||||||
else:
|
else:
|
||||||
prompt = None
|
prompt = None
|
||||||
from langchain_experimental.sql import SQLDatabaseChain
|
|
||||||
|
|
||||||
return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config)
|
return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user