|
|
|
@ -4,7 +4,9 @@ from typing import Any
|
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
from tenacity import wait_none
|
|
|
|
|
|
|
|
|
|
from langchain.llms import base
|
|
|
|
|
from langchain.llms.openai import OpenAI
|
|
|
|
|
from tests.unit_tests.callbacks.fake_callback_handler import (
|
|
|
|
|
FakeAsyncCallbackHandler,
|
|
|
|
@ -55,6 +57,16 @@ def mock_completion() -> dict:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _patched_retry(*args: Any, **kwargs: Any) -> Any:
|
|
|
|
|
"""Patched retry for unit tests that does not wait."""
|
|
|
|
|
from tenacity import retry
|
|
|
|
|
|
|
|
|
|
assert "wait" in kwargs
|
|
|
|
|
kwargs["wait"] = wait_none()
|
|
|
|
|
r = retry(*args, **kwargs)
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("openai")
|
|
|
|
|
def test_openai_retries(mock_completion: dict) -> None:
|
|
|
|
|
llm = OpenAI()
|
|
|
|
@ -73,6 +85,9 @@ def test_openai_retries(mock_completion: dict) -> None:
|
|
|
|
|
|
|
|
|
|
mock_client.create = raise_once
|
|
|
|
|
callback_handler = FakeCallbackHandler()
|
|
|
|
|
|
|
|
|
|
# Patch the retry to avoid waiting during a unit test
|
|
|
|
|
with patch.object(base, "retry", _patched_retry):
|
|
|
|
|
with patch.object(
|
|
|
|
|
llm,
|
|
|
|
|
"client",
|
|
|
|
@ -105,6 +120,8 @@ async def test_openai_async_retries(mock_completion: dict) -> None:
|
|
|
|
|
|
|
|
|
|
mock_client.acreate = araise_once
|
|
|
|
|
callback_handler = FakeAsyncCallbackHandler()
|
|
|
|
|
# Patch the retry to avoid waiting during a unit test
|
|
|
|
|
with patch.object(base, "retry", _patched_retry):
|
|
|
|
|
with patch.object(
|
|
|
|
|
llm,
|
|
|
|
|
"client",
|
|
|
|
|