diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index b7e2c2a6..f26125bd 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -138,19 +138,31 @@ def _load_map_reduce_documents_chain( def _load_llm_bash_chain(config: dict, **kwargs: Any) -> LLMBashChain: - if "llm" in config: + llm_chain = None + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + llm_chain = load_chain_from_config(llm_chain_config) + elif "llm_chain_path" in config: + llm_chain = load_chain(config.pop("llm_chain_path")) + # llm attribute is deprecated in favor of llm_chain, here to support old configs + elif "llm" in config: llm_config = config.pop("llm") llm = load_llm_from_config(llm_config) + # llm_path attribute is deprecated in favor of llm_chain_path, + # its to support old configs elif "llm_path" in config: llm = load_llm(config.pop("llm_path")) else: - raise ValueError("One of `llm` or `llm_path` must be present.") + raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") if "prompt" in config: prompt_config = config.pop("prompt") prompt = load_prompt_from_config(prompt_config) elif "prompt_path" in config: prompt = load_prompt(config.pop("prompt_path")) - return LLMBashChain(llm=llm, prompt=prompt, **config) + if llm_chain: + return LLMBashChain(llm_chain=llm_chain, prompt=prompt, **config) + else: + return LLMBashChain(llm=llm, prompt=prompt, **config) def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain: @@ -200,19 +212,31 @@ def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain: def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain: - if "llm" in config: + llm_chain = None + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + llm_chain = load_chain_from_config(llm_chain_config) + elif "llm_chain_path" in config: + llm_chain = load_chain(config.pop("llm_chain_path")) + # llm attribute is deprecated in favor of llm_chain, here to support old configs + elif "llm" in config: llm_config = config.pop("llm") llm = load_llm_from_config(llm_config) + # llm_path attribute is deprecated in favor of llm_chain_path, + # its to support old configs elif "llm_path" in config: llm = load_llm(config.pop("llm_path")) else: - raise ValueError("One of `llm` or `llm_path` must be present.") + raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") if "prompt" in config: prompt_config = config.pop("prompt") prompt = load_prompt_from_config(prompt_config) elif "prompt_path" in config: prompt = load_prompt(config.pop("prompt_path")) - return LLMMathChain(llm=llm, prompt=prompt, **config) + if llm_chain: + return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) + else: + return LLMMathChain(llm=llm, prompt=prompt, **config) def _load_map_rerank_documents_chain( @@ -229,13 +253,22 @@ def _load_map_rerank_documents_chain( def _load_pal_chain(config: dict, **kwargs: Any) -> PALChain: - if "llm" in config: + llm_chain = None + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + llm_chain = load_chain_from_config(llm_chain_config) + elif "llm_chain_path" in config: + llm_chain = load_chain(config.pop("llm_chain_path")) + # llm attribute is deprecated in favor of llm_chain, here to support old configs + elif "llm" in config: llm_config = config.pop("llm") llm = load_llm_from_config(llm_config) + # llm_path attribute is deprecated in favor of llm_chain_path, + # its to support old configs elif "llm_path" in config: llm = load_llm(config.pop("llm_path")) else: - raise ValueError("One of `llm` or `llm_path` must be present.") + raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") if "prompt" in config: prompt_config = config.pop("prompt") prompt = load_prompt_from_config(prompt_config) @@ -243,7 +276,10 @@ def _load_pal_chain(config: dict, **kwargs: Any) -> PALChain: prompt = load_prompt(config.pop("prompt_path")) else: raise ValueError("One of `prompt` or `prompt_path` must be present.") - return PALChain(llm=llm, prompt=prompt, **config) + if llm_chain: + return PALChain(llm_chain=llm_chain, prompt=prompt, **config) + else: + return PALChain(llm=llm, prompt=prompt, **config) def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocumentsChain: