diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index bf19cab5..99fb2c1e 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -241,8 +241,9 @@ class AgentExecutor(Chain, BaseModel): return iterations < self.max_iterations def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]: - if self.verbose: - self.callback_manager.on_agent_finish(output, color="green") + self.callback_manager.on_agent_finish( + output, color="green", verbose=self.verbose + ) final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps @@ -272,35 +273,35 @@ class AgentExecutor(Chain, BaseModel): # Otherwise we lookup the tool if output.tool in name_to_tool_map: tool = name_to_tool_map[output.tool] - if self.verbose: - self.callback_manager.on_tool_start( - {"name": str(tool.func)[:60] + "..."}, output, color="green" - ) + self.callback_manager.on_tool_start( + {"name": str(tool.func)[:60] + "..."}, + output, + color="green", + verbose=self.verbose, + ) try: # We then call the tool on the tool input to get an observation observation = tool.func(output.tool_input) color = color_mapping[output.tool] return_direct = tool.return_direct except Exception as e: - if self.verbose: - self.callback_manager.on_tool_error(e) + self.callback_manager.on_tool_error(e, verbose=self.verbose) raise e else: - if self.verbose: - self.callback_manager.on_tool_start( - {"name": "N/A"}, output, color="green" - ) + self.callback_manager.on_tool_start( + {"name": "N/A"}, output, color="green", verbose=self.verbose + ) observation = f"{output.tool} is not a valid tool, try another one." color = None return_direct = False - if self.verbose: - llm_prefix = "" if return_direct else self.agent.llm_prefix - self.callback_manager.on_tool_end( - observation, - color=color, - observation_prefix=self.agent.observation_prefix, - llm_prefix=llm_prefix, - ) + llm_prefix = "" if return_direct else self.agent.llm_prefix + self.callback_manager.on_tool_end( + observation, + color=color, + observation_prefix=self.agent.observation_prefix, + llm_prefix=llm_prefix, + verbose=self.verbose, + ) intermediate_steps.append((output, observation)) if return_direct: # Set the log to "" because we do not want to log it. diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 6b3b4ca9..d04bc1a6 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -15,6 +15,11 @@ class BaseCallbackHandler(BaseModel, ABC): ignore_chain: bool = False ignore_agent: bool = False + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return False + @abstractmethod def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any @@ -22,14 +27,11 @@ class BaseCallbackHandler(BaseModel, ABC): """Run when LLM starts running.""" @abstractmethod - def on_llm_end( - self, - response: LLMResult, - ) -> None: + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" @abstractmethod - def on_llm_error(self, error: Exception) -> None: + def on_llm_error(self, error: Exception, **kwargs: Any) -> None: """Run when LLM errors.""" @abstractmethod @@ -39,11 +41,11 @@ class BaseCallbackHandler(BaseModel, ABC): """Run when chain starts running.""" @abstractmethod - def on_chain_end(self, outputs: Dict[str, Any]) -> None: + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" @abstractmethod - def on_chain_error(self, error: Exception) -> None: + def on_chain_error(self, error: Exception, **kwargs: Any) -> None: """Run when chain errors.""" @abstractmethod @@ -57,7 +59,7 @@ class BaseCallbackHandler(BaseModel, ABC): """Run when tool ends running.""" @abstractmethod - def on_tool_error(self, error: Exception) -> None: + def on_tool_error(self, error: Exception, **kwargs: Any) -> None: """Run when tool errors.""" @abstractmethod @@ -91,78 +93,110 @@ class CallbackManager(BaseCallbackManager): handlers: List[BaseCallbackHandler] def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, + serialized: Dict[str, Any], + prompts: List[str], + verbose: bool = False, + **kwargs: Any ) -> None: """Run when LLM starts running.""" for handler in self.handlers: if not handler.ignore_llm: - handler.on_llm_start(serialized, prompts, **kwargs) + if verbose or handler.always_verbose: + handler.on_llm_start(serialized, prompts, **kwargs) def on_llm_end( - self, - response: LLMResult, + self, response: LLMResult, verbose: bool = False, **kwargs: Any ) -> None: """Run when LLM ends running.""" for handler in self.handlers: if not handler.ignore_llm: - handler.on_llm_end(response) + if verbose or handler.always_verbose: + handler.on_llm_end(response) - def on_llm_error(self, error: Exception) -> None: + def on_llm_error( + self, error: Exception, verbose: bool = False, **kwargs: Any + ) -> None: """Run when LLM errors.""" for handler in self.handlers: if not handler.ignore_llm: - handler.on_llm_error(error) + if verbose or handler.always_verbose: + handler.on_llm_error(error) def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + verbose: bool = False, + **kwargs: Any ) -> None: """Run when chain starts running.""" for handler in self.handlers: if not handler.ignore_chain: - handler.on_chain_start(serialized, inputs, **kwargs) + if verbose or handler.always_verbose: + handler.on_chain_start(serialized, inputs, **kwargs) - def on_chain_end(self, outputs: Dict[str, Any]) -> None: + def on_chain_end( + self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any + ) -> None: """Run when chain ends running.""" for handler in self.handlers: if not handler.ignore_chain: - handler.on_chain_end(outputs) + if verbose or handler.always_verbose: + handler.on_chain_end(outputs) - def on_chain_error(self, error: Exception) -> None: + def on_chain_error( + self, error: Exception, verbose: bool = False, **kwargs: Any + ) -> None: """Run when chain errors.""" for handler in self.handlers: if not handler.ignore_chain: - handler.on_chain_error(error) + if verbose or handler.always_verbose: + handler.on_chain_error(error) def on_tool_start( - self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any + self, + serialized: Dict[str, Any], + action: AgentAction, + verbose: bool = False, + **kwargs: Any ) -> None: """Run when tool starts running.""" for handler in self.handlers: if not handler.ignore_agent: - handler.on_tool_start(serialized, action, **kwargs) + if verbose or handler.always_verbose: + handler.on_tool_start(serialized, action, **kwargs) - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None: """Run when tool ends running.""" for handler in self.handlers: if not handler.ignore_agent: - handler.on_tool_end(output, **kwargs) + if verbose or handler.always_verbose: + handler.on_tool_end(output, **kwargs) - def on_tool_error(self, error: Exception) -> None: + def on_tool_error( + self, error: Exception, verbose: bool = False, **kwargs: Any + ) -> None: """Run when tool errors.""" for handler in self.handlers: if not handler.ignore_agent: - handler.on_tool_error(error) + if verbose or handler.always_verbose: + handler.on_tool_error(error) - def on_text(self, text: str, **kwargs: Any) -> None: + def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: """Run on additional input from chains and agents.""" for handler in self.handlers: - handler.on_text(text, **kwargs) + if verbose or handler.always_verbose: + handler.on_text(text, **kwargs) - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + def on_agent_finish( + self, finish: AgentFinish, verbose: bool = False, **kwargs: Any + ) -> None: """Run on agent end.""" for handler in self.handlers: if not handler.ignore_agent: - handler.on_agent_finish(finish, **kwargs) + if verbose or handler.always_verbose: + handler.on_agent_finish(finish, **kwargs) def add_handler(self, handler: BaseCallbackHandler) -> None: """Add a handler to the callback manager.""" diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index 576f4609..3ec7a686 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -41,18 +41,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.on_llm_start(serialized, prompts, **kwargs) - def on_llm_end( - self, - response: LLMResult, - ) -> None: + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" with self._lock: - self._callback_manager.on_llm_end(response) + self._callback_manager.on_llm_end(response, **kwargs) - def on_llm_error(self, error: Exception) -> None: + def on_llm_error(self, error: Exception, **kwargs: Any) -> None: """Run when LLM errors.""" with self._lock: - self._callback_manager.on_llm_error(error) + self._callback_manager.on_llm_error(error, **kwargs) def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any @@ -61,15 +58,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.on_chain_start(serialized, inputs, **kwargs) - def on_chain_end(self, outputs: Dict[str, Any]) -> None: + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" with self._lock: - self._callback_manager.on_chain_end(outputs) + self._callback_manager.on_chain_end(outputs, **kwargs) - def on_chain_error(self, error: Exception) -> None: + def on_chain_error(self, error: Exception, **kwargs: Any) -> None: """Run when chain errors.""" with self._lock: - self._callback_manager.on_chain_error(error) + self._callback_manager.on_chain_error(error, **kwargs) def on_tool_start( self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any @@ -83,10 +80,10 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.on_tool_end(output, **kwargs) - def on_tool_error(self, error: Exception) -> None: + def on_tool_error(self, error: Exception, **kwargs: Any) -> None: """Run when tool errors.""" with self._lock: - self._callback_manager.on_tool_error(error) + self._callback_manager.on_tool_error(error, **kwargs) def on_text(self, text: str, **kwargs: Any) -> None: """Run on arbitrary text.""" diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py index 3e6cd281..ff8ea2f4 100644 --- a/langchain/callbacks/stdout.py +++ b/langchain/callbacks/stdout.py @@ -15,11 +15,11 @@ class StdOutCallbackHandler(BaseCallbackHandler): """Print out the prompts.""" pass - def on_llm_end(self, response: LLMResult) -> None: + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Do nothing.""" pass - def on_llm_error(self, error: Exception) -> None: + def on_llm_error(self, error: Exception, **kwargs: Any) -> None: """Do nothing.""" pass @@ -30,11 +30,11 @@ class StdOutCallbackHandler(BaseCallbackHandler): class_name = serialized["name"] print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") - def on_chain_end(self, outputs: Dict[str, Any]) -> None: + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain.""" print("\n\033[1m> Finished chain.\033[0m") - def on_chain_error(self, error: Exception) -> None: + def on_chain_error(self, error: Exception, **kwargs: Any) -> None: """Do nothing.""" pass @@ -61,7 +61,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): print_text(output, color=color) print_text(f"\n{llm_prefix}") - def on_tool_error(self, error: Exception) -> None: + def on_tool_error(self, error: Exception, **kwargs: Any) -> None: """Do nothing.""" pass diff --git a/langchain/callbacks/streamlit.py b/langchain/callbacks/streamlit.py index 2e781f61..99bc8d96 100644 --- a/langchain/callbacks/streamlit.py +++ b/langchain/callbacks/streamlit.py @@ -18,11 +18,11 @@ class StreamlitCallbackHandler(BaseCallbackHandler): for prompt in prompts: st.write(prompt) - def on_llm_end(self, response: LLMResult) -> None: + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Do nothing.""" pass - def on_llm_error(self, error: Exception) -> None: + def on_llm_error(self, error: Exception, **kwargs: Any) -> None: """Do nothing.""" pass @@ -33,11 +33,11 @@ class StreamlitCallbackHandler(BaseCallbackHandler): class_name = serialized["name"] st.write(f"Entering new {class_name} chain...") - def on_chain_end(self, outputs: Dict[str, Any]) -> None: + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain.""" st.write("Finished chain.") - def on_chain_error(self, error: Exception) -> None: + def on_chain_error(self, error: Exception, **kwargs: Any) -> None: """Do nothing.""" pass @@ -62,7 +62,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler): st.write(f"{observation_prefix}{output}") st.write(llm_prefix) - def on_tool_error(self, error: Exception) -> None: + def on_tool_error(self, error: Exception, **kwargs: Any) -> None: """Do nothing.""" pass diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 51838237..3bb5f917 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -66,11 +66,13 @@ class APIChain(Chain, BaseModel): api_url = self.api_request_chain.predict( question=question, api_docs=self.api_docs ) - if self.verbose: - self.callback_manager.on_text(api_url, color="green", end="\n") + self.callback_manager.on_text( + api_url, color="green", end="\n", verbose=self.verbose + ) api_response = self.requests_wrapper.run(api_url) - if self.verbose: - self.callback_manager.on_text(api_response, color="yellow", end="\n") + self.callback_manager.on_text( + api_response, color="yellow", end="\n", verbose=self.verbose + ) answer = self.api_answer_chain.predict( question=question, api_docs=self.api_docs, diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 83c1ea9f..b828b064 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -134,18 +134,17 @@ class Chain(BaseModel, ABC): external_context = self.memory.load_memory_variables(inputs) inputs = dict(inputs, **external_context) self._validate_inputs(inputs) - if self.verbose: - self.callback_manager.on_chain_start( - {"name": self.__class__.__name__}, inputs - ) + self.callback_manager.on_chain_start( + {"name": self.__class__.__name__}, + inputs, + verbose=self.verbose, + ) try: outputs = self._call(inputs) except Exception as e: - if self.verbose: - self.callback_manager.on_chain_error(e) + self.callback_manager.on_chain_error(e, verbose=self.verbose) raise e - if self.verbose: - self.callback_manager.on_chain_end(outputs) + self.callback_manager.on_chain_end(outputs, verbose=self.verbose) self._validate_outputs(outputs) if self.memory is not None: self.memory.save_context(inputs, outputs) diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 11c31638..9f713ed5 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -61,10 +61,9 @@ class LLMChain(Chain, BaseModel): for inputs in input_list: selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} prompt = self.prompt.format(**selected_inputs) - if self.verbose: - _colored_text = get_colored_text(prompt, "green") - _text = "Prompt after formatting:\n" + _colored_text - self.callback_manager.on_text(_text, end="\n") + _colored_text = get_colored_text(prompt, "green") + _text = "Prompt after formatting:\n" + _colored_text + self.callback_manager.on_text(_text, end="\n", verbose=self.verbose) if "stop" in inputs and inputs["stop"] != stop: raise ValueError( "If `stop` is present in any inputs, should be present in all." diff --git a/langchain/chains/llm_bash/__init__.py b/langchain/chains/llm_bash/__init__.py new file mode 100644 index 00000000..e1e848a1 --- /dev/null +++ b/langchain/chains/llm_bash/__init__.py @@ -0,0 +1 @@ +"""Chain that interprets a prompt and executes bash code to perform bash operations.""" diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 47a46333..9cc657ea 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -52,12 +52,10 @@ class LLMBashChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: llm_executor = LLMChain(prompt=self.prompt, llm=self.llm) bash_executor = BashProcess() - if self.verbose: - self.callback_manager.on_text(inputs[self.input_key]) + self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose) t = llm_executor.predict(question=inputs[self.input_key]) - if self.verbose: - self.callback_manager.on_text(t, color="green") + self.callback_manager.on_text(t, color="green", verbose=self.verbose) t = t.strip() if t.startswith("```bash"): @@ -69,9 +67,8 @@ class LLMBashChain(Chain, BaseModel): command_list = [s for s in command_list[1:-1]] output = bash_executor.run(command_list) - if self.verbose: - self.callback_manager.on_text("\nAnswer: ") - self.callback_manager.on_text(output, color="yellow") + self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) + self.callback_manager.on_text(output, color="yellow", verbose=self.verbose) else: raise ValueError(f"unknown format from LLM: {t}") diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index 1b6628f4..c169ade3 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -53,18 +53,15 @@ class LLMMathChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: llm_executor = LLMChain(prompt=self.prompt, llm=self.llm) python_executor = PythonREPL() - if self.verbose: - self.callback_manager.on_text(inputs[self.input_key]) + self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose) t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"]) - if self.verbose: - self.callback_manager.on_text(t, color="green") + self.callback_manager.on_text(t, color="green", verbose=self.verbose) t = t.strip() if t.startswith("```python"): code = t[9:-4] output = python_executor.run(code) - if self.verbose: - self.callback_manager.on_text("\nAnswer: ") - self.callback_manager.on_text(output, color="yellow") + self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) + self.callback_manager.on_text(output, color="yellow", verbose=self.verbose) answer = "Answer: " + output elif t.startswith("Answer:"): answer = t diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 858f6a1f..3b16ed86 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -51,8 +51,9 @@ class PALChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) code = llm_chain.predict(stop=[self.stop], **inputs) - if self.verbose: - self.callback_manager.on_text(code, color="green", end="\n") + self.callback_manager.on_text( + code, color="green", end="\n", verbose=self.verbose + ) repl = PythonREPL() res = repl.run(code + f"\n{self.get_answer_expr}") return {self.output_key: res.strip()} diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index a3ca8898..cee1a9b7 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -76,8 +76,6 @@ class SequentialChain(Chain, BaseModel): known_values = inputs.copy() for i, chain in enumerate(self.chains): outputs = chain(known_values, return_only_outputs=True) - if self.verbose: - print(f"\033[1mChain {i}\033[0m:\n{outputs}\n") known_values.update(outputs) return {k: known_values[k] for k in self.output_variables} @@ -135,8 +133,7 @@ class SimpleSequentialChain(Chain, BaseModel): _input = chain.run(_input) if self.strip_outputs: _input = _input.strip() - if self.verbose: - self.callback_manager.on_text( - _input, color=color_mapping[str(i)], end="\n" - ) + self.callback_manager.on_text( + _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose + ) return {self.output_key: _input} diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 41ef484b..10377800 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -60,8 +60,7 @@ class SQLDatabaseChain(Chain, BaseModel): def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) input_text = f"{inputs[self.input_key]} \nSQLQuery:" - if self.verbose: - self.callback_manager.on_text(input_text) + self.callback_manager.on_text(input_text, verbose=self.verbose) # If not present, then defaults to None which is all tables. table_names_to_use = inputs.get("table_names_to_use") table_info = self.database.get_table_info(table_names=table_names_to_use) @@ -74,18 +73,15 @@ class SQLDatabaseChain(Chain, BaseModel): } sql_cmd = llm_chain.predict(**llm_inputs) - if self.verbose: - self.callback_manager.on_text(sql_cmd, color="green") + self.callback_manager.on_text(sql_cmd, color="green", verbose=self.verbose) result = self.database.run(sql_cmd) - if self.verbose: - self.callback_manager.on_text("\nSQLResult: ") - self.callback_manager.on_text(result, color="yellow") - self.callback_manager.on_text("\nAnswer:") + self.callback_manager.on_text("\nSQLResult: ", verbose=self.verbose) + self.callback_manager.on_text(result, color="yellow", verbose=self.verbose) + self.callback_manager.on_text("\nAnswer:", verbose=self.verbose) input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" llm_inputs["input"] = input_text final_result = llm_chain.predict(**llm_inputs) - if self.verbose: - self.callback_manager.on_text(final_result, color="green") + self.callback_manager.on_text(final_result, color="green", verbose=self.verbose) return {self.output_key: final_result} @@ -146,9 +142,12 @@ class SQLDatabaseSequentialChain(Chain, BaseModel): "table_names": table_names, } table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs) - if self.verbose: - self.callback_manager.on_text("Table names to use:", end="\n") - self.callback_manager.on_text(str(table_names_to_use), color="yellow") + self.callback_manager.on_text( + "Table names to use:", end="\n", verbose=self.verbose + ) + self.callback_manager.on_text( + str(table_names_to_use), color="yellow", verbose=self.verbose + ) new_inputs = { self.sql_chain.input_key: inputs[self.input_key], "table_names_to_use": table_names_to_use, diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 97c062f3..ebb1c58b 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -69,18 +69,15 @@ class BaseLLM(BaseModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - if self.verbose: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts - ) + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, prompts, verbose=self.verbose + ) try: output = self._generate(prompts, stop=stop) except Exception as e: - if self.verbose: - self.callback_manager.on_llm_error(e) + self.callback_manager.on_llm_error(e, verbose=self.verbose) raise e - if self.verbose: - self.callback_manager.on_llm_end(output) + self.callback_manager.on_llm_end(output, verbose=self.verbose) return output params = self._llm_dict() params["stop"] = stop @@ -95,18 +92,15 @@ class BaseLLM(BaseModel, ABC): else: missing_prompts.append(prompt) missing_prompt_idxs.append(i) - if self.verbose: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, missing_prompts - ) + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose + ) try: new_results = self._generate(missing_prompts, stop=stop) except Exception as e: - if self.verbose: - self.callback_manager.on_llm_error(e) + self.callback_manager.on_llm_error(e, verbose=self.verbose) raise e - if self.verbose: - self.callback_manager.on_llm_end(new_results) + self.callback_manager.on_llm_end(new_results, verbose=self.verbose) for i, result in enumerate(new_results.generations): existing_prompts[missing_prompt_idxs[i]] = result prompt = prompts[missing_prompt_idxs[i]] diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 07cf52f5..896b05aa 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -19,14 +19,11 @@ class FakeCallbackHandler(BaseCallbackHandler): """Run when LLM starts running.""" self.starts += 1 - def on_llm_end( - self, - response: LLMResult, - ) -> None: + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" self.ends += 1 - def on_llm_error(self, error: Exception) -> None: + def on_llm_error(self, error: Exception, **kwargs: Any) -> None: """Run when LLM errors.""" self.errors += 1 @@ -36,11 +33,11 @@ class FakeCallbackHandler(BaseCallbackHandler): """Run when chain starts running.""" self.starts += 1 - def on_chain_end(self, outputs: Dict[str, Any]) -> None: + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" self.ends += 1 - def on_chain_error(self, error: Exception) -> None: + def on_chain_error(self, error: Exception, **kwargs: Any) -> None: """Run when chain errors.""" self.errors += 1 @@ -54,7 +51,7 @@ class FakeCallbackHandler(BaseCallbackHandler): """Run when tool ends running.""" self.ends += 1 - def on_tool_error(self, error: Exception) -> None: + def on_tool_error(self, error: Exception, **kwargs: Any) -> None: """Run when tool errors.""" self.errors += 1 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index c5c493ce..03d0181b 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -10,16 +10,16 @@ def _test_callback_manager( manager: BaseCallbackManager, *handlers: FakeCallbackHandler ) -> None: """Test the CallbackManager.""" - manager.on_llm_start({}, []) - manager.on_llm_end(LLMResult(generations=[])) - manager.on_llm_error(Exception()) - manager.on_chain_start({"name": "foo"}, {}) - manager.on_chain_end({}) - manager.on_chain_error(Exception()) - manager.on_tool_start({}, AgentAction("", "", "")) - manager.on_tool_end("") - manager.on_tool_error(Exception()) - manager.on_agent_finish(AgentFinish({}, "")) + manager.on_llm_start({}, [], verbose=True) + manager.on_llm_end(LLMResult(generations=[]), verbose=True) + manager.on_llm_error(Exception(), verbose=True) + manager.on_chain_start({"name": "foo"}, {}, verbose=True) + manager.on_chain_end({}, verbose=True) + manager.on_chain_error(Exception(), verbose=True) + manager.on_tool_start({}, AgentAction("", "", ""), verbose=True) + manager.on_tool_end("", verbose=True) + manager.on_tool_error(Exception(), verbose=True) + manager.on_agent_finish(AgentFinish({}, ""), verbose=True) for handler in handlers: assert handler.starts == 3 assert handler.ends == 4 @@ -39,9 +39,9 @@ def test_ignore_llm() -> None: handler1 = FakeCallbackHandler(ignore_llm=True) handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - manager.on_llm_start({}, []) - manager.on_llm_end(LLMResult(generations=[])) - manager.on_llm_error(Exception()) + manager.on_llm_start({}, [], verbose=True) + manager.on_llm_end(LLMResult(generations=[]), verbose=True) + manager.on_llm_error(Exception(), verbose=True) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 @@ -55,9 +55,9 @@ def test_ignore_chain() -> None: handler1 = FakeCallbackHandler(ignore_chain=True) handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - manager.on_chain_start({"name": "foo"}, {}) - manager.on_chain_end({}) - manager.on_chain_error(Exception()) + manager.on_chain_start({"name": "foo"}, {}, verbose=True) + manager.on_chain_end({}, verbose=True) + manager.on_chain_error(Exception(), verbose=True) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 @@ -71,10 +71,10 @@ def test_ignore_agent() -> None: handler1 = FakeCallbackHandler(ignore_agent=True) handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - manager.on_tool_start({}, AgentAction("", "", "")) - manager.on_tool_end("") - manager.on_tool_error(Exception()) - manager.on_agent_finish(AgentFinish({}, "")) + manager.on_tool_start({}, AgentAction("", "", ""), verbose=True) + manager.on_tool_end("", verbose=True) + manager.on_tool_error(Exception(), verbose=True) + manager.on_agent_finish(AgentFinish({}, ""), verbose=True) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0