diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 1127b55b7d..1a0f6ac0a0 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -477,6 +477,22 @@ class RunManager(BaseRunManager): **kwargs, ) + def on_retry( + self, + retry_state: RetryCallState, + **kwargs: Any, + ) -> None: + _handle_event( + self.handlers, + "on_retry", + "ignore_retry", + retry_state, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + class ParentRunManager(RunManager): """Sync Parent Run Manager.""" @@ -527,6 +543,22 @@ class AsyncRunManager(BaseRunManager): **kwargs, ) + async def on_retry( + self, + retry_state: RetryCallState, + **kwargs: Any, + ) -> None: + await _ahandle_event( + self.handlers, + "on_retry", + "ignore_retry", + retry_state, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + class AsyncParentRunManager(AsyncRunManager): """Async Parent Run Manager.""" @@ -574,22 +606,6 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): **kwargs, ) - def on_retry( - self, - retry_state: RetryCallState, - **kwargs: Any, - ) -> None: - _handle_event( - self.handlers, - "on_retry", - "ignore_retry", - retry_state, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running. @@ -653,22 +669,6 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): **kwargs, ) - async def on_retry( - self, - retry_state: RetryCallState, - **kwargs: Any, - ) -> None: - await _ahandle_event( - self.handlers, - "on_retry", - "ignore_retry", - retry_state, - run_id=self.run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running. diff --git a/libs/langchain/langchain/callbacks/tracers/base.py b/libs/langchain/langchain/callbacks/tracers/base.py index cc37bb47e8..b1244ff412 100644 --- a/libs/langchain/langchain/callbacks/tracers/base.py +++ b/libs/langchain/langchain/callbacks/tracers/base.py @@ -151,8 +151,8 @@ class BaseTracer(BaseCallbackHandler, ABC): raise TracerException("No run_id provided for on_retry callback.") run_id_ = str(run_id) llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != RunTypeEnum.llm: - raise TracerException("No LLM Run found to be traced for on_retry") + if llm_run is None: + raise TracerException("No Run found to be traced for on_retry") retry_d: Dict[str, Any] = { "slept": retry_state.idle_for, "attempt": retry_state.attempt_number,