mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Fix load map reduce documents chain (#7915)
This PR updates _load_reduce_documents_chain to handle `reduce_documents_chain` and `combine_documents_chain` config Please review @hwchase17, @baskaryan Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
0f0ccfe7f6
commit
5b7ff215e8
@ -114,17 +114,55 @@ def _load_map_reduce_documents_chain(
|
|||||||
if not isinstance(llm_chain, LLMChain):
|
if not isinstance(llm_chain, LLMChain):
|
||||||
raise ValueError(f"Expected LLMChain, got {llm_chain}")
|
raise ValueError(f"Expected LLMChain, got {llm_chain}")
|
||||||
|
|
||||||
if "combine_document_chain" in config:
|
if "reduce_documents_chain" in config:
|
||||||
|
reduce_documents_chain = load_chain_from_config(
|
||||||
|
config.pop("reduce_documents_chain")
|
||||||
|
)
|
||||||
|
elif "reduce_documents_chain_path" in config:
|
||||||
|
reduce_documents_chain = load_chain(config.pop("reduce_documents_chain_path"))
|
||||||
|
else:
|
||||||
|
reduce_documents_chain = _load_reduce_documents_chain(config)
|
||||||
|
|
||||||
|
return MapReduceDocumentsChain(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocumentsChain:
|
||||||
|
combine_documents_chain = None
|
||||||
|
collapse_documents_chain = None
|
||||||
|
|
||||||
|
if "combine_documents_chain" in config:
|
||||||
|
combine_document_chain_config = config.pop("combine_documents_chain")
|
||||||
|
combine_documents_chain = load_chain_from_config(combine_document_chain_config)
|
||||||
|
elif "combine_document_chain" in config:
|
||||||
combine_document_chain_config = config.pop("combine_document_chain")
|
combine_document_chain_config = config.pop("combine_document_chain")
|
||||||
combine_documents_chain = load_chain_from_config(combine_document_chain_config)
|
combine_documents_chain = load_chain_from_config(combine_document_chain_config)
|
||||||
|
elif "combine_documents_chain_path" in config:
|
||||||
|
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path"))
|
||||||
elif "combine_document_chain_path" in config:
|
elif "combine_document_chain_path" in config:
|
||||||
combine_documents_chain = load_chain(config.pop("combine_document_chain_path"))
|
combine_documents_chain = load_chain(config.pop("combine_document_chain_path"))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"One of `combine_document_chain` or "
|
"One of `combine_documents_chain` or "
|
||||||
"`combine_document_chain_path` must be present."
|
"`combine_documents_chain_path` must be present."
|
||||||
)
|
)
|
||||||
if "collapse_document_chain" in config:
|
|
||||||
|
if "collapse_documents_chain" in config:
|
||||||
|
collapse_document_chain_config = config.pop("collapse_documents_chain")
|
||||||
|
if collapse_document_chain_config is None:
|
||||||
|
collapse_documents_chain = None
|
||||||
|
else:
|
||||||
|
collapse_documents_chain = load_chain_from_config(
|
||||||
|
collapse_document_chain_config
|
||||||
|
)
|
||||||
|
elif "collapse_documents_chain_path" in config:
|
||||||
|
collapse_documents_chain = load_chain(
|
||||||
|
config.pop("collapse_documents_chain_path")
|
||||||
|
)
|
||||||
|
elif "collapse_document_chain" in config:
|
||||||
collapse_document_chain_config = config.pop("collapse_document_chain")
|
collapse_document_chain_config = config.pop("collapse_document_chain")
|
||||||
if collapse_document_chain_config is None:
|
if collapse_document_chain_config is None:
|
||||||
collapse_documents_chain = None
|
collapse_documents_chain = None
|
||||||
@ -136,15 +174,10 @@ def _load_map_reduce_documents_chain(
|
|||||||
collapse_documents_chain = load_chain(
|
collapse_documents_chain = load_chain(
|
||||||
config.pop("collapse_document_chain_path")
|
config.pop("collapse_document_chain_path")
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
collapse_documents_chain = None
|
return ReduceDocumentsChain(
|
||||||
reduce_documents_chain = ReduceDocumentsChain(
|
|
||||||
combine_documents_chain=combine_documents_chain,
|
combine_documents_chain=combine_documents_chain,
|
||||||
collapse_documents_chain=collapse_documents_chain,
|
collapse_documents_chain=collapse_documents_chain,
|
||||||
)
|
|
||||||
return MapReduceDocumentsChain(
|
|
||||||
llm_chain=llm_chain,
|
|
||||||
reduce_documents_chain=reduce_documents_chain,
|
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -497,6 +530,7 @@ type_to_loader_dict = {
|
|||||||
"qa_with_sources_chain": _load_qa_with_sources_chain,
|
"qa_with_sources_chain": _load_qa_with_sources_chain,
|
||||||
"stuff_documents_chain": _load_stuff_documents_chain,
|
"stuff_documents_chain": _load_stuff_documents_chain,
|
||||||
"map_reduce_documents_chain": _load_map_reduce_documents_chain,
|
"map_reduce_documents_chain": _load_map_reduce_documents_chain,
|
||||||
|
"reduce_documents_chain": _load_reduce_documents_chain,
|
||||||
"map_rerank_documents_chain": _load_map_rerank_documents_chain,
|
"map_rerank_documents_chain": _load_map_rerank_documents_chain,
|
||||||
"refine_documents_chain": _load_refine_documents_chain,
|
"refine_documents_chain": _load_refine_documents_chain,
|
||||||
"sql_database_chain": _load_sql_database_chain,
|
"sql_database_chain": _load_sql_database_chain,
|
||||||
|
Loading…
Reference in New Issue
Block a user