@ -114,17 +114,55 @@ def _load_map_reduce_documents_chain(
if not isinstance ( llm_chain , LLMChain ) :
raise ValueError ( f " Expected LLMChain, got { llm_chain } " )
if " combine_document_chain " in config :
if " reduce_documents_chain " in config :
reduce_documents_chain = load_chain_from_config (
config . pop ( " reduce_documents_chain " )
)
elif " reduce_documents_chain_path " in config :
reduce_documents_chain = load_chain ( config . pop ( " reduce_documents_chain_path " ) )
else :
reduce_documents_chain = _load_reduce_documents_chain ( config )
return MapReduceDocumentsChain (
llm_chain = llm_chain ,
reduce_documents_chain = reduce_documents_chain ,
* * config ,
)
def _load_reduce_documents_chain ( config : dict , * * kwargs : Any ) - > ReduceDocumentsChain :
combine_documents_chain = None
collapse_documents_chain = None
if " combine_documents_chain " in config :
combine_document_chain_config = config . pop ( " combine_documents_chain " )
combine_documents_chain = load_chain_from_config ( combine_document_chain_config )
elif " combine_document_chain " in config :
combine_document_chain_config = config . pop ( " combine_document_chain " )
combine_documents_chain = load_chain_from_config ( combine_document_chain_config )
elif " combine_documents_chain_path " in config :
combine_documents_chain = load_chain ( config . pop ( " combine_documents_chain_path " ) )
elif " combine_document_chain_path " in config :
combine_documents_chain = load_chain ( config . pop ( " combine_document_chain_path " ) )
else :
raise ValueError (
" One of `combine_document_chain` or "
" `combine_document_chain_path` must be present. "
" One of `combine_documents_chain` or "
" `combine_documents_chain_path` must be present. "
)
if " collapse_documents_chain " in config :
collapse_document_chain_config = config . pop ( " collapse_documents_chain " )
if collapse_document_chain_config is None :
collapse_documents_chain = None
else :
collapse_documents_chain = load_chain_from_config (
collapse_document_chain_config
)
elif " collapse_documents_chain_path " in config :
collapse_documents_chain = load_chain (
config . pop ( " collapse_documents_chain_path " )
)
if " collapse_document_chain " in config :
el if " collapse_document_chain " in config :
collapse_document_chain_config = config . pop ( " collapse_document_chain " )
if collapse_document_chain_config is None :
collapse_documents_chain = None
@ -136,15 +174,10 @@ def _load_map_reduce_documents_chain(
collapse_documents_chain = load_chain (
config . pop ( " collapse_document_chain_path " )
)
else :
collapse_documents_chain = None
reduce_documents_chain = ReduceDocumentsChain (
return ReduceDocumentsChain (
combine_documents_chain = combine_documents_chain ,
collapse_documents_chain = collapse_documents_chain ,
)
return MapReduceDocumentsChain (
llm_chain = llm_chain ,
reduce_documents_chain = reduce_documents_chain ,
* * config ,
)
@ -497,6 +530,7 @@ type_to_loader_dict = {
" qa_with_sources_chain " : _load_qa_with_sources_chain ,
" stuff_documents_chain " : _load_stuff_documents_chain ,
" map_reduce_documents_chain " : _load_map_reduce_documents_chain ,
" reduce_documents_chain " : _load_reduce_documents_chain ,
" map_rerank_documents_chain " : _load_map_rerank_documents_chain ,
" refine_documents_chain " : _load_refine_documents_chain ,
" sql_database_chain " : _load_sql_database_chain ,