Exponential back-off support for Google PaLM api (#4001)

This PR adds exponential back-off to the Google PaLM api to gracefully
handle rate limiting errors.

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
dynamic_agent_tools
Mose Tronci 1 year ago committed by GitHub
parent a6f3ec94bc
commit a9dbe90447
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,17 @@
"""Wrapper around Google's PaLM Chat API."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from pydantic import BaseModel, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -24,6 +32,8 @@ from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
import google.generativeai as genai
logger = logging.getLogger(__name__)
class ChatGooglePalmError(Exception):
pass
@ -156,6 +166,51 @@ def _messages_to_prompt_dict(
)
def _create_retry_decorator() -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
import google.api_core.exceptions
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
return llm.client.chat(**kwargs)
return _chat_with_retry(**kwargs)
async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.chat_async(**kwargs)
return await _achat_with_retry(**kwargs)
class ChatGooglePalm(BaseChatModel, BaseModel):
"""Wrapper around Google's PaLM Chat API.
@ -227,7 +282,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = self.client.chat(
response: genai.types.ChatResponse = chat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
@ -246,7 +302,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = await self.client.chat_async(
response: genai.types.ChatResponse = await achat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,

@ -1,16 +1,64 @@
"""Wrapper arround Google's PaLM Embeddings APIs."""
from typing import Any, Dict, List, Optional
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def _create_retry_decorator() -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
import google.api_core.exceptions
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def embed_with_retry(
embeddings: GooglePalmEmbeddings, *args: Any, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
return embeddings.client.generate_embeddings(*args, **kwargs)
return _embed_with_retry(*args, **kwargs)
class GooglePalmEmbeddings(BaseModel, Embeddings):
client: Any
google_api_key: Optional[str]
model_name: str = "models/embedding-gecko-001"
"""Model name to use."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@ -34,5 +82,5 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
embedding = self.client.generate_embeddings(self.model_name, text)
embedding = embed_with_retry(self, self.model_name, text)
return embedding["embedding"]

@ -1,9 +1,17 @@
"""Wrapper arround Google's PaLM Text APIs."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import logging
from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -13,6 +21,44 @@ from langchain.llms import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def _create_retry_decorator() -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
try:
import google.api_core.exceptions
except ImportError:
raise ImportError()
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _generate_with_retry(**kwargs: Any) -> Any:
return llm.client.generate_text(**kwargs)
return _generate_with_retry(**kwargs)
def _strip_erroneous_leading_spaces(text: str) -> str:
"""Strip erroneous leading spaces from text.
@ -85,7 +131,8 @@ class GooglePalm(BaseLLM, BaseModel):
) -> LLMResult:
generations = []
for prompt in prompts:
completion = self.client.generate_text(
completion = generate_with_retry(
self,
model=self.model_name,
prompt=prompt,
stop_sequences=stop,

Loading…
Cancel
Save