From f81e613086d211327b67b0fb591fd4d5f9a85860 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Thu, 3 Aug 2023 15:02:16 -0700 Subject: [PATCH] Fix Async Retry Event Handling (#8659) It fails currently because the event loop is already running. The `retry` decorator alraedy infers an `AsyncRetrying` handler for coroutines (see [tenacity line](https://github.com/jd/tenacity/blob/aa6f8f0a2428de696b237d1a86bc131c1cdb707a/tenacity/__init__.py#L535)) However before_sleep always gets called synchronously (see [tenacity line](https://github.com/jd/tenacity/blob/aa6f8f0a2428de696b237d1a86bc131c1cdb707a/tenacity/__init__.py#L338)). Instead, check for a running loop and use that it exists. Of course, it's running an async method synchronously which is not _nice_. Given how important LLMs are, it may make sense to have a task list or something but I'd want to chat with @nfcampos on where that would live. This PR also fixes the unit tests to check the handler is called and to make sure the async test is run (it looks like it's just been being skipped). It would have failed prior to the proposed fixes but passes now. --- libs/langchain/langchain/llms/base.py | 6 +- .../integration_tests/llms/test_openai.py | 64 +------------------ .../callbacks/fake_callback_handler.py | 7 ++ .../tests/unit_tests/llms/test_openai.py | 37 +++++++++-- 4 files changed, 43 insertions(+), 71 deletions(-) diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 8ccaf0deb6..7da494de78 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -91,7 +91,11 @@ def create_base_retry_decorator( if isinstance(run_manager, AsyncCallbackManagerForLLMRun): coro = run_manager.on_retry(retry_state) try: - asyncio.run(coro) + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(coro) + else: + asyncio.run(coro) except Exception as e: _log_error_once(f"Error in on_retry: {e}") else: diff --git a/libs/langchain/tests/integration_tests/llms/test_openai.py b/libs/langchain/tests/integration_tests/llms/test_openai.py index 6b584ae154..ca8911078a 100644 --- a/libs/langchain/tests/integration_tests/llms/test_openai.py +++ b/libs/langchain/tests/integration_tests/llms/test_openai.py @@ -1,7 +1,6 @@ """Test OpenAI API wrapper.""" from pathlib import Path -from typing import Any, Generator -from unittest.mock import MagicMock, patch +from typing import Generator import pytest @@ -11,7 +10,6 @@ from langchain.llms.loading import load_llm from langchain.llms.openai import OpenAI, OpenAIChat from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import ( - FakeAsyncCallbackHandler, FakeCallbackHandler, ) @@ -351,63 +349,3 @@ def mock_completion() -> dict: ], "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, } - - -@pytest.mark.requires("openai") -def test_openai_retries(mock_completion: dict) -> None: - llm = OpenAI() - mock_client = MagicMock() - completed = False - raised = False - import openai - - def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed, raised - if not raised: - raised = True - raise openai.error.APIError - completed = True - return mock_completion - - mock_client.create = raise_once - callback_handler = FakeCallbackHandler() - with patch.object( - llm, - "client", - mock_client, - ): - res = llm.predict("bar", callbacks=[callback_handler]) - assert res == "Bar Baz" - assert completed - assert raised - assert callback_handler.retries == 1 - - -@pytest.mark.requires("openai") -async def test_openai_async_retries(mock_completion: dict) -> None: - llm = OpenAI() - mock_client = MagicMock() - completed = False - raised = False - import openai - - def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed, raised - if not raised: - raised = True - raise openai.error.APIError - completed = True - return mock_completion - - mock_client.create = raise_once - callback_handler = FakeAsyncCallbackHandler() - with patch.object( - llm, - "client", - mock_client, - ): - res = llm.apredict("bar", callbacks=[callback_handler]) - assert res == "Bar Baz" - assert completed - assert raised - assert callback_handler.retries == 1 diff --git a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py index 87b56a9bff..f4819c6930 100644 --- a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py @@ -290,6 +290,13 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi """Whether to ignore agent callbacks.""" return self.ignore_agent_ + async def on_retry( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retry_common() + async def on_llm_start( self, *args: Any, diff --git a/libs/langchain/tests/unit_tests/llms/test_openai.py b/libs/langchain/tests/unit_tests/llms/test_openai.py index cc0fc74c1f..54750a9592 100644 --- a/libs/langchain/tests/unit_tests/llms/test_openai.py +++ b/libs/langchain/tests/unit_tests/llms/test_openai.py @@ -1,3 +1,4 @@ +import asyncio import os from typing import Any from unittest.mock import MagicMock, patch @@ -5,6 +6,10 @@ from unittest.mock import MagicMock, patch import pytest from langchain.llms.openai import OpenAI +from tests.unit_tests.callbacks.fake_callback_handler import ( + FakeAsyncCallbackHandler, + FakeCallbackHandler, +) os.environ["OPENAI_API_KEY"] = "foo" @@ -45,44 +50,62 @@ def mock_completion() -> dict: @pytest.mark.requires("openai") -def test_openai_calls(mock_completion: dict) -> None: +def test_openai_retries(mock_completion: dict) -> None: llm = OpenAI() mock_client = MagicMock() completed = False + raised = False + import openai def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed + nonlocal completed, raised + if not raised: + raised = True + raise openai.error.APIError completed = True return mock_completion mock_client.create = raise_once + callback_handler = FakeCallbackHandler() with patch.object( llm, "client", mock_client, ): - res = llm.predict("bar") + res = llm.predict("bar", callbacks=[callback_handler]) assert res == "Bar Baz" assert completed + assert raised + assert callback_handler.retries == 1 @pytest.mark.requires("openai") +@pytest.mark.asyncio async def test_openai_async_retries(mock_completion: dict) -> None: llm = OpenAI() mock_client = MagicMock() completed = False + raised = False + import openai - def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed + async def araise_once(*args: Any, **kwargs: Any) -> Any: + nonlocal completed, raised + if not raised: + raised = True + raise openai.error.APIError + await asyncio.sleep(0) completed = True return mock_completion - mock_client.create = raise_once + mock_client.acreate = araise_once + callback_handler = FakeAsyncCallbackHandler() with patch.object( llm, "client", mock_client, ): - res = llm.apredict("bar") + res = await llm.apredict("bar", callbacks=[callback_handler]) assert res == "Bar Baz" assert completed + assert raised + assert callback_handler.retries == 1