added get_num_tokens to GooglePalm (#14282)

added get_num_tokens to GooglePalm + a little bit of refactoring
pull/14325/head
Leonid Kuligin 6 months ago committed by GitHub
parent c215a4c9ec
commit fd5be55a7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,62 +1,32 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import BaseLLM
from langchain.utilities.vertexai import create_retry_decorator
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def _create_retry_decorator() -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
try:
import google.api_core.exceptions
except ImportError:
raise ImportError(
"Could not import google-api-core python package. "
"Please install it with `pip install google-api-core`."
)
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
def completion_with_retry(
llm: GooglePalm,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
retry_decorator = create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _generate_with_retry(**kwargs: Any) -> Any:
return llm.client.generate_text(**kwargs)
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.generate_text(*args, **kwargs)
return _generate_with_retry(**kwargs)
return _completion_with_retry(*args, **kwargs)
def _strip_erroneous_leading_spaces(text: str) -> str:
@ -94,6 +64,8 @@ class GooglePalm(BaseLLM, BaseModel):
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
max_retries: int = 6
"""The maximum number of retries to make when generating."""
@property
def lc_secrets(self) -> Dict[str, str]:
@ -144,7 +116,7 @@ class GooglePalm(BaseLLM, BaseModel):
) -> LLMResult:
generations = []
for prompt in prompts:
completion = generate_with_retry(
completion = completion_with_retry(
self,
model=self.model_name,
prompt=prompt,
@ -170,3 +142,17 @@ class GooglePalm(BaseLLM, BaseModel):
def _llm_type(self) -> str:
"""Return type of llm."""
return "google_palm"
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
return result["token_count"]

@ -4,13 +4,11 @@ from concurrent.futures import Executor, ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Iterator,
List,
Optional,
Union,
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
@ -20,8 +18,9 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.llms.base import BaseLLM
from langchain.utilities.vertexai import (
create_retry_decorator,
get_client_info,
init_vertexai,
raise_vertex_import_error,
@ -65,27 +64,6 @@ def is_codey_model(model_name: str) -> bool:
return "code" in model_name
def _create_retry_decorator(
llm: VertexAI,
*,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import google.api_core
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
return decorator
def completion_with_retry(
llm: VertexAI,
*args: Any,
@ -93,7 +71,7 @@ def completion_with_retry(
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
@ -109,7 +87,9 @@ def stream_completion_with_retry(
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
retry_decorator = create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
@ -125,7 +105,7 @@ async def acompletion_with_retry(
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:

@ -1,12 +1,43 @@
"""Utilities to init Vertex AI."""
from importlib import metadata
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
if TYPE_CHECKING:
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.credentials import Credentials
def create_retry_decorator(
llm: BaseLLM,
*,
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Creates a retry decorator for Vertex / Palm LLMs."""
import google.api_core
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
google.api_core.exceptions.GoogleAPIError,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=max_retries, run_manager=run_manager
)
return decorator
def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None:
"""Raise ImportError related to Vertex SDK being not available.

@ -6,6 +6,8 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
from pathlib import Path
from langchain_core.outputs import LLMResult
from langchain.llms.google_palm import GooglePalm
from langchain.llms.loading import load_llm
@ -15,6 +17,22 @@ def test_google_palm_call() -> None:
llm = GooglePalm(max_output_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "google_palm"
assert llm.model_name == "models/text-bison-001"
def test_google_palm_generate() -> None:
llm = GooglePalm(temperature=0.3, n=2)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == 2
def test_google_palm_get_num_tokens() -> None:
llm = GooglePalm()
output = llm.get_num_tokens("How are you?")
assert output == 4
def test_saving_loading_llm(tmp_path: Path) -> None:

Loading…
Cancel
Save