Added retries for Vertex LLM (#7219)

#7217

---------

Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
Leonid Kuligin 2023-07-06 15:38:01 +02:00 committed by GitHub
parent ec66d5188c
commit 8b19f6a0da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 88 additions and 29 deletions

View File

@ -1,14 +1,36 @@
"""Base interface for large language models to expose."""
from __future__ import annotations
import asyncio
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import yaml
from pydantic import Field, root_validator, validator
from tenacity import (
before_sleep_log,
retry,
retry_base,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
import langchain
from langchain.base_language import BaseLanguageModel
@ -29,11 +51,34 @@ from langchain.schema import (
)
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
logger = logging.getLogger(__name__)
def _get_verbosity() -> bool:
return langchain.verbose
def create_base_retry_decorator(
error_types: List[Type[BaseException]], max_retries: int = 1
) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided list of error types."""
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: "retry_base" = retry_if_exception_type(error_types[0])
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(error)
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_instance,
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def get_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:

View File

@ -21,19 +21,12 @@ from typing import (
)
from pydantic import Field, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
@ -76,23 +69,14 @@ def _streaming_response_template() -> Dict[str, Any]:
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
errors = [
openai.error.Timeout,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
]
return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries)
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:

View File

@ -1,7 +1,9 @@
"""Wrapper around Google VertexAI models."""
from __future__ import annotations
import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional
from pydantic import BaseModel, root_validator
@ -9,7 +11,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.llms.utils import enforce_stop_tokens
from langchain.utilities.vertexai import (
init_vertexai,
@ -24,6 +26,32 @@ def is_codey_model(model_name: str) -> bool:
return "code" in model_name
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
import google.api_core
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries # type: ignore
)
return decorator
def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict(*args, **kwargs)
return _completion_with_retry(*args, **kwargs)
class _VertexAICommon(BaseModel):
client: "_LanguageModel" = None #: :meta private:
model_name: str
@ -51,6 +79,8 @@ class _VertexAICommon(BaseModel):
request_parallelism: int = 5
"The amount of parallelism allowed for requests issued to VertexAI models. "
"Default is 5."
max_retries: int = 6
"""The maximum number of retries to make when generating."""
task_executor: ClassVar[Optional[Executor]] = None
@property
@ -76,7 +106,7 @@ class _VertexAICommon(BaseModel):
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
params = {**self._default_params, **kwargs}
res = self.client.predict(prompt, **params)
res = completion_with_retry(self, prompt, **params) # type: ignore
return self._enforce_stop_words(res.text, stop)
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: