diff --git a/langchain/chat_models/google_palm.py b/langchain/chat_models/google_palm.py index 0e9a2a15..cdfbd8e1 100644 --- a/langchain/chat_models/google_palm.py +++ b/langchain/chat_models/google_palm.py @@ -1,9 +1,17 @@ """Wrapper around Google's PaLM Chat API.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional +import logging +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional from pydantic import BaseModel, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -24,6 +32,8 @@ from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: import google.generativeai as genai +logger = logging.getLogger(__name__) + class ChatGooglePalmError(Exception): pass @@ -156,6 +166,51 @@ def _messages_to_prompt_dict( ) +def _create_retry_decorator() -> Callable[[Any], Any]: + """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" + import google.api_core.exceptions + + multiplier = 2 + min_seconds = 1 + max_seconds = 60 + max_retries = 10 + + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) + | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) + | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator() + + @retry_decorator + def _chat_with_retry(**kwargs: Any) -> Any: + return llm.client.chat(**kwargs) + + return _chat_with_retry(**kwargs) + + +async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator() + + @retry_decorator + async def _achat_with_retry(**kwargs: Any) -> Any: + # Use OpenAI's async api https://github.com/openai/openai-python#async-api + return await llm.client.chat_async(**kwargs) + + return await _achat_with_retry(**kwargs) + + class ChatGooglePalm(BaseChatModel, BaseModel): """Wrapper around Google's PaLM Chat API. @@ -227,7 +282,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel): ) -> ChatResult: prompt = _messages_to_prompt_dict(messages) - response: genai.types.ChatResponse = self.client.chat( + response: genai.types.ChatResponse = chat_with_retry( + self, model=self.model_name, prompt=prompt, temperature=self.temperature, @@ -246,7 +302,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel): ) -> ChatResult: prompt = _messages_to_prompt_dict(messages) - response: genai.types.ChatResponse = await self.client.chat_async( + response: genai.types.ChatResponse = await achat_with_retry( + self, model=self.model_name, prompt=prompt, temperature=self.temperature, diff --git a/langchain/embeddings/google_palm.py b/langchain/embeddings/google_palm.py index 0d198137..5be7e736 100644 --- a/langchain/embeddings/google_palm.py +++ b/langchain/embeddings/google_palm.py @@ -1,16 +1,64 @@ """Wrapper arround Google's PaLM Embeddings APIs.""" -from typing import Any, Dict, List, Optional +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Optional from pydantic import BaseModel, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from langchain.embeddings.base import Embeddings from langchain.utils import get_from_dict_or_env +logger = logging.getLogger(__name__) + + +def _create_retry_decorator() -> Callable[[Any], Any]: + """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" + import google.api_core.exceptions + + multiplier = 2 + min_seconds = 1 + max_seconds = 60 + max_retries = 10 + + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) + | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) + | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def embed_with_retry( + embeddings: GooglePalmEmbeddings, *args: Any, **kwargs: Any +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator() + + @retry_decorator + def _embed_with_retry(*args: Any, **kwargs: Any) -> Any: + return embeddings.client.generate_embeddings(*args, **kwargs) + + return _embed_with_retry(*args, **kwargs) + class GooglePalmEmbeddings(BaseModel, Embeddings): client: Any google_api_key: Optional[str] model_name: str = "models/embedding-gecko-001" + """Model name to use.""" @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -34,5 +82,5 @@ class GooglePalmEmbeddings(BaseModel, Embeddings): def embed_query(self, text: str) -> List[float]: """Embed query text.""" - embedding = self.client.generate_embeddings(self.model_name, text) + embedding = embed_with_retry(self, self.model_name, text) return embedding["embedding"] diff --git a/langchain/llms/google_palm.py b/langchain/llms/google_palm.py index ef418532..ee709cba 100644 --- a/langchain/llms/google_palm.py +++ b/langchain/llms/google_palm.py @@ -1,9 +1,17 @@ """Wrapper arround Google's PaLM Text APIs.""" from __future__ import annotations -from typing import Any, Dict, List, Optional +import logging +from typing import Any, Callable, Dict, List, Optional from pydantic import BaseModel, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -13,6 +21,44 @@ from langchain.llms import BaseLLM from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env +logger = logging.getLogger(__name__) + + +def _create_retry_decorator() -> Callable[[Any], Any]: + """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" + try: + import google.api_core.exceptions + except ImportError: + raise ImportError() + + multiplier = 2 + min_seconds = 1 + max_seconds = 60 + max_retries = 10 + + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) + | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) + | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator() + + @retry_decorator + def _generate_with_retry(**kwargs: Any) -> Any: + return llm.client.generate_text(**kwargs) + + return _generate_with_retry(**kwargs) + def _strip_erroneous_leading_spaces(text: str) -> str: """Strip erroneous leading spaces from text. @@ -85,7 +131,8 @@ class GooglePalm(BaseLLM, BaseModel): ) -> LLMResult: generations = [] for prompt in prompts: - completion = self.client.generate_text( + completion = generate_with_retry( + self, model=self.model_name, prompt=prompt, stop_sequences=stop,