From 8b19f6a0da7bd3b79fa481cf2e8c4b334d189d20 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Thu, 6 Jul 2023 15:38:01 +0200 Subject: [PATCH] Added retries for Vertex LLM (#7219) #7217 --------- Co-authored-by: Leonid Kuligin --- langchain/llms/base.py | 47 +++++++++++++++++++++++++++++++++++++- langchain/llms/openai.py | 34 ++++++++------------------- langchain/llms/vertexai.py | 36 ++++++++++++++++++++++++++--- 3 files changed, 88 insertions(+), 29 deletions(-) diff --git a/langchain/llms/base.py b/langchain/llms/base.py index cdb5ea5973..5a3ab0aafb 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -1,14 +1,36 @@ """Base interface for large language models to expose.""" +from __future__ import annotations + import asyncio import inspect import json +import logging import warnings from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) import yaml from pydantic import Field, root_validator, validator +from tenacity import ( + before_sleep_log, + retry, + retry_base, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) import langchain from langchain.base_language import BaseLanguageModel @@ -29,11 +51,34 @@ from langchain.schema import ( ) from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string +logger = logging.getLogger(__name__) + def _get_verbosity() -> bool: return langchain.verbose +def create_base_retry_decorator( + error_types: List[Type[BaseException]], max_retries: int = 1 +) -> Callable[[Any], Any]: + """Create a retry decorator for a given LLM and provided list of error types.""" + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + retry_instance: "retry_base" = retry_if_exception_type(error_types[0]) + for error in error_types[1:]: + retry_instance = retry_instance | retry_if_exception_type(error) + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=retry_instance, + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + def get_prompts( params: Dict[str, Any], prompts: List[str] ) -> Tuple[Dict[int, List], str, List[int], List[str]]: diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 19bce8bddc..3b497b2fd9 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -21,19 +21,12 @@ from typing import ( ) from pydantic import Field, root_validator -from tenacity import ( - before_sleep_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.llms.base import BaseLLM +from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env @@ -76,23 +69,14 @@ def _streaming_response_template() -> Dict[str, Any]: def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]: import openai - min_seconds = 4 - max_seconds = 10 - # Wait 2^x * 1 second between each retry starting with - # 4 seconds, then up to 10 seconds, then 10 seconds afterwards - return retry( - reraise=True, - stop=stop_after_attempt(llm.max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) - ), - before_sleep=before_sleep_log(logger, logging.WARNING), - ) + errors = [ + openai.error.Timeout, + openai.error.APIError, + openai.error.APIConnectionError, + openai.error.RateLimitError, + openai.error.ServiceUnavailableError, + ] + return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries) def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any: diff --git a/langchain/llms/vertexai.py b/langchain/llms/vertexai.py index 1e69830823..14ce8dbbee 100644 --- a/langchain/llms/vertexai.py +++ b/langchain/llms/vertexai.py @@ -1,7 +1,9 @@ """Wrapper around Google VertexAI models.""" +from __future__ import annotations + import asyncio from concurrent.futures import Executor, ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional from pydantic import BaseModel, root_validator @@ -9,7 +11,7 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.llms.base import LLM +from langchain.llms.base import LLM, create_base_retry_decorator from langchain.llms.utils import enforce_stop_tokens from langchain.utilities.vertexai import ( init_vertexai, @@ -24,6 +26,32 @@ def is_codey_model(model_name: str) -> bool: return "code" in model_name +def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]: + import google.api_core + + errors = [ + google.api_core.exceptions.ResourceExhausted, + google.api_core.exceptions.ServiceUnavailable, + google.api_core.exceptions.Aborted, + google.api_core.exceptions.DeadlineExceeded, + ] + decorator = create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries # type: ignore + ) + return decorator + + +def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: + return llm.client.predict(*args, **kwargs) + + return _completion_with_retry(*args, **kwargs) + + class _VertexAICommon(BaseModel): client: "_LanguageModel" = None #: :meta private: model_name: str @@ -51,6 +79,8 @@ class _VertexAICommon(BaseModel): request_parallelism: int = 5 "The amount of parallelism allowed for requests issued to VertexAI models. " "Default is 5." + max_retries: int = 6 + """The maximum number of retries to make when generating.""" task_executor: ClassVar[Optional[Executor]] = None @property @@ -76,7 +106,7 @@ class _VertexAICommon(BaseModel): self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any ) -> str: params = {**self._default_params, **kwargs} - res = self.client.predict(prompt, **params) + res = completion_with_retry(self, prompt, **params) # type: ignore return self._enforce_stop_words(res.text, stop) def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: