Add Retry Events (#8053)

![image](https://github.com/hwchase17/langchain/assets/13333726/59a5c3b4-4367-47e6-9f58-5b6557576a8a)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/8375/head
William FH 11 months ago committed by GitHub
parent 94a693e2ee
commit ff98fad2d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -242,6 +242,11 @@ class BaseCallbackHandler(
"""Whether to ignore LLM callbacks."""
return False
@property
def ignore_retry(self) -> bool:
"""Whether to ignore retry callbacks."""
return False
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""

@ -23,6 +23,8 @@ from typing import (
)
from uuid import UUID
from tenacity import RetryCallState
import langchain
from langchain.callbacks.base import (
BaseCallbackHandler,
@ -572,6 +574,22 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
**kwargs,
)
def on_retry(
self,
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
_handle_event(
self.handlers,
"on_retry",
"ignore_retry",
retry_state,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running.
@ -635,6 +653,22 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
**kwargs,
)
async def on_retry(
self,
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
await _ahandle_event(
self.handlers,
"on_retry",
"ignore_retry",
retry_state,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running.

@ -7,6 +7,8 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Union, cast
from uuid import UUID
from tenacity import RetryCallState
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.load.dump import dumpd
@ -138,6 +140,41 @@ class BaseTracer(BaseCallbackHandler, ABC):
},
)
def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
**kwargs: Any,
) -> None:
if not run_id:
raise TracerException("No run_id provided for on_retry callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
raise TracerException("No LLM Run found to be traced for on_retry")
retry_d: Dict[str, Any] = {
"slept": retry_state.idle_for,
"attempt": retry_state.attempt_number,
}
if retry_state.outcome is None:
retry_d["outcome"] = "N/A"
elif retry_state.outcome.failed:
retry_d["outcome"] = "failed"
exception = retry_state.outcome.exception()
retry_d["exception"] = str(exception)
retry_d["exception_type"] = exception.__class__.__name__
else:
retry_d["outcome"] = "success"
retry_d["result"] = str(retry_state.outcome.result())
llm_run.events.append(
{
"name": "retry",
"time": datetime.utcnow(),
"kwargs": retry_d,
},
)
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for an LLM run."""
if not run_id:

@ -18,23 +18,14 @@ from typing import (
)
from pydantic import Field, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.llms.base import create_base_retry_decorator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
@ -70,31 +61,33 @@ def _import_tiktoken() -> Any:
return tiktoken
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
def _create_retry_decorator(
llm: ChatOpenAI,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import openai
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
errors = [
openai.error.Timeout,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any:
async def acompletion_with_retry(
llm: ChatOpenAI,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
@ -322,9 +315,11 @@ class ChatOpenAI(BaseChatModel):
**self.model_kwargs,
}
def completion_with_retry(self, **kwargs: Any) -> Any:
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self)
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
@ -357,7 +352,9 @@ class ChatOpenAI(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params):
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
@ -388,7 +385,9 @@ class ChatOpenAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def _create_message_dicts(
@ -427,7 +426,7 @@ class ChatOpenAI(BaseChatModel):
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, **params
self, messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["choices"]) == 0:
continue
@ -459,7 +458,9 @@ class ChatOpenAI(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(self, messages=message_dicts, **params)
response = await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
@property

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import functools
import inspect
import json
import logging
@ -28,6 +29,7 @@ from typing import (
import yaml
from pydantic import Field, root_validator, validator
from tenacity import (
RetryCallState,
before_sleep_log,
retry,
retry_base,
@ -66,11 +68,36 @@ def _get_verbosity() -> bool:
return langchain.verbose
@functools.lru_cache
def _log_error_once(msg: str) -> None:
"""Log an error once."""
logger.error(msg)
def create_base_retry_decorator(
error_types: List[Type[BaseException]], max_retries: int = 1
error_types: List[Type[BaseException]],
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided list of error types."""
_logging = before_sleep_log(logger, logging.WARNING)
def _before_sleep(retry_state: RetryCallState) -> None:
_logging(retry_state)
if run_manager:
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
asyncio.run(coro)
except Exception as e:
_log_error_once(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
@ -83,7 +110,7 @@ def create_base_retry_decorator(
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_instance,
before_sleep=before_sleep_log(logger, logging.WARNING),
before_sleep=_before_sleep,
)

@ -80,7 +80,12 @@ def _streaming_response_template() -> Dict[str, Any]:
}
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
def _create_retry_decorator(
llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import openai
errors = [
@ -90,12 +95,18 @@ def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
]
return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries)
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
def completion_with_retry(
llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
@ -105,10 +116,12 @@ def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) ->
async def acompletion_with_retry(
llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any
llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
@ -291,8 +304,10 @@ class BaseOpenAI(BaseLLM):
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params
for stream_resp in completion_with_retry(self, prompt=prompt, **params):
self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
@ -314,7 +329,7 @@ class BaseOpenAI(BaseLLM):
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params
async for stream_resp in await acompletion_with_retry(
self, prompt=prompt, **params
self, prompt=prompt, run_manager=run_manager, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
@ -381,7 +396,9 @@ class BaseOpenAI(BaseLLM):
}
)
else:
response = completion_with_retry(self, prompt=_prompts, **params)
response = completion_with_retry(
self, prompt=_prompts, run_manager=run_manager, **params
)
choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
@ -428,7 +445,9 @@ class BaseOpenAI(BaseLLM):
}
)
else:
response = await acompletion_with_retry(self, prompt=_prompts, **params)
response = await acompletion_with_retry(
self, prompt=_prompts, run_manager=run_manager, **params
)
choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
@ -818,7 +837,9 @@ class OpenAIChat(BaseLLM):
) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry(self, messages=messages, **params):
for stream_resp in completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
if run_manager:
@ -834,7 +855,7 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry(
self, messages=messages, **params
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
@ -860,7 +881,9 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = completion_with_retry(self, messages=messages, **params)
full_response = completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
)
llm_output = {
"token_usage": full_response["usage"],
"model_name": self.model_name,
@ -891,7 +914,9 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
full_response = await acompletion_with_retry(self, messages=messages, **params)
full_response = await acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params
)
llm_output = {
"token_usage": full_response["usage"],
"model_name": self.model_name,

@ -1,7 +1,7 @@
"""Test OpenAI API wrapper."""
from pathlib import Path
from typing import Generator
from typing import Any, Generator
from unittest.mock import MagicMock, patch
import pytest
@ -10,7 +10,10 @@ from langchain.chat_models.openai import ChatOpenAI
from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI, OpenAIChat
from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
def test_openai_call() -> None:
@ -334,3 +337,77 @@ def test_chat_openai_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = ChatOpenAI(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
"object": "text_completion",
"created": 1689989000,
"model": "text-davinci-003",
"choices": [
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
@pytest.mark.requires("openai")
def test_openai_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
completed = True
return mock_completion
mock_client.create = raise_once
callback_handler = FakeCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1
@pytest.mark.requires("openai")
async def test_openai_async_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
raised = False
import openai
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed, raised
if not raised:
raised = True
raise openai.error.APIError
completed = True
return mock_completion
mock_client.create = raise_once
callback_handler = FakeAsyncCallbackHandler()
with patch.object(
llm,
"client",
mock_client,
):
res = llm.apredict("bar", callbacks=[callback_handler])
assert res == "Bar Baz"
assert completed
assert raised
assert callback_handler.retries == 1

@ -39,6 +39,7 @@ class BaseFakeCallbackHandler(BaseModel):
retriever_starts: int = 0
retriever_ends: int = 0
retriever_errors: int = 0
retries: int = 0
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
@ -58,8 +59,10 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
def on_llm_new_token_common(self) -> None:
self.llm_streams += 1
def on_retry_common(self) -> None:
self.retries += 1
def on_chain_start_common(self) -> None:
("CHAIN START")
self.chain_starts += 1
self.starts += 1
@ -82,7 +85,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.errors += 1
def on_agent_action_common(self) -> None:
print("AGENT ACTION")
self.agent_actions += 1
self.starts += 1
@ -91,7 +93,6 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.ends += 1
def on_chat_model_start_common(self) -> None:
print("STARTING CHAT MODEL")
self.chat_model_starts += 1
self.starts += 1
@ -162,6 +163,13 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_error_common()
def on_retry(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retry_common()
def on_chain_start(
self,
*args: Any,

@ -1,8 +1,12 @@
"""Test OpenAI Chat API wrapper."""
import json
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from langchain.chat_models.openai import (
ChatOpenAI,
_convert_dict_to_message,
)
from langchain.schema.messages import FunctionMessage
@ -21,3 +25,67 @@ def test_function_message_dict_to_function_message() -> None:
assert isinstance(result, FunctionMessage)
assert result.name == name
assert result.content == content
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "chatcmpl-7fcZavknQda3SQ",
"object": "chat.completion",
"created": 1689989000,
"model": "gpt-3.5-turbo-0613",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Bar Baz",
},
"finish_reason": "stop",
}
],
}
@pytest.mark.requires("openai")
def test_openai_predict(mock_completion: dict) -> None:
llm = ChatOpenAI()
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed
@pytest.mark.requires("openai")
async def test_openai_apredict(mock_completion: dict) -> None:
llm = ChatOpenAI()
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed

@ -1,4 +1,6 @@
import os
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
@ -26,3 +28,61 @@ def test_openai_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = OpenAI(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
"object": "text_completion",
"created": 1689989000,
"model": "text-davinci-003",
"choices": [
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
@pytest.mark.requires("openai")
def test_openai_calls(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = raise_once
with patch.object(
llm,
"client",
mock_client,
):
res = llm.predict("bar")
assert res == "Bar Baz"
assert completed
@pytest.mark.requires("openai")
async def test_openai_async_retries(mock_completion: dict) -> None:
llm = OpenAI()
mock_client = MagicMock()
completed = False
def raise_once(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = raise_once
with patch.object(
llm,
"client",
mock_client,
):
res = llm.apredict("bar")
assert res == "Bar Baz"
assert completed

Loading…
Cancel
Save