Add Retry Events (#8053)

![image](https://github.com/hwchase17/langchain/assets/13333726/59a5c3b4-4367-47e6-9f58-5b6557576a8a)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/8375/head
William FH 11 months ago committed by GitHub
parent 94a693e2ee
commit ff98fad2d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -242,6 +242,11 @@ class BaseCallbackHandler(
"""Whether to ignore LLM callbacks.""" """Whether to ignore LLM callbacks."""
return False return False
@property
def ignore_retry(self) -> bool:
"""Whether to ignore retry callbacks."""
return False
@property @property
def ignore_chain(self) -> bool: def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks.""" """Whether to ignore chain callbacks."""

@ -23,6 +23,8 @@ from typing import (
) )
from uuid import UUID from uuid import UUID
from tenacity import RetryCallState
import langchain import langchain
from langchain.callbacks.base import ( from langchain.callbacks.base import (
BaseCallbackHandler, BaseCallbackHandler,
@ -572,6 +574,22 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
**kwargs, **kwargs,
) )
def on_retry(
self,
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
_handle_event(
self.handlers,
"on_retry",
"ignore_retry",
retry_state,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running. """Run when LLM ends running.
@ -635,6 +653,22 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
**kwargs, **kwargs,
) )
async def on_retry(
self,
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
await _ahandle_event(
self.handlers,
"on_retry",
"ignore_retry",
retry_state,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running. """Run when LLM ends running.

@ -7,6 +7,8 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Union, cast from typing import Any, Dict, List, Optional, Sequence, Union, cast
from uuid import UUID from uuid import UUID
from tenacity import RetryCallState
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
@ -138,6 +140,41 @@ class BaseTracer(BaseCallbackHandler, ABC):
}, },
) )
def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
**kwargs: Any,
) -> None:
if not run_id:
raise TracerException("No run_id provided for on_retry callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
raise TracerException("No LLM Run found to be traced for on_retry")
retry_d: Dict[str, Any] = {
"slept": retry_state.idle_for,
"attempt": retry_state.attempt_number,
}
if retry_state.outcome is None:
retry_d["outcome"] = "N/A"
elif retry_state.outcome.failed:
retry_d["outcome"] = "failed"
exception = retry_state.outcome.exception()
retry_d["exception"] = str(exception)
retry_d["exception_type"] = exception.__class__.__name__
else:
retry_d["outcome"] = "success"
retry_d["result"] = str(retry_state.outcome.result())
llm_run.events.append(
{
"name": "retry",
"time": datetime.utcnow(),
"kwargs": retry_d,
},
)
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for an LLM run.""" """End a trace for an LLM run."""
if not run_id: if not run_id:

@ -18,23 +18,14 @@ from typing import (
) )
from pydantic import Field, root_validator from pydantic import Field, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.schema import ( from langchain.llms.base import create_base_retry_decorator
ChatGeneration, from langchain.schema import ChatGeneration, ChatResult
ChatResult,
)
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -70,31 +61,33 @@ def _import_tiktoken() -> Any:
return tiktoken return tiktoken
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]: def _create_retry_decorator(
llm: ChatOpenAI,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import openai import openai
min_seconds = 1 errors = [
max_seconds = 60 openai.error.Timeout,
# Wait 2^x * 1 second between each retry starting with openai.error.APIError,
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards openai.error.APIConnectionError,
return retry( openai.error.RateLimitError,
reraise=True, openai.error.ServiceUnavailableError,
stop=stop_after_attempt(llm.max_retries), ]
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), return create_base_retry_decorator(
retry=( error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
) )
async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any: async def acompletion_with_retry(
llm: ChatOpenAI,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call.""" """Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm) retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator @retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any: async def _completion_with_retry(**kwargs: Any) -> Any:
@ -322,9 +315,11 @@ class ChatOpenAI(BaseChatModel):
**self.model_kwargs, **self.model_kwargs,
} }
def completion_with_retry(self, **kwargs: Any) -> Any: def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self) retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator @retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any: def _completion_with_retry(**kwargs: Any) -> Any:
@ -357,7 +352,9 @@ class ChatOpenAI(BaseChatModel):
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params): for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["choices"]) == 0: if len(chunk["choices"]) == 0:
continue continue
delta = chunk["choices"][0]["delta"] delta = chunk["choices"][0]["delta"]
@ -388,7 +385,9 @@ class ChatOpenAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params) response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response) return self._create_chat_result(response)
def _create_message_dicts( def _create_message_dicts(
@ -427,7 +426,7 @@ class ChatOpenAI(BaseChatModel):
default_chunk_class = AIMessageChunk default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry( async for chunk in await acompletion_with_retry(
self, messages=message_dicts, **params self, messages=message_dicts, run_manager=run_manager, **params
): ):
if len(chunk["choices"]) == 0: if len(chunk["choices"]) == 0:
continue continue
@ -459,7 +458,9 @@ class ChatOpenAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
response = await acompletion_with_retry(self, messages=message_dicts, **params) response = await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response) return self._create_chat_result(response)
@property @property

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import functools
import inspect import inspect
import json import json
import logging import logging
@ -28,6 +29,7 @@ from typing import (
import yaml import yaml
from pydantic import Field, root_validator, validator from pydantic import Field, root_validator, validator
from tenacity import ( from tenacity import (
RetryCallState,
before_sleep_log, before_sleep_log,
retry, retry,
retry_base, retry_base,
@ -66,11 +68,36 @@ def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
@functools.lru_cache
def _log_error_once(msg: str) -> None:
"""Log an error once."""
logger.error(msg)
def create_base_retry_decorator( def create_base_retry_decorator(
error_types: List[Type[BaseException]], max_retries: int = 1 error_types: List[Type[BaseException]],
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]: ) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided list of error types.""" """Create a retry decorator for a given LLM and provided list of error types."""
_logging = before_sleep_log(logger, logging.WARNING)
def _before_sleep(retry_state: RetryCallState) -> None:
_logging(retry_state)
if run_manager:
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
asyncio.run(coro)
except Exception as e:
_log_error_once(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None
min_seconds = 4 min_seconds = 4
max_seconds = 10 max_seconds = 10
# Wait 2^x * 1 second between each retry starting with # Wait 2^x * 1 second between each retry starting with
@ -83,7 +110,7 @@ def create_base_retry_decorator(
stop=stop_after_attempt(max_retries), stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_instance, retry=retry_instance,
before_sleep=before_sleep_log(logger, logging.WARNING), before_sleep=_before_sleep,
) )

@ -80,7 +80,12 @@ def _streaming_response_template() -> Dict[str, Any]:
} }
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]: def _create_retry_decorator(
llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import openai import openai
errors = [ errors = [
@ -90,12 +95,18 @@ def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any
openai.error.RateLimitError, openai.error.RateLimitError,
openai.error.ServiceUnavailableError, openai.error.ServiceUnavailableError,
] ]
return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries) return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any: def completion_with_retry(
llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm) retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator @retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any: def _completion_with_retry(**kwargs: Any) -> Any:
@ -105,10 +116,12 @@ def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) ->
async def acompletion_with_retry( async def acompletion_with_retry(
llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any: ) -> Any:
"""Use tenacity to retry the async completion call.""" """Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm) retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator @retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any: async def _completion_with_retry(**kwargs: Any) -> Any:
@ -291,8 +304,10 @@ class BaseOpenAI(BaseLLM):
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True} params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(self, prompt=prompt, **params): for stream_resp in completion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp) chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk yield chunk
if run_manager: if run_manager:
@ -314,7 +329,7 @@ class BaseOpenAI(BaseLLM):
params = {**self._invocation_params, **kwargs, "stream": True} params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params self.get_sub_prompts(params, [prompt], stop) # this mutate params
async for stream_resp in await acompletion_with_retry( async for stream_resp in await acompletion_with_retry(
self, prompt=prompt, **params self, prompt=prompt, run_manager=run_manager, **params
): ):
chunk = _stream_response_to_generation_chunk(stream_resp) chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk yield chunk
@ -381,7 +396,9 @@ class BaseOpenAI(BaseLLM):
} }
) )
else: else:
response = completion_with_retry(self, prompt=_prompts, **params) response = completion_with_retry(
self, prompt=_prompts, run_manager=run_manager, **params
)
choices.extend(response["choices"]) choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage) update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
@ -428,7 +445,9 @@ class BaseOpenAI(BaseLLM):
} }
) )
else: else:
response = await acompletion_with_retry(self, prompt=_prompts, **params) response = await acompletion_with_retry(
self, prompt=_prompts, run_manager=run_manager, **params
)
choices.extend(response["choices"]) choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage) update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
@ -818,7 +837,9 @@ class OpenAIChat(BaseLLM):
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop) messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry(self, messages=messages, **params): for stream_resp in completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "") token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token) yield GenerationChunk(text=token)
if run_manager: if run_manager:
@ -834,7 +855,7 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params([prompt], stop) messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry( async for stream_resp in await acompletion_with_retry(
self, messages=messages, **params self, messages=messages, run_manager=run_manager, **params
): ):
token = stream_resp["choices"][0]["delta"].get("content", "") token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token) yield GenerationChunk(text=token)
@ -860,7 +881,9 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
full_response = completion_with_retry(self, messages=messages, **params) full_response = completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
)
llm_output = { llm_output = {
"token_usage": full_response["usage"], "token_usage": full_response["usage"],
"model_name": self.model_name, "model_name": self.model_name,
@ -891,7 +914,9 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
full_response = await acompletion_with_retry(self, messages=messages, **params) full_response = await acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params
)
llm_output = { llm_output = {
"token_usage": full_response["usage"], "token_usage": full_response["usage"],
"model_name": self.model_name, "model_name": self.model_name,

@ -1,7 +1,7 @@
"""Test OpenAI API wrapper.""" """Test OpenAI API wrapper."""
from pathlib import Path from pathlib import Path
from typing import Generator from typing import Any, Generator
from unittest.mock import MagicMock, patch
import pytest import pytest
@ -10,7 +10,10 @@ from langchain.chat_models.openai import ChatOpenAI
from langchain.llms.loading import load_llm from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI, OpenAIChat from langchain.llms.openai import OpenAI, OpenAIChat
from langchain.schema import LLMResult from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
def test_openai_call() -> None: def test_openai_call() -> None:
@ -334,3 +337,77 @@ def test_chat_openai_get_num_tokens(model: str) -> None:
"""Test get_tokens.""" """Test get_tokens."""
llm = ChatOpenAI(model=model) llm = ChatOpenAI(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model] assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
"object": "text_completion",
"created": 1689989000,
"model": "text-davinci-003",
"choices": [
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
@pytest.mark.requires("openai")
def test_openai_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
completed = True
return mock_completion
mock_client.create = raise_once
callback_handler = FakeCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1
@pytest.mark.requires("openai")
async def test_openai_async_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
completed = True
return mock_completion
mock_client.create = raise_once
callback_handler = FakeAsyncCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.apredict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1

@ -39,6 +39,7 @@ class BaseFakeCallbackHandler(BaseModel):
retriever_starts: int = 0 retriever_starts: int = 0
retriever_ends: int = 0 retriever_ends: int = 0
retriever_errors: int = 0 retriever_errors: int = 0
retries: int = 0
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
@ -58,8 +59,10 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
def on_llm_new_token_common(self) -> None: def on_llm_new_token_common(self) -> None:
self.llm_streams += 1 self.llm_streams += 1
def on_retry_common(self) -> None:
self.retries += 1
def on_chain_start_common(self) -> None: def on_chain_start_common(self) -> None:
("CHAIN START")
self.chain_starts += 1 self.chain_starts += 1
self.starts += 1 self.starts += 1
@ -82,7 +85,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.errors += 1 self.errors += 1
def on_agent_action_common(self) -> None: def on_agent_action_common(self) -> None:
print("AGENT ACTION")
self.agent_actions += 1 self.agent_actions += 1
self.starts += 1 self.starts += 1
@ -91,7 +93,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.ends += 1 self.ends += 1
def on_chat_model_start_common(self) -> None: def on_chat_model_start_common(self) -> None:
print("STARTING CHAT MODEL")
self.chat_model_starts += 1 self.chat_model_starts += 1
self.starts += 1 self.starts += 1
@ -162,6 +163,13 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_llm_error_common() self.on_llm_error_common()
def on_retry(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retry_common()
def on_chain_start( def on_chain_start(
self, self,
*args: Any, *args: Any,

@ -1,8 +1,12 @@
"""Test OpenAI Chat API wrapper.""" """Test OpenAI Chat API wrapper."""
import json import json
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from langchain.chat_models.openai import ( from langchain.chat_models.openai import (
ChatOpenAI,
_convert_dict_to_message, _convert_dict_to_message,
) )
from langchain.schema.messages import FunctionMessage from langchain.schema.messages import FunctionMessage
@ -21,3 +25,67 @@ def test_function_message_dict_to_function_message() -> None:
assert isinstance(result, FunctionMessage) assert isinstance(result, FunctionMessage)
assert result.name == name assert result.name == name
assert result.content == content assert result.content == content
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "chatcmpl-7fcZavknQda3SQ",
"object": "chat.completion",
"created": 1689989000,
"model": "gpt-3.5-turbo-0613",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Bar Baz",
},
"finish_reason": "stop",
}
],
}
@pytest.mark.requires("openai")
def test_openai_predict(mock_completion: dict) -> None:
llm = ChatOpenAI()
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed
@pytest.mark.requires("openai")
async def test_openai_apredict(mock_completion: dict) -> None:
llm = ChatOpenAI()
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed

@ -1,4 +1,6 @@
import os import os
from typing import Any
from unittest.mock import MagicMock, patch
import pytest import pytest
@ -26,3 +28,61 @@ def test_openai_incorrect_field() -> None:
with pytest.warns(match="not default parameter"): with pytest.warns(match="not default parameter"):
llm = OpenAI(foo="bar") llm = OpenAI(foo="bar")
assert llm.model_kwargs == {"foo": "bar"} assert llm.model_kwargs == {"foo": "bar"}
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
"object": "text_completion",
"created": 1689989000,
"model": "text-davinci-003",
"choices": [
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
@pytest.mark.requires("openai")
def test_openai_calls(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = raise_once
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed
@pytest.mark.requires("openai")
async def test_openai_async_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = raise_once
with patch.object(
llm,
"client",
mock_client,
):
res = llm.apredict("bar")
assert res == "Bar Baz"
assert completed

Loading…
Cancel
Save