Fix load map reduce documents chain (#7915)

This PR updates _load_reduce_documents_chain to handle
`reduce_documents_chain` and `combine_documents_chain` config

Please review @hwchase17, @baskaryan

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Moonsik Kang 2023-08-03 23:27:38 -07:00 committed by GitHub
parent 0f0ccfe7f6
commit 5b7ff215e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -114,17 +114,55 @@ def _load_map_reduce_documents_chain(
if not isinstance(llm_chain, LLMChain):
raise ValueError(f"Expected LLMChain, got {llm_chain}")
if "combine_document_chain" in config:
if "reduce_documents_chain" in config:
reduce_documents_chain = load_chain_from_config(
config.pop("reduce_documents_chain")
)
elif "reduce_documents_chain_path" in config:
reduce_documents_chain = load_chain(config.pop("reduce_documents_chain_path"))
else:
reduce_documents_chain = _load_reduce_documents_chain(config)
return MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
**config,
)
def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocumentsChain:
combine_documents_chain = None
collapse_documents_chain = None
if "combine_documents_chain" in config:
combine_document_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_document_chain_config)
elif "combine_document_chain" in config:
combine_document_chain_config = config.pop("combine_document_chain")
combine_documents_chain = load_chain_from_config(combine_document_chain_config)
elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path"))
elif "combine_document_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_document_chain_path"))
else:
raise ValueError(
"One of `combine_document_chain` or "
"`combine_document_chain_path` must be present."
"One of `combine_documents_chain` or "
"`combine_documents_chain_path` must be present."
)
if "collapse_document_chain" in config:
if "collapse_documents_chain" in config:
collapse_document_chain_config = config.pop("collapse_documents_chain")
if collapse_document_chain_config is None:
collapse_documents_chain = None
else:
collapse_documents_chain = load_chain_from_config(
collapse_document_chain_config
)
elif "collapse_documents_chain_path" in config:
collapse_documents_chain = load_chain(
config.pop("collapse_documents_chain_path")
)
elif "collapse_document_chain" in config:
collapse_document_chain_config = config.pop("collapse_document_chain")
if collapse_document_chain_config is None:
collapse_documents_chain = None
@ -136,15 +174,10 @@ def _load_map_reduce_documents_chain(
collapse_documents_chain = load_chain(
config.pop("collapse_document_chain_path")
)
else:
collapse_documents_chain = None
reduce_documents_chain = ReduceDocumentsChain(
return ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_documents_chain,
)
return MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
**config,
)
@ -497,6 +530,7 @@ type_to_loader_dict = {
"qa_with_sources_chain": _load_qa_with_sources_chain,
"stuff_documents_chain": _load_stuff_documents_chain,
"map_reduce_documents_chain": _load_map_reduce_documents_chain,
"reduce_documents_chain": _load_reduce_documents_chain,
"map_rerank_documents_chain": _load_map_rerank_documents_chain,
"refine_documents_chain": _load_refine_documents_chain,
"sql_database_chain": _load_sql_database_chain,