diff --git a/libs/langchain/tests/unit_tests/llms/test_openai.py b/libs/langchain/tests/unit_tests/llms/test_openai.py index 7af941a432..db3c2b32c1 100644 --- a/libs/langchain/tests/unit_tests/llms/test_openai.py +++ b/libs/langchain/tests/unit_tests/llms/test_openai.py @@ -4,7 +4,9 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest +from tenacity import wait_none +from langchain.llms import base from langchain.llms.openai import OpenAI from tests.unit_tests.callbacks.fake_callback_handler import ( FakeAsyncCallbackHandler, @@ -55,6 +57,16 @@ def mock_completion() -> dict: } +def _patched_retry(*args: Any, **kwargs: Any) -> Any: + """Patched retry for unit tests that does not wait.""" + from tenacity import retry + + assert "wait" in kwargs + kwargs["wait"] = wait_none() + r = retry(*args, **kwargs) + return r + + @pytest.mark.requires("openai") def test_openai_retries(mock_completion: dict) -> None: llm = OpenAI() @@ -73,13 +85,16 @@ def test_openai_retries(mock_completion: dict) -> None: mock_client.create = raise_once callback_handler = FakeCallbackHandler() - with patch.object( - llm, - "client", - mock_client, - ): - res = llm.predict("bar", callbacks=[callback_handler]) - assert res == "Bar Baz" + + # Patch the retry to avoid waiting during a unit test + with patch.object(base, "retry", _patched_retry): + with patch.object( + llm, + "client", + mock_client, + ): + res = llm.predict("bar", callbacks=[callback_handler]) + assert res == "Bar Baz" assert completed assert raised assert callback_handler.retries == 1 @@ -105,13 +120,15 @@ async def test_openai_async_retries(mock_completion: dict) -> None: mock_client.acreate = araise_once callback_handler = FakeAsyncCallbackHandler() - with patch.object( - llm, - "client", - mock_client, - ): - res = await llm.apredict("bar", callbacks=[callback_handler]) - assert res == "Bar Baz" + # Patch the retry to avoid waiting during a unit test + with patch.object(base, "retry", _patched_retry): + with patch.object( + llm, + "client", + mock_client, + ): + res = await llm.apredict("bar", callbacks=[callback_handler]) + assert res == "Bar Baz" assert completed assert raised assert callback_handler.retries == 1