From b0097f8908a3cf30778029b4ba2cfb6aaeea733a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 4 Oct 2023 05:04:59 +0100 Subject: [PATCH] =?UTF-8?q?In=20ProgressBarCallback=20update=20the=20progr?= =?UTF-8?q?ess=20counter=20also=20when=20runs=20fin=E2=80=A6=20(#11332)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langchain/smith/evaluation/progress.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/libs/langchain/langchain/smith/evaluation/progress.py b/libs/langchain/langchain/smith/evaluation/progress.py index 1ea51eee42..a0f8c4fc4c 100644 --- a/libs/langchain/langchain/smith/evaluation/progress.py +++ b/libs/langchain/langchain/smith/evaluation/progress.py @@ -37,6 +37,17 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler): spaces = " " * (self.ncols - len(arrow)) print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end="") + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + if parent_run_id is None: + self.increment() + def on_chain_end( self, outputs: Dict[str, Any], @@ -48,6 +59,17 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler): if parent_run_id is None: self.increment() + def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + if parent_run_id is None: + self.increment() + def on_retriever_end( self, documents: Sequence[Document], @@ -59,6 +81,17 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler): if parent_run_id is None: self.increment() + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + if parent_run_id is None: + self.increment() + def on_llm_end( self, response: LLMResult, @@ -70,6 +103,17 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler): if parent_run_id is None: self.increment() + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + if parent_run_id is None: + self.increment() + def on_tool_end( self, output: str,