"""Wrapper around Google's PaLM Chat API.""" from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, BaseMessage, ChatMessage, HumanMessage, SystemMessage, ) from langchain_core.outputs import ( ChatGeneration, ChatResult, ) from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from tenacity import ( before_sleep_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) if TYPE_CHECKING: import google.generativeai as genai logger = logging.getLogger(__name__) class ChatGooglePalmError(Exception): """Error with the `Google PaLM` API.""" def _truncate_at_stop_tokens( text: str, stop: Optional[List[str]], ) -> str: """Truncates text at the earliest stop token found.""" if stop is None: return text for stop_token in stop: stop_token_idx = text.find(stop_token) if stop_token_idx != -1: text = text[:stop_token_idx] return text def _response_to_result( response: genai.types.ChatResponse, stop: Optional[List[str]], ) -> ChatResult: """Converts a PaLM API response into a LangChain ChatResult.""" if not response.candidates: raise ChatGooglePalmError("ChatResponse must have at least one candidate.") generations: List[ChatGeneration] = [] for candidate in response.candidates: author = candidate.get("author") if author is None: raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}") content = _truncate_at_stop_tokens(candidate.get("content", ""), stop) if content is None: raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}") if author == "ai": generations.append( ChatGeneration(text=content, message=AIMessage(content=content)) ) elif author == "human": generations.append( ChatGeneration( text=content, message=HumanMessage(content=content), ) ) else: generations.append( ChatGeneration( text=content, message=ChatMessage(role=author, content=content), ) ) return ChatResult(generations=generations) def _messages_to_prompt_dict( input_messages: List[BaseMessage], ) -> genai.types.MessagePromptDict: """Converts a list of LangChain messages into a PaLM API MessagePrompt structure.""" import google.generativeai as genai context: str = "" examples: List[genai.types.MessageDict] = [] messages: List[genai.types.MessageDict] = [] remaining = list(enumerate(input_messages)) while remaining: index, input_message = remaining.pop(0) if isinstance(input_message, SystemMessage): if index != 0: raise ChatGooglePalmError("System message must be first input message.") context = cast(str, input_message.content) elif isinstance(input_message, HumanMessage) and input_message.example: if messages: raise ChatGooglePalmError( "Message examples must come before other messages." ) _, next_input_message = remaining.pop(0) if isinstance(next_input_message, AIMessage) and next_input_message.example: examples.extend( [ genai.types.MessageDict( author="human", content=input_message.content ), genai.types.MessageDict( author="ai", content=next_input_message.content ), ] ) else: raise ChatGooglePalmError( "Human example message must be immediately followed by an " " AI example response." ) elif isinstance(input_message, AIMessage) and input_message.example: raise ChatGooglePalmError( "AI example message must be immediately preceded by a Human " "example message." ) elif isinstance(input_message, AIMessage): messages.append( genai.types.MessageDict(author="ai", content=input_message.content) ) elif isinstance(input_message, HumanMessage): messages.append( genai.types.MessageDict(author="human", content=input_message.content) ) elif isinstance(input_message, ChatMessage): messages.append( genai.types.MessageDict( author=input_message.role, content=input_message.content ) ) else: raise ChatGooglePalmError( "Messages without an explicit role not supported by PaLM API." ) return genai.types.MessagePromptDict( context=context, examples=examples, messages=messages, ) def _create_retry_decorator() -> Callable[[Any], Any]: """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" import google.api_core.exceptions multiplier = 2 min_seconds = 1 max_seconds = 60 max_retries = 10 return retry( reraise=True, stop=stop_after_attempt(max_retries), wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), retry=( retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator() @retry_decorator def _chat_with_retry(**kwargs: Any) -> Any: return llm.client.chat(**kwargs) return _chat_with_retry(**kwargs) async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any: """Use tenacity to retry the async completion call.""" retry_decorator = _create_retry_decorator() @retry_decorator async def _achat_with_retry(**kwargs: Any) -> Any: # Use OpenAI's async api https://github.com/openai/openai-python#async-api return await llm.client.chat_async(**kwargs) return await _achat_with_retry(**kwargs) class ChatGooglePalm(BaseChatModel, BaseModel): """`Google PaLM` Chat models API. To use you must have the google.generativeai Python package installed and either: 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or 2. Pass your API key using the google_api_key kwarg to the ChatGoogle constructor. Example: .. code-block:: python from langchain_community.chat_models import ChatGooglePalm chat = ChatGooglePalm() """ client: Any #: :meta private: model_name: str = "models/chat-bison-001" """Model name to use.""" google_api_key: Optional[SecretStr] = None temperature: Optional[float] = None """Run inference with this temperature. Must by in the closed interval [0.0, 1.0].""" top_p: Optional[float] = None """Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" top_k: Optional[int] = None """Decode using top-k sampling: consider the set of top_k most probable tokens. Must be positive.""" n: int = 1 """Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.""" @property def lc_secrets(self) -> Dict[str, str]: return {"google_api_key": "GOOGLE_API_KEY"} @classmethod def is_lc_serializable(self) -> bool: return True @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "google_palm"] @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists, temperature, top_p, and top_k.""" google_api_key = convert_to_secret_str( get_from_dict_or_env(values, "google_api_key", "GOOGLE_API_KEY") ) try: import google.generativeai as genai genai.configure(api_key=google_api_key.get_secret_value()) except ImportError: raise ChatGooglePalmError( "Could not import google.generativeai python package. " "Please install it with `pip install google-generativeai`" ) values["client"] = genai if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: raise ValueError("temperature must be in the range [0.0, 1.0]") if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: raise ValueError("top_p must be in the range [0.0, 1.0]") if values["top_k"] is not None and values["top_k"] <= 0: raise ValueError("top_k must be positive") return values def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: prompt = _messages_to_prompt_dict(messages) response: genai.types.ChatResponse = chat_with_retry( self, model=self.model_name, prompt=prompt, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, candidate_count=self.n, **kwargs, ) return _response_to_result(response, stop) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: prompt = _messages_to_prompt_dict(messages) response: genai.types.ChatResponse = await achat_with_retry( self, model=self.model_name, prompt=prompt, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, candidate_count=self.n, ) return _response_to_result(response, stop) @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return { "model_name": self.model_name, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, "n": self.n, } @property def _llm_type(self) -> str: return "google-palm-chat"