diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index f6255d2e61..c2e0b81397 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -114,17 +114,55 @@ def _load_map_reduce_documents_chain( if not isinstance(llm_chain, LLMChain): raise ValueError(f"Expected LLMChain, got {llm_chain}") - if "combine_document_chain" in config: + if "reduce_documents_chain" in config: + reduce_documents_chain = load_chain_from_config( + config.pop("reduce_documents_chain") + ) + elif "reduce_documents_chain_path" in config: + reduce_documents_chain = load_chain(config.pop("reduce_documents_chain_path")) + else: + reduce_documents_chain = _load_reduce_documents_chain(config) + + return MapReduceDocumentsChain( + llm_chain=llm_chain, + reduce_documents_chain=reduce_documents_chain, + **config, + ) + + +def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocumentsChain: + combine_documents_chain = None + collapse_documents_chain = None + + if "combine_documents_chain" in config: + combine_document_chain_config = config.pop("combine_documents_chain") + combine_documents_chain = load_chain_from_config(combine_document_chain_config) + elif "combine_document_chain" in config: combine_document_chain_config = config.pop("combine_document_chain") combine_documents_chain = load_chain_from_config(combine_document_chain_config) + elif "combine_documents_chain_path" in config: + combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) elif "combine_document_chain_path" in config: combine_documents_chain = load_chain(config.pop("combine_document_chain_path")) else: raise ValueError( - "One of `combine_document_chain` or " - "`combine_document_chain_path` must be present." + "One of `combine_documents_chain` or " + "`combine_documents_chain_path` must be present." ) - if "collapse_document_chain" in config: + + if "collapse_documents_chain" in config: + collapse_document_chain_config = config.pop("collapse_documents_chain") + if collapse_document_chain_config is None: + collapse_documents_chain = None + else: + collapse_documents_chain = load_chain_from_config( + collapse_document_chain_config + ) + elif "collapse_documents_chain_path" in config: + collapse_documents_chain = load_chain( + config.pop("collapse_documents_chain_path") + ) + elif "collapse_document_chain" in config: collapse_document_chain_config = config.pop("collapse_document_chain") if collapse_document_chain_config is None: collapse_documents_chain = None @@ -136,15 +174,10 @@ def _load_map_reduce_documents_chain( collapse_documents_chain = load_chain( config.pop("collapse_document_chain_path") ) - else: - collapse_documents_chain = None - reduce_documents_chain = ReduceDocumentsChain( + + return ReduceDocumentsChain( combine_documents_chain=combine_documents_chain, collapse_documents_chain=collapse_documents_chain, - ) - return MapReduceDocumentsChain( - llm_chain=llm_chain, - reduce_documents_chain=reduce_documents_chain, **config, ) @@ -497,6 +530,7 @@ type_to_loader_dict = { "qa_with_sources_chain": _load_qa_with_sources_chain, "stuff_documents_chain": _load_stuff_documents_chain, "map_reduce_documents_chain": _load_map_reduce_documents_chain, + "reduce_documents_chain": _load_reduce_documents_chain, "map_rerank_documents_chain": _load_map_rerank_documents_chain, "refine_documents_chain": _load_refine_documents_chain, "sql_database_chain": _load_sql_database_chain,