Pass verbose argument to LLMChains when using *DocumentsChain (#458)

When using chains such as Summarization chain (`load_summarize_chain`),
the verbose flag wasn't propagated to the `LLMChain`.
This commit is contained in:
Parth Chadha 2022-12-29 05:22:31 -08:00 committed by GitHub
parent d0f194de73
commit e88e66f982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 24 deletions

View File

@ -26,9 +26,10 @@ def _load_stuff_chain(
llm: BaseLLM, llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "summaries", document_variable_name: str = "summaries",
verbose: bool = False,
**kwargs: Any, **kwargs: Any,
) -> StuffDocumentsChain: ) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt) llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
return StuffDocumentsChain( return StuffDocumentsChain(
llm_chain=llm_chain, llm_chain=llm_chain,
document_variable_name=document_variable_name, document_variable_name=document_variable_name,
@ -47,11 +48,12 @@ def _load_map_reduce_chain(
collapse_prompt: Optional[BasePromptTemplate] = None, collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None, reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None,
verbose: bool = False,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt) map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_reduce_llm = reduce_llm or llm _reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt) reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
combine_document_chain = StuffDocumentsChain( combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain, llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
@ -67,7 +69,11 @@ def _load_map_reduce_chain(
else: else:
_collapse_llm = collapse_llm or llm _collapse_llm = collapse_llm or llm
collapse_chain = StuffDocumentsChain( collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=_collapse_llm, prompt=collapse_prompt), llm_chain=LLMChain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose,
),
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
document_prompt=document_prompt, document_prompt=document_prompt,
) )
@ -88,11 +94,12 @@ def _load_refine_chain(
document_variable_name: str = "context_str", document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer", initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None, refine_llm: Optional[BaseLLM] = None,
verbose: bool = False,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt) initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_refine_llm = refine_llm or llm _refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt) refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
return RefineDocumentsChain( return RefineDocumentsChain(
initial_llm_chain=initial_chain, initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain, refine_llm_chain=refine_chain,
@ -104,7 +111,7 @@ def _load_refine_chain(
def load_qa_with_sources_chain( def load_qa_with_sources_chain(
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load question answering with sources chain. """Load question answering with sources chain.
@ -127,4 +134,4 @@ def load_qa_with_sources_chain(
f"Should be one of {loader_mapping.keys()}" f"Should be one of {loader_mapping.keys()}"
) )
_func: LoadingCallable = loader_mapping[chain_type] _func: LoadingCallable = loader_mapping[chain_type]
return _func(llm, **kwargs) return _func(llm, verbose=verbose, **kwargs)

View File

@ -26,9 +26,10 @@ def _load_stuff_chain(
llm: BaseLLM, llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "context", document_variable_name: str = "context",
verbose: bool = False,
**kwargs: Any, **kwargs: Any,
) -> StuffDocumentsChain: ) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt) llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
# TODO: document prompt # TODO: document prompt
return StuffDocumentsChain( return StuffDocumentsChain(
llm_chain=llm_chain, document_variable_name=document_variable_name, **kwargs llm_chain=llm_chain, document_variable_name=document_variable_name, **kwargs
@ -44,11 +45,12 @@ def _load_map_reduce_chain(
collapse_prompt: Optional[BasePromptTemplate] = None, collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None, reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None,
verbose: bool = False,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt) map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_reduce_llm = reduce_llm or llm _reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt) reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
# TODO: document prompt # TODO: document prompt
combine_document_chain = StuffDocumentsChain( combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
@ -63,7 +65,11 @@ def _load_map_reduce_chain(
else: else:
_collapse_llm = collapse_llm or llm _collapse_llm = collapse_llm or llm
collapse_chain = StuffDocumentsChain( collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=_collapse_llm, prompt=collapse_prompt), llm_chain=LLMChain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose,
),
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
) )
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
@ -82,11 +88,12 @@ def _load_refine_chain(
document_variable_name: str = "context_str", document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer", initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None, refine_llm: Optional[BaseLLM] = None,
verbose: bool = False,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt) initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_refine_llm = refine_llm or llm _refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt) refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
return RefineDocumentsChain( return RefineDocumentsChain(
initial_llm_chain=initial_chain, initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain, refine_llm_chain=refine_chain,
@ -97,7 +104,7 @@ def _load_refine_chain(
def load_qa_chain( def load_qa_chain(
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load question answering chain. """Load question answering chain.
@ -119,4 +126,4 @@ def load_qa_chain(
f"Got unsupported chain type: {chain_type}. " f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}" f"Should be one of {loader_mapping.keys()}"
) )
return loader_mapping[chain_type](llm, **kwargs) return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)

View File

@ -22,9 +22,10 @@ def _load_stuff_chain(
llm: BaseLLM, llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "text", document_variable_name: str = "text",
verbose: bool = False,
**kwargs: Any, **kwargs: Any,
) -> StuffDocumentsChain: ) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt) llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
# TODO: document prompt # TODO: document prompt
return StuffDocumentsChain( return StuffDocumentsChain(
llm_chain=llm_chain, document_variable_name=document_variable_name, **kwargs llm_chain=llm_chain, document_variable_name=document_variable_name, **kwargs
@ -40,11 +41,12 @@ def _load_map_reduce_chain(
collapse_prompt: Optional[BasePromptTemplate] = None, collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None, reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None,
verbose: bool = False,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=map_prompt) map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
_reduce_llm = reduce_llm or llm _reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt) reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
# TODO: document prompt # TODO: document prompt
combine_document_chain = StuffDocumentsChain( combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
@ -59,7 +61,11 @@ def _load_map_reduce_chain(
else: else:
_collapse_llm = collapse_llm or llm _collapse_llm = collapse_llm or llm
collapse_chain = StuffDocumentsChain( collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=_collapse_llm, prompt=collapse_prompt), llm_chain=LLMChain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose,
),
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
) )
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
@ -78,11 +84,13 @@ def _load_refine_chain(
document_variable_name: str = "text", document_variable_name: str = "text",
initial_response_name: str = "existing_answer", initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None, refine_llm: Optional[BaseLLM] = None,
verbose: bool = False,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt)
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_refine_llm = refine_llm or llm _refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt) refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
return RefineDocumentsChain( return RefineDocumentsChain(
initial_llm_chain=initial_chain, initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain, refine_llm_chain=refine_chain,
@ -93,7 +101,7 @@ def _load_refine_chain(
def load_summarize_chain( def load_summarize_chain(
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load summarizing chain. """Load summarizing chain.
@ -115,4 +123,4 @@ def load_summarize_chain(
f"Got unsupported chain type: {chain_type}. " f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}" f"Should be one of {loader_mapping.keys()}"
) )
return loader_mapping[chain_type](llm, **kwargs) return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)