mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Fix Async Retry Event Handling (#8659)
It fails currently because the event loop is already running. The `retry` decorator alraedy infers an `AsyncRetrying` handler for coroutines (see [tenacity line](aa6f8f0a24/tenacity/__init__.py (L535)
)) However before_sleep always gets called synchronously (see [tenacity line](aa6f8f0a24/tenacity/__init__.py (L338)
)). Instead, check for a running loop and use that it exists. Of course, it's running an async method synchronously which is not _nice_. Given how important LLMs are, it may make sense to have a task list or something but I'd want to chat with @nfcampos on where that would live. This PR also fixes the unit tests to check the handler is called and to make sure the async test is run (it looks like it's just been being skipped). It would have failed prior to the proposed fixes but passes now.
This commit is contained in:
parent
8ef7e14a85
commit
f81e613086
@ -91,6 +91,10 @@ def create_base_retry_decorator(
|
||||
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
|
||||
coro = run_manager.on_retry(retry_state)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop.create_task(coro)
|
||||
else:
|
||||
asyncio.run(coro)
|
||||
except Exception as e:
|
||||
_log_error_once(f"Error in on_retry: {e}")
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""Test OpenAI API wrapper."""
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
@ -11,7 +10,6 @@ from langchain.llms.loading import load_llm
|
||||
from langchain.llms.openai import OpenAI, OpenAIChat
|
||||
from langchain.schema import LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
|
||||
@ -351,63 +349,3 @@ def mock_completion() -> dict:
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_retries(mock_completion: dict) -> None:
|
||||
llm = OpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
raised = False
|
||||
import openai
|
||||
|
||||
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed, raised
|
||||
if not raised:
|
||||
raised = True
|
||||
raise openai.error.APIError
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = raise_once
|
||||
callback_handler = FakeCallbackHandler()
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
assert raised
|
||||
assert callback_handler.retries == 1
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
async def test_openai_async_retries(mock_completion: dict) -> None:
|
||||
llm = OpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
raised = False
|
||||
import openai
|
||||
|
||||
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed, raised
|
||||
if not raised:
|
||||
raised = True
|
||||
raise openai.error.APIError
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = raise_once
|
||||
callback_handler = FakeAsyncCallbackHandler()
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.apredict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
assert raised
|
||||
assert callback_handler.retries == 1
|
||||
|
@ -290,6 +290,13 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
async def on_retry(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retry_common()
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
*args: Any,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
@ -5,6 +6,10 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from langchain.llms.openai import OpenAI
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "foo"
|
||||
|
||||
@ -45,44 +50,62 @@ def mock_completion() -> dict:
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_calls(mock_completion: dict) -> None:
|
||||
def test_openai_retries(mock_completion: dict) -> None:
|
||||
llm = OpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
raised = False
|
||||
import openai
|
||||
|
||||
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
nonlocal completed, raised
|
||||
if not raised:
|
||||
raised = True
|
||||
raise openai.error.APIError
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = raise_once
|
||||
callback_handler = FakeCallbackHandler()
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar")
|
||||
res = llm.predict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
assert raised
|
||||
assert callback_handler.retries == 1
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_async_retries(mock_completion: dict) -> None:
|
||||
llm = OpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
raised = False
|
||||
import openai
|
||||
|
||||
def raise_once(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
async def araise_once(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed, raised
|
||||
if not raised:
|
||||
raised = True
|
||||
raise openai.error.APIError
|
||||
await asyncio.sleep(0)
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = raise_once
|
||||
mock_client.acreate = araise_once
|
||||
callback_handler = FakeAsyncCallbackHandler()
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.apredict("bar")
|
||||
res = await llm.apredict("bar", callbacks=[callback_handler])
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
assert raised
|
||||
assert callback_handler.retries == 1
|
||||
|
Loading…
Reference in New Issue
Block a user