diff --git a/libs/langchain/langchain/callbacks/base.py b/libs/langchain/langchain/callbacks/base.py index dcf3766d8f..71025955f8 100644 --- a/libs/langchain/langchain/callbacks/base.py +++ b/libs/langchain/langchain/callbacks/base.py @@ -242,6 +242,11 @@ class BaseCallbackHandler( """Whether to ignore LLM callbacks.""" return False + @property + def ignore_retry(self) -> bool: + """Whether to ignore retry callbacks.""" + return False + @property def ignore_chain(self) -> bool: """Whether to ignore chain callbacks.""" diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 3e2ba86b71..1127b55b7d 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -23,6 +23,8 @@ from typing import ( ) from uuid import UUID +from tenacity import RetryCallState + import langchain from langchain.callbacks.base import ( BaseCallbackHandler, @@ -572,6 +574,22 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): **kwargs, ) + def on_retry( + self, + retry_state: RetryCallState, + **kwargs: Any, + ) -> None: + _handle_event( + self.handlers, + "on_retry", + "ignore_retry", + retry_state, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running. @@ -635,6 +653,22 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): **kwargs, ) + async def on_retry( + self, + retry_state: RetryCallState, + **kwargs: Any, + ) -> None: + await _ahandle_event( + self.handlers, + "on_retry", + "ignore_retry", + retry_state, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running. diff --git a/libs/langchain/langchain/callbacks/tracers/base.py b/libs/langchain/langchain/callbacks/tracers/base.py index db55127782..cc37bb47e8 100644 --- a/libs/langchain/langchain/callbacks/tracers/base.py +++ b/libs/langchain/langchain/callbacks/tracers/base.py @@ -7,6 +7,8 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Sequence, Union, cast from uuid import UUID +from tenacity import RetryCallState + from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.tracers.schemas import Run, RunTypeEnum from langchain.load.dump import dumpd @@ -138,6 +140,41 @@ class BaseTracer(BaseCallbackHandler, ABC): }, ) + def on_retry( + self, + retry_state: RetryCallState, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + if not run_id: + raise TracerException("No run_id provided for on_retry callback.") + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None or llm_run.run_type != RunTypeEnum.llm: + raise TracerException("No LLM Run found to be traced for on_retry") + retry_d: Dict[str, Any] = { + "slept": retry_state.idle_for, + "attempt": retry_state.attempt_number, + } + if retry_state.outcome is None: + retry_d["outcome"] = "N/A" + elif retry_state.outcome.failed: + retry_d["outcome"] = "failed" + exception = retry_state.outcome.exception() + retry_d["exception"] = str(exception) + retry_d["exception_type"] = exception.__class__.__name__ + else: + retry_d["outcome"] = "success" + retry_d["result"] = str(retry_state.outcome.result()) + llm_run.events.append( + { + "name": "retry", + "time": datetime.utcnow(), + "kwargs": retry_d, + }, + ) + def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None: """End a trace for an LLM run.""" if not run_id: diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index 815e3011bb..8d1137b6fb 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -18,23 +18,14 @@ from typing import ( ) from pydantic import Field, root_validator -from tenacity import ( - before_sleep_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.chat_models.base import BaseChatModel -from langchain.schema import ( - ChatGeneration, - ChatResult, -) +from langchain.llms.base import create_base_retry_decorator +from langchain.schema import ChatGeneration, ChatResult from langchain.schema.messages import ( AIMessage, AIMessageChunk, @@ -70,31 +61,33 @@ def _import_tiktoken() -> Any: return tiktoken -def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]: +def _create_retry_decorator( + llm: ChatOpenAI, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: import openai - min_seconds = 1 - max_seconds = 60 - # Wait 2^x * 1 second between each retry starting with - # 4 seconds, then up to 10 seconds, then 10 seconds afterwards - return retry( - reraise=True, - stop=stop_after_attempt(llm.max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) - ), - before_sleep=before_sleep_log(logger, logging.WARNING), + errors = [ + openai.error.Timeout, + openai.error.APIError, + openai.error.APIConnectionError, + openai.error.RateLimitError, + openai.error.ServiceUnavailableError, + ] + return create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager ) -async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any: +async def acompletion_with_retry( + llm: ChatOpenAI, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: """Use tenacity to retry the async completion call.""" - retry_decorator = _create_retry_decorator(llm) + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: @@ -322,9 +315,11 @@ class ChatOpenAI(BaseChatModel): **self.model_kwargs, } - def completion_with_retry(self, **kwargs: Any) -> Any: + def completion_with_retry( + self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any + ) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(self) + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: @@ -357,7 +352,9 @@ class ChatOpenAI(BaseChatModel): params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk - for chunk in self.completion_with_retry(messages=message_dicts, **params): + for chunk in self.completion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ): if len(chunk["choices"]) == 0: continue delta = chunk["choices"][0]["delta"] @@ -388,7 +385,9 @@ class ChatOpenAI(BaseChatModel): message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} - response = self.completion_with_retry(messages=message_dicts, **params) + response = self.completion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ) return self._create_chat_result(response) def _create_message_dicts( @@ -427,7 +426,7 @@ class ChatOpenAI(BaseChatModel): default_chunk_class = AIMessageChunk async for chunk in await acompletion_with_retry( - self, messages=message_dicts, **params + self, messages=message_dicts, run_manager=run_manager, **params ): if len(chunk["choices"]) == 0: continue @@ -459,7 +458,9 @@ class ChatOpenAI(BaseChatModel): message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} - response = await acompletion_with_retry(self, messages=message_dicts, **params) + response = await acompletion_with_retry( + self, messages=message_dicts, run_manager=run_manager, **params + ) return self._create_chat_result(response) @property diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 10595e8319..a40b79de80 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +import functools import inspect import json import logging @@ -28,6 +29,7 @@ from typing import ( import yaml from pydantic import Field, root_validator, validator from tenacity import ( + RetryCallState, before_sleep_log, retry, retry_base, @@ -66,11 +68,36 @@ def _get_verbosity() -> bool: return langchain.verbose +@functools.lru_cache +def _log_error_once(msg: str) -> None: + """Log an error once.""" + logger.error(msg) + + def create_base_retry_decorator( - error_types: List[Type[BaseException]], max_retries: int = 1 + error_types: List[Type[BaseException]], + max_retries: int = 1, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, ) -> Callable[[Any], Any]: """Create a retry decorator for a given LLM and provided list of error types.""" + _logging = before_sleep_log(logger, logging.WARNING) + + def _before_sleep(retry_state: RetryCallState) -> None: + _logging(retry_state) + if run_manager: + if isinstance(run_manager, AsyncCallbackManagerForLLMRun): + coro = run_manager.on_retry(retry_state) + try: + asyncio.run(coro) + except Exception as e: + _log_error_once(f"Error in on_retry: {e}") + else: + run_manager.on_retry(retry_state) + return None + min_seconds = 4 max_seconds = 10 # Wait 2^x * 1 second between each retry starting with @@ -83,7 +110,7 @@ def create_base_retry_decorator( stop=stop_after_attempt(max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), retry=retry_instance, - before_sleep=before_sleep_log(logger, logging.WARNING), + before_sleep=_before_sleep, ) diff --git a/libs/langchain/langchain/llms/openai.py b/libs/langchain/langchain/llms/openai.py index 2c664a1c87..52741d7f53 100644 --- a/libs/langchain/langchain/llms/openai.py +++ b/libs/langchain/langchain/llms/openai.py @@ -80,7 +80,12 @@ def _streaming_response_template() -> Dict[str, Any]: } -def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]: +def _create_retry_decorator( + llm: Union[BaseOpenAI, OpenAIChat], + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: import openai errors = [ @@ -90,12 +95,18 @@ def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any openai.error.RateLimitError, openai.error.ServiceUnavailableError, ] - return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries) + return create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + ) -def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any: +def completion_with_retry( + llm: Union[BaseOpenAI, OpenAIChat], + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm) + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: @@ -105,10 +116,12 @@ def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> async def acompletion_with_retry( - llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any + llm: Union[BaseOpenAI, OpenAIChat], + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> Any: """Use tenacity to retry the async completion call.""" - retry_decorator = _create_retry_decorator(llm) + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: @@ -291,8 +304,10 @@ class BaseOpenAI(BaseLLM): **kwargs: Any, ) -> Iterator[GenerationChunk]: params = {**self._invocation_params, **kwargs, "stream": True} - self.get_sub_prompts(params, [prompt], stop) # this mutate params - for stream_resp in completion_with_retry(self, prompt=prompt, **params): + self.get_sub_prompts(params, [prompt], stop) # this mutates params + for stream_resp in completion_with_retry( + self, prompt=prompt, run_manager=run_manager, **params + ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk if run_manager: @@ -314,7 +329,7 @@ class BaseOpenAI(BaseLLM): params = {**self._invocation_params, **kwargs, "stream": True} self.get_sub_prompts(params, [prompt], stop) # this mutate params async for stream_resp in await acompletion_with_retry( - self, prompt=prompt, **params + self, prompt=prompt, run_manager=run_manager, **params ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk @@ -381,7 +396,9 @@ class BaseOpenAI(BaseLLM): } ) else: - response = completion_with_retry(self, prompt=_prompts, **params) + response = completion_with_retry( + self, prompt=_prompts, run_manager=run_manager, **params + ) choices.extend(response["choices"]) update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) @@ -428,7 +445,9 @@ class BaseOpenAI(BaseLLM): } ) else: - response = await acompletion_with_retry(self, prompt=_prompts, **params) + response = await acompletion_with_retry( + self, prompt=_prompts, run_manager=run_manager, **params + ) choices.extend(response["choices"]) update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) @@ -818,7 +837,9 @@ class OpenAIChat(BaseLLM): ) -> Iterator[GenerationChunk]: messages, params = self._get_chat_params([prompt], stop) params = {**params, **kwargs, "stream": True} - for stream_resp in completion_with_retry(self, messages=messages, **params): + for stream_resp in completion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ): token = stream_resp["choices"][0]["delta"].get("content", "") yield GenerationChunk(text=token) if run_manager: @@ -834,7 +855,7 @@ class OpenAIChat(BaseLLM): messages, params = self._get_chat_params([prompt], stop) params = {**params, **kwargs, "stream": True} async for stream_resp in await acompletion_with_retry( - self, messages=messages, **params + self, messages=messages, run_manager=run_manager, **params ): token = stream_resp["choices"][0]["delta"].get("content", "") yield GenerationChunk(text=token) @@ -860,7 +881,9 @@ class OpenAIChat(BaseLLM): messages, params = self._get_chat_params(prompts, stop) params = {**params, **kwargs} - full_response = completion_with_retry(self, messages=messages, **params) + full_response = completion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ) llm_output = { "token_usage": full_response["usage"], "model_name": self.model_name, @@ -891,7 +914,9 @@ class OpenAIChat(BaseLLM): messages, params = self._get_chat_params(prompts, stop) params = {**params, **kwargs} - full_response = await acompletion_with_retry(self, messages=messages, **params) + full_response = await acompletion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ) llm_output = { "token_usage": full_response["usage"], "model_name": self.model_name, diff --git a/libs/langchain/tests/integration_tests/llms/test_openai.py b/libs/langchain/tests/integration_tests/llms/test_openai.py index 0844faa6aa..6b584ae154 100644 --- a/libs/langchain/tests/integration_tests/llms/test_openai.py +++ b/libs/langchain/tests/integration_tests/llms/test_openai.py @@ -1,7 +1,7 @@ """Test OpenAI API wrapper.""" - from pathlib import Path -from typing import Generator +from typing import Any, Generator +from unittest.mock import MagicMock, patch import pytest @@ -10,7 +10,10 @@ from langchain.chat_models.openai import ChatOpenAI from langchain.llms.loading import load_llm from langchain.llms.openai import OpenAI, OpenAIChat from langchain.schema import LLMResult -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +from tests.unit_tests.callbacks.fake_callback_handler import ( + FakeAsyncCallbackHandler, + FakeCallbackHandler, +) def test_openai_call() -> None: @@ -334,3 +337,77 @@ def test_chat_openai_get_num_tokens(model: str) -> None: """Test get_tokens.""" llm = ChatOpenAI(model=model) assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model] + + +@pytest.fixture +def mock_completion() -> dict: + return { + "id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ", + "object": "text_completion", + "created": 1689989000, + "model": "text-davinci-003", + "choices": [ + {"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"} + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + + +@pytest.mark.requires("openai") +def test_openai_retries(mock_completion: dict) -> None: + llm = OpenAI() + mock_client = MagicMock() + completed = False + raised = False + import openai + + def raise_once(*args: Any, **kwargs: Any) -> Any: + nonlocal completed, raised + if not raised: + raised = True + raise openai.error.APIError + completed = True + return mock_completion + + mock_client.create = raise_once + callback_handler = FakeCallbackHandler() + with patch.object( + llm, + "client", + mock_client, + ): + res = llm.predict("bar", callbacks=[callback_handler]) + assert res == "Bar Baz" + assert completed + assert raised + assert callback_handler.retries == 1 + + +@pytest.mark.requires("openai") +async def test_openai_async_retries(mock_completion: dict) -> None: + llm = OpenAI() + mock_client = MagicMock() + completed = False + raised = False + import openai + + def raise_once(*args: Any, **kwargs: Any) -> Any: + nonlocal completed, raised + if not raised: + raised = True + raise openai.error.APIError + completed = True + return mock_completion + + mock_client.create = raise_once + callback_handler = FakeAsyncCallbackHandler() + with patch.object( + llm, + "client", + mock_client, + ): + res = llm.apredict("bar", callbacks=[callback_handler]) + assert res == "Bar Baz" + assert completed + assert raised + assert callback_handler.retries == 1 diff --git a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py index a5e3d3ef69..87b56a9bff 100644 --- a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py @@ -39,6 +39,7 @@ class BaseFakeCallbackHandler(BaseModel): retriever_starts: int = 0 retriever_ends: int = 0 retriever_errors: int = 0 + retries: int = 0 class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): @@ -58,8 +59,10 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): def on_llm_new_token_common(self) -> None: self.llm_streams += 1 + def on_retry_common(self) -> None: + self.retries += 1 + def on_chain_start_common(self) -> None: - ("CHAIN START") self.chain_starts += 1 self.starts += 1 @@ -82,7 +85,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): self.errors += 1 def on_agent_action_common(self) -> None: - print("AGENT ACTION") self.agent_actions += 1 self.starts += 1 @@ -91,7 +93,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): self.ends += 1 def on_chat_model_start_common(self) -> None: - print("STARTING CHAT MODEL") self.chat_model_starts += 1 self.starts += 1 @@ -162,6 +163,13 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_llm_error_common() + def on_retry( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retry_common() + def on_chain_start( self, *args: Any, diff --git a/libs/langchain/tests/unit_tests/chat_models/test_openai.py b/libs/langchain/tests/unit_tests/chat_models/test_openai.py index ad05133f18..7719358d72 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_openai.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_openai.py @@ -1,8 +1,12 @@ """Test OpenAI Chat API wrapper.""" - import json +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest from langchain.chat_models.openai import ( + ChatOpenAI, _convert_dict_to_message, ) from langchain.schema.messages import FunctionMessage @@ -21,3 +25,67 @@ def test_function_message_dict_to_function_message() -> None: assert isinstance(result, FunctionMessage) assert result.name == name assert result.content == content + + +@pytest.fixture +def mock_completion() -> dict: + return { + "id": "chatcmpl-7fcZavknQda3SQ", + "object": "chat.completion", + "created": 1689989000, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Bar Baz", + }, + "finish_reason": "stop", + } + ], + } + + +@pytest.mark.requires("openai") +def test_openai_predict(mock_completion: dict) -> None: + llm = ChatOpenAI() + mock_client = MagicMock() + completed = False + + def mock_create(*args: Any, **kwargs: Any) -> Any: + nonlocal completed + completed = True + return mock_completion + + mock_client.create = mock_create + with patch.object( + llm, + "client", + mock_client, + ): + res = llm.predict("bar") + assert res == "Bar Baz" + assert completed + + +@pytest.mark.requires("openai") +async def test_openai_apredict(mock_completion: dict) -> None: + llm = ChatOpenAI() + mock_client = MagicMock() + completed = False + + def mock_create(*args: Any, **kwargs: Any) -> Any: + nonlocal completed + completed = True + return mock_completion + + mock_client.create = mock_create + with patch.object( + llm, + "client", + mock_client, + ): + res = llm.predict("bar") + assert res == "Bar Baz" + assert completed diff --git a/libs/langchain/tests/unit_tests/llms/test_openai.py b/libs/langchain/tests/unit_tests/llms/test_openai.py index ef311ea878..cc0fc74c1f 100644 --- a/libs/langchain/tests/unit_tests/llms/test_openai.py +++ b/libs/langchain/tests/unit_tests/llms/test_openai.py @@ -1,4 +1,6 @@ import os +from typing import Any +from unittest.mock import MagicMock, patch import pytest @@ -26,3 +28,61 @@ def test_openai_incorrect_field() -> None: with pytest.warns(match="not default parameter"): llm = OpenAI(foo="bar") assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.fixture +def mock_completion() -> dict: + return { + "id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ", + "object": "text_completion", + "created": 1689989000, + "model": "text-davinci-003", + "choices": [ + {"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"} + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + + +@pytest.mark.requires("openai") +def test_openai_calls(mock_completion: dict) -> None: + llm = OpenAI() + mock_client = MagicMock() + completed = False + + def raise_once(*args: Any, **kwargs: Any) -> Any: + nonlocal completed + completed = True + return mock_completion + + mock_client.create = raise_once + with patch.object( + llm, + "client", + mock_client, + ): + res = llm.predict("bar") + assert res == "Bar Baz" + assert completed + + +@pytest.mark.requires("openai") +async def test_openai_async_retries(mock_completion: dict) -> None: + llm = OpenAI() + mock_client = MagicMock() + completed = False + + def raise_once(*args: Any, **kwargs: Any) -> Any: + nonlocal completed + completed = True + return mock_completion + + mock_client.create = raise_once + with patch.object( + llm, + "client", + mock_client, + ): + res = llm.apredict("bar") + assert res == "Bar Baz" + assert completed