mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Added retries for Vertex LLM (#7219)
#7217 --------- Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
parent
ec66d5188c
commit
8b19f6a0da
@ -1,14 +1,36 @@
|
||||
"""Base interface for large language models to expose."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import yaml
|
||||
from pydantic import Field, root_validator, validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_base,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
import langchain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
@ -29,11 +51,34 @@ from langchain.schema import (
|
||||
)
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
def create_base_retry_decorator(
|
||||
error_types: List[Type[BaseException]], max_retries: int = 1
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Create a retry decorator for a given LLM and provided list of error types."""
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
retry_instance: "retry_base" = retry_if_exception_type(error_types[0])
|
||||
for error in error_types[1:]:
|
||||
retry_instance = retry_instance | retry_if_exception_type(error)
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=retry_instance,
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def get_prompts(
|
||||
params: Dict[str, Any], prompts: List[str]
|
||||
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
||||
|
@ -21,19 +21,12 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@ -76,23 +69,14 @@ def _streaming_response_template() -> Dict[str, Any]:
|
||||
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
|
||||
import openai
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(llm.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
errors = [
|
||||
openai.error.Timeout,
|
||||
openai.error.APIError,
|
||||
openai.error.APIConnectionError,
|
||||
openai.error.RateLimitError,
|
||||
openai.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries)
|
||||
|
||||
|
||||
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
|
||||
|
@ -1,7 +1,9 @@
|
||||
"""Wrapper around Google VertexAI models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
@ -9,7 +11,7 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import LLM, create_base_retry_decorator
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utilities.vertexai import (
|
||||
init_vertexai,
|
||||
@ -24,6 +26,32 @@ def is_codey_model(model_name: str) -> bool:
|
||||
return "code" in model_name
|
||||
|
||||
|
||||
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
|
||||
import google.api_core
|
||||
|
||||
errors = [
|
||||
google.api_core.exceptions.ResourceExhausted,
|
||||
google.api_core.exceptions.ServiceUnavailable,
|
||||
google.api_core.exceptions.Aborted,
|
||||
google.api_core.exceptions.DeadlineExceeded,
|
||||
]
|
||||
decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries # type: ignore
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.predict(*args, **kwargs)
|
||||
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
class _VertexAICommon(BaseModel):
|
||||
client: "_LanguageModel" = None #: :meta private:
|
||||
model_name: str
|
||||
@ -51,6 +79,8 @@ class _VertexAICommon(BaseModel):
|
||||
request_parallelism: int = 5
|
||||
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
||||
"Default is 5."
|
||||
max_retries: int = 6
|
||||
"""The maximum number of retries to make when generating."""
|
||||
task_executor: ClassVar[Optional[Executor]] = None
|
||||
|
||||
@property
|
||||
@ -76,7 +106,7 @@ class _VertexAICommon(BaseModel):
|
||||
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
params = {**self._default_params, **kwargs}
|
||||
res = self.client.predict(prompt, **params)
|
||||
res = completion_with_retry(self, prompt, **params) # type: ignore
|
||||
return self._enforce_stop_words(res.text, stop)
|
||||
|
||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user