@ -1,9 +1,17 @@
""" Wrapper around Google ' s PaLM Chat API. """
""" Wrapper around Google ' s PaLM Chat API. """
from __future__ import annotations
from __future__ import annotations
from typing import TYPE_CHECKING , Any , Dict , List , Mapping , Optional
import logging
from typing import TYPE_CHECKING , Any , Callable , Dict , List , Mapping , Optional
from pydantic import BaseModel , root_validator
from pydantic import BaseModel , root_validator
from tenacity import (
before_sleep_log ,
retry ,
retry_if_exception_type ,
stop_after_attempt ,
wait_exponential ,
)
from langchain . callbacks . manager import (
from langchain . callbacks . manager import (
AsyncCallbackManagerForLLMRun ,
AsyncCallbackManagerForLLMRun ,
@ -24,6 +32,8 @@ from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING :
if TYPE_CHECKING :
import google . generativeai as genai
import google . generativeai as genai
logger = logging . getLogger ( __name__ )
class ChatGooglePalmError ( Exception ) :
class ChatGooglePalmError ( Exception ) :
pass
pass
@ -156,6 +166,51 @@ def _messages_to_prompt_dict(
)
)
def _create_retry_decorator ( ) - > Callable [ [ Any ] , Any ] :
""" Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions """
import google . api_core . exceptions
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry (
reraise = True ,
stop = stop_after_attempt ( max_retries ) ,
wait = wait_exponential ( multiplier = multiplier , min = min_seconds , max = max_seconds ) ,
retry = (
retry_if_exception_type ( google . api_core . exceptions . ResourceExhausted )
| retry_if_exception_type ( google . api_core . exceptions . ServiceUnavailable )
| retry_if_exception_type ( google . api_core . exceptions . GoogleAPIError )
) ,
before_sleep = before_sleep_log ( logger , logging . WARNING ) ,
)
def chat_with_retry ( llm : ChatGooglePalm , * * kwargs : Any ) - > Any :
""" Use tenacity to retry the completion call. """
retry_decorator = _create_retry_decorator ( )
@retry_decorator
def _chat_with_retry ( * * kwargs : Any ) - > Any :
return llm . client . chat ( * * kwargs )
return _chat_with_retry ( * * kwargs )
async def achat_with_retry ( llm : ChatGooglePalm , * * kwargs : Any ) - > Any :
""" Use tenacity to retry the async completion call. """
retry_decorator = _create_retry_decorator ( )
@retry_decorator
async def _achat_with_retry ( * * kwargs : Any ) - > Any :
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm . client . chat_async ( * * kwargs )
return await _achat_with_retry ( * * kwargs )
class ChatGooglePalm ( BaseChatModel , BaseModel ) :
class ChatGooglePalm ( BaseChatModel , BaseModel ) :
""" Wrapper around Google ' s PaLM Chat API.
""" Wrapper around Google ' s PaLM Chat API.
@ -227,7 +282,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
) - > ChatResult :
) - > ChatResult :
prompt = _messages_to_prompt_dict ( messages )
prompt = _messages_to_prompt_dict ( messages )
response : genai . types . ChatResponse = self . client . chat (
response : genai . types . ChatResponse = chat_with_retry (
self ,
model = self . model_name ,
model = self . model_name ,
prompt = prompt ,
prompt = prompt ,
temperature = self . temperature ,
temperature = self . temperature ,
@ -246,7 +302,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
) - > ChatResult :
) - > ChatResult :
prompt = _messages_to_prompt_dict ( messages )
prompt = _messages_to_prompt_dict ( messages )
response : genai . types . ChatResponse = await self . client . chat_async (
response : genai . types . ChatResponse = await achat_with_retry (
self ,
model = self . model_name ,
model = self . model_name ,
prompt = prompt ,
prompt = prompt ,
temperature = self . temperature ,
temperature = self . temperature ,