|
|
|
@ -1,62 +1,32 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
from langchain_core.outputs import Generation, LLMResult
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
|
|
|
|
from tenacity import (
|
|
|
|
|
before_sleep_log,
|
|
|
|
|
retry,
|
|
|
|
|
retry_if_exception_type,
|
|
|
|
|
stop_after_attempt,
|
|
|
|
|
wait_exponential,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
|
|
|
from langchain.llms import BaseLLM
|
|
|
|
|
from langchain.utilities.vertexai import create_retry_decorator
|
|
|
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_retry_decorator() -> Callable[[Any], Any]:
|
|
|
|
|
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
|
|
|
|
try:
|
|
|
|
|
import google.api_core.exceptions
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Could not import google-api-core python package. "
|
|
|
|
|
"Please install it with `pip install google-api-core`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
multiplier = 2
|
|
|
|
|
min_seconds = 1
|
|
|
|
|
max_seconds = 60
|
|
|
|
|
max_retries = 10
|
|
|
|
|
|
|
|
|
|
return retry(
|
|
|
|
|
reraise=True,
|
|
|
|
|
stop=stop_after_attempt(max_retries),
|
|
|
|
|
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
|
|
|
|
retry=(
|
|
|
|
|
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
|
|
|
|
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
|
|
|
|
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
|
|
|
|
),
|
|
|
|
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
|
|
|
|
|
def completion_with_retry(
|
|
|
|
|
llm: GooglePalm,
|
|
|
|
|
*args: Any,
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""Use tenacity to retry the completion call."""
|
|
|
|
|
retry_decorator = _create_retry_decorator()
|
|
|
|
|
retry_decorator = create_retry_decorator(
|
|
|
|
|
llm, max_retries=llm.max_retries, run_manager=run_manager
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@retry_decorator
|
|
|
|
|
def _generate_with_retry(**kwargs: Any) -> Any:
|
|
|
|
|
return llm.client.generate_text(**kwargs)
|
|
|
|
|
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
|
|
|
|
return llm.client.generate_text(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return _generate_with_retry(**kwargs)
|
|
|
|
|
return _completion_with_retry(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
|
|
|
@ -94,6 +64,8 @@ class GooglePalm(BaseLLM, BaseModel):
|
|
|
|
|
n: int = 1
|
|
|
|
|
"""Number of chat completions to generate for each prompt. Note that the API may
|
|
|
|
|
not return the full n completions if duplicates are generated."""
|
|
|
|
|
max_retries: int = 6
|
|
|
|
|
"""The maximum number of retries to make when generating."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def lc_secrets(self) -> Dict[str, str]:
|
|
|
|
@ -144,7 +116,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
|
|
|
|
) -> LLMResult:
|
|
|
|
|
generations = []
|
|
|
|
|
for prompt in prompts:
|
|
|
|
|
completion = generate_with_retry(
|
|
|
|
|
completion = completion_with_retry(
|
|
|
|
|
self,
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
prompt=prompt,
|
|
|
|
@ -170,3 +142,17 @@ class GooglePalm(BaseLLM, BaseModel):
|
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
|
"""Return type of llm."""
|
|
|
|
|
return "google_palm"
|
|
|
|
|
|
|
|
|
|
def get_num_tokens(self, text: str) -> int:
|
|
|
|
|
"""Get the number of tokens present in the text.
|
|
|
|
|
|
|
|
|
|
Useful for checking if an input will fit in a model's context window.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
text: The string input to tokenize.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The integer number of tokens in the text.
|
|
|
|
|
"""
|
|
|
|
|
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
|
|
|
|
|
return result["token_count"]
|
|
|
|
|