diff --git a/libs/langchain/langchain/callbacks/base.py b/libs/langchain/langchain/callbacks/base.py index 71025955f8..c03633e2d4 100644 --- a/libs/langchain/langchain/callbacks/base.py +++ b/libs/langchain/langchain/callbacks/base.py @@ -4,6 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from uuid import UUID +from tenacity import RetryCallState + if TYPE_CHECKING: from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.document import Document @@ -222,6 +224,16 @@ class RunManagerMixin: ) -> Any: """Run on arbitrary text.""" + def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on a retry event.""" + class BaseCallbackHandler( LLMManagerMixin, @@ -414,6 +426,16 @@ class AsyncCallbackHandler(BaseCallbackHandler): ) -> None: """Run on arbitrary text.""" + async def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on a retry event.""" + async def on_agent_action( self, action: AgentAction, diff --git a/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py b/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py index 1fa62a74ad..4580ccd488 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_openai_info.py @@ -1,3 +1,6 @@ +from unittest.mock import MagicMock +from uuid import uuid4 + import pytest from langchain.callbacks import OpenAICallbackHandler @@ -124,3 +127,7 @@ def test_on_llm_end_no_cost_invalid_model( ) handler.on_llm_end(response) assert handler.total_cost == 0 + + +def test_on_retry_works(handler: OpenAICallbackHandler) -> None: + handler.on_retry(MagicMock(), run_id=uuid4())