In ProgressBarCallback update the progress counter also when runs fin… (#11332)

pull/11356/head^2
Nuno Campos 10 months ago committed by GitHub
parent 06f39be1c2
commit b0097f8908
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,6 +37,17 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
spaces = " " * (self.ncols - len(arrow))
print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end="")
def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if parent_run_id is None:
self.increment()
def on_chain_end(
self,
outputs: Dict[str, Any],
@ -48,6 +59,17 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
if parent_run_id is None:
self.increment()
def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if parent_run_id is None:
self.increment()
def on_retriever_end(
self,
documents: Sequence[Document],
@ -59,6 +81,17 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
if parent_run_id is None:
self.increment()
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if parent_run_id is None:
self.increment()
def on_llm_end(
self,
response: LLMResult,
@ -70,6 +103,17 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
if parent_run_id is None:
self.increment()
def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if parent_run_id is None:
self.increment()
def on_tool_end(
self,
output: str,

Loading…
Cancel
Save