From c478fc208ed4c29e979abeb7a532eb4d01431e1b Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 14 Aug 2023 16:45:17 -0700 Subject: [PATCH] Default On Retry (#9230) Base callbacks don't have a default on retry event Fix #8542 --------- Co-authored-by: landonsilla --- libs/langchain/langchain/callbacks/base.py | 22 +++++++++++++++++++ .../unit_tests/callbacks/test_openai_info.py | 7 ++++++ 2 files changed, 29 insertions(+) diff --git a/libs/langchain/langchain/callbacks/base.py b/libs/langchain/langchain/callbacks/base.py index 71025955f8..c03633e2d4 100644 --- a/libs/langchain/langchain/callbacks/base.py +++ b/libs/langchain/langchain/callbacks/base.py @@ -4,6 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from uuid import UUID +from tenacity import RetryCallState + if TYPE_CHECKING: from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.document import Document @@ -222,6 +224,16 @@ class RunManagerMixin: ) -> Any: """Run on arbitrary text.""" + def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on a retry event.""" + class BaseCallbackHandler( LLMManagerMixin, @@ -414,6 +426,16 @@ class AsyncCallbackHandler(BaseCallbackHandler): ) -> None: """Run on arbitrary text.""" + async def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on a retry event.""" + async def on_agent_action( self, action: AgentAction, diff --git a/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py b/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py index 1fa62a74ad..4580ccd488 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py @@ -1,3 +1,6 @@ +from unittest.mock import MagicMock +from uuid import uuid4 + import pytest from langchain.callbacks import OpenAICallbackHandler @@ -124,3 +127,7 @@ def test_on_llm_end_no_cost_invalid_model( ) handler.on_llm_end(response) assert handler.total_cost == 0 + + +def test_on_retry_works(handler: OpenAICallbackHandler) -> None: + handler.on_retry(MagicMock(), run_id=uuid4())