Fix Async Retry Event Handling (#8659)

It fails currently because the event loop is already running.

The `retry` decorator alraedy infers an `AsyncRetrying` handler for
coroutines (see [tenacity
line](aa6f8f0a24/tenacity/__init__.py (L535)))
However before_sleep always gets called synchronously (see [tenacity
line](aa6f8f0a24/tenacity/__init__.py (L338))).


Instead, check for a running loop and use that it exists. Of course,
it's running an async method synchronously which is not _nice_. Given
how important LLMs are, it may make sense to have a task list or
something but I'd want to chat with @nfcampos on where that would live.

This PR also fixes the unit tests to check the handler is called and to
make sure the async test is run (it looks like it's just been being
skipped). It would have failed prior to the proposed fixes but passes
now.
This commit is contained in:
William FH 2023-08-03 15:02:16 -07:00 committed by GitHub
parent 8ef7e14a85
commit f81e613086
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 71 deletions

View File

@ -91,6 +91,10 @@ def create_base_retry_decorator(
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(coro)
else:
asyncio.run(coro)
except Exception as e:
_log_error_once(f"Error in on_retry: {e}")

View File

@ -1,7 +1,6 @@
"""Test OpenAI API wrapper."""
from pathlib import Path
from typing import Any, Generator
from unittest.mock import MagicMock, patch
from typing import Generator
import pytest
@ -11,7 +10,6 @@ from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI, OpenAIChat
from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
@ -351,63 +349,3 @@ def mock_completion() -> dict:
],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
@pytest.mark.requires("openai")
def test_openai_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
completed = True
return mock_completion
mock_client.create = raise_once
callback_handler = FakeCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1
@pytest.mark.requires("openai")
async def test_openai_async_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
completed = True
return mock_completion
mock_client.create = raise_once
callback_handler = FakeAsyncCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.apredict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1

View File

@ -290,6 +290,13 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
async def on_retry(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retry_common()
async def on_llm_start(
self,
*args: Any,

View File

@ -1,3 +1,4 @@
import asyncio
import os
from typing import Any
from unittest.mock import MagicMock, patch
@ -5,6 +6,10 @@ from unittest.mock import MagicMock, patch
import pytest
from langchain.llms.openai import OpenAI
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
os.environ["OPENAI_API_KEY"] = "foo"
@ -45,44 +50,62 @@ def mock_completion() -> dict:
@pytest.mark.requires("openai")
def test_openai_calls(mock_completion: dict) -> None:
def test_openai_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
completed = True
return mock_completion
mock_client.create = raise_once
callback_handler = FakeCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar")
res = llm.predict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1
@pytest.mark.requires("openai")
@pytest.mark.asyncio
async def test_openai_async_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
async def araise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
await asyncio.sleep(0)
completed = True
return mock_completion
mock_client.create = raise_once
mock_client.acreate = araise_once
callback_handler = FakeAsyncCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.apredict("bar")
res = await llm.apredict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1