From e88e66f9822ee2d3813cf5c638bbab9895774a10 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Thu, 29 Dec 2022 05:22:31 -0800 Subject: [PATCH] Pass verbose argument to LLMChains when using *DocumentsChain (#458) When using chains such as Summarization chain (`load_summarize_chain`), the verbose flag wasn't propagated to the `LLMChain`. --- langchain/chains/qa_with_sources/__init__.py | 23 +++++++++++------- .../chains/question_answering/__init__.py | 23 +++++++++++------- langchain/chains/summarize/__init__.py | 24 ++++++++++++------- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/langchain/chains/qa_with_sources/__init__.py b/langchain/chains/qa_with_sources/__init__.py index 763a35d5af..c3a3fa2b06 100644 --- a/langchain/chains/qa_with_sources/__init__.py +++ b/langchain/chains/qa_with_sources/__init__.py @@ -26,9 +26,10 @@ def _load_stuff_chain( llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "summaries", + verbose: bool = False, **kwargs: Any, ) -> StuffDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=prompt) + llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) return StuffDocumentsChain( llm_chain=llm_chain, document_variable_name=document_variable_name, @@ -47,11 +48,12 @@ def _load_map_reduce_chain( collapse_prompt: Optional[BasePromptTemplate] = None, reduce_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None, + verbose: bool = False, **kwargs: Any, ) -> MapReduceDocumentsChain: - map_chain = LLMChain(llm=llm, prompt=question_prompt) + map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) _reduce_llm = reduce_llm or llm - reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt) + reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) combine_document_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name, @@ -67,7 +69,11 @@ def _load_map_reduce_chain( else: _collapse_llm = collapse_llm or llm collapse_chain = StuffDocumentsChain( - llm_chain=LLMChain(llm=_collapse_llm, prompt=collapse_prompt), + llm_chain=LLMChain( + llm=_collapse_llm, + prompt=collapse_prompt, + verbose=verbose, + ), document_variable_name=combine_document_variable_name, document_prompt=document_prompt, ) @@ -88,11 +94,12 @@ def _load_refine_chain( document_variable_name: str = "context_str", initial_response_name: str = "existing_answer", refine_llm: Optional[BaseLLM] = None, + verbose: bool = False, **kwargs: Any, ) -> RefineDocumentsChain: - initial_chain = LLMChain(llm=llm, prompt=question_prompt) + initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) _refine_llm = refine_llm or llm - refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt) + refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) return RefineDocumentsChain( initial_llm_chain=initial_chain, refine_llm_chain=refine_chain, @@ -104,7 +111,7 @@ def _load_refine_chain( def load_qa_with_sources_chain( - llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any + llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any ) -> BaseCombineDocumentsChain: """Load question answering with sources chain. @@ -127,4 +134,4 @@ def load_qa_with_sources_chain( f"Should be one of {loader_mapping.keys()}" ) _func: LoadingCallable = loader_mapping[chain_type] - return _func(llm, **kwargs) + return _func(llm, verbose=verbose, **kwargs) diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 350c96918a..f01bd0c696 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -26,9 +26,10 @@ def _load_stuff_chain( llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "context", + verbose: bool = False, **kwargs: Any, ) -> StuffDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=prompt) + llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # TODO: document prompt return StuffDocumentsChain( llm_chain=llm_chain, document_variable_name=document_variable_name, **kwargs @@ -44,11 +45,12 @@ def _load_map_reduce_chain( collapse_prompt: Optional[BasePromptTemplate] = None, reduce_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None, + verbose: bool = False, **kwargs: Any, ) -> MapReduceDocumentsChain: - map_chain = LLMChain(llm=llm, prompt=question_prompt) + map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) _reduce_llm = reduce_llm or llm - reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt) + reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) # TODO: document prompt combine_document_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name @@ -63,7 +65,11 @@ def _load_map_reduce_chain( else: _collapse_llm = collapse_llm or llm collapse_chain = StuffDocumentsChain( - llm_chain=LLMChain(llm=_collapse_llm, prompt=collapse_prompt), + llm_chain=LLMChain( + llm=_collapse_llm, + prompt=collapse_prompt, + verbose=verbose, + ), document_variable_name=combine_document_variable_name, ) return MapReduceDocumentsChain( @@ -82,11 +88,12 @@ def _load_refine_chain( document_variable_name: str = "context_str", initial_response_name: str = "existing_answer", refine_llm: Optional[BaseLLM] = None, + verbose: bool = False, **kwargs: Any, ) -> RefineDocumentsChain: - initial_chain = LLMChain(llm=llm, prompt=question_prompt) + initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) _refine_llm = refine_llm or llm - refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt) + refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) return RefineDocumentsChain( initial_llm_chain=initial_chain, refine_llm_chain=refine_chain, @@ -97,7 +104,7 @@ def _load_refine_chain( def load_qa_chain( - llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any + llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any ) -> BaseCombineDocumentsChain: """Load question answering chain. @@ -119,4 +126,4 @@ def load_qa_chain( f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}" ) - return loader_mapping[chain_type](llm, **kwargs) + return loader_mapping[chain_type](llm, verbose=verbose, **kwargs) diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index cecb3de701..68621e2728 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -22,9 +22,10 @@ def _load_stuff_chain( llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "text", + verbose: bool = False, **kwargs: Any, ) -> StuffDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=prompt) + llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # TODO: document prompt return StuffDocumentsChain( llm_chain=llm_chain, document_variable_name=document_variable_name, **kwargs @@ -40,11 +41,12 @@ def _load_map_reduce_chain( collapse_prompt: Optional[BasePromptTemplate] = None, reduce_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None, + verbose: bool = False, **kwargs: Any, ) -> MapReduceDocumentsChain: - map_chain = LLMChain(llm=llm, prompt=map_prompt) + map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose) _reduce_llm = reduce_llm or llm - reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt) + reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) # TODO: document prompt combine_document_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name @@ -59,7 +61,11 @@ def _load_map_reduce_chain( else: _collapse_llm = collapse_llm or llm collapse_chain = StuffDocumentsChain( - llm_chain=LLMChain(llm=_collapse_llm, prompt=collapse_prompt), + llm_chain=LLMChain( + llm=_collapse_llm, + prompt=collapse_prompt, + verbose=verbose, + ), document_variable_name=combine_document_variable_name, ) return MapReduceDocumentsChain( @@ -78,11 +84,13 @@ def _load_refine_chain( document_variable_name: str = "text", initial_response_name: str = "existing_answer", refine_llm: Optional[BaseLLM] = None, + verbose: bool = False, **kwargs: Any, ) -> RefineDocumentsChain: - initial_chain = LLMChain(llm=llm, prompt=question_prompt) + + initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) _refine_llm = refine_llm or llm - refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt) + refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) return RefineDocumentsChain( initial_llm_chain=initial_chain, refine_llm_chain=refine_chain, @@ -93,7 +101,7 @@ def _load_refine_chain( def load_summarize_chain( - llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any + llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any ) -> BaseCombineDocumentsChain: """Load summarizing chain. @@ -115,4 +123,4 @@ def load_summarize_chain( f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}" ) - return loader_mapping[chain_type](llm, **kwargs) + return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)