|
|
|
@ -1,7 +1,9 @@
|
|
|
|
|
"""Wrapper around Google VertexAI models."""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
|
|
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
|
|
|
|
|
@ -9,7 +11,7 @@ from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
|
from langchain.llms.base import LLM, create_base_retry_decorator
|
|
|
|
|
from langchain.llms.utils import enforce_stop_tokens
|
|
|
|
|
from langchain.utilities.vertexai import (
|
|
|
|
|
init_vertexai,
|
|
|
|
@ -24,6 +26,32 @@ def is_codey_model(model_name: str) -> bool:
|
|
|
|
|
return "code" in model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
|
|
|
|
|
import google.api_core
|
|
|
|
|
|
|
|
|
|
errors = [
|
|
|
|
|
google.api_core.exceptions.ResourceExhausted,
|
|
|
|
|
google.api_core.exceptions.ServiceUnavailable,
|
|
|
|
|
google.api_core.exceptions.Aborted,
|
|
|
|
|
google.api_core.exceptions.DeadlineExceeded,
|
|
|
|
|
]
|
|
|
|
|
decorator = create_base_retry_decorator(
|
|
|
|
|
error_types=errors, max_retries=llm.max_retries # type: ignore
|
|
|
|
|
)
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
|
|
|
|
|
"""Use tenacity to retry the completion call."""
|
|
|
|
|
retry_decorator = _create_retry_decorator(llm)
|
|
|
|
|
|
|
|
|
|
@retry_decorator
|
|
|
|
|
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
|
|
|
|
return llm.client.predict(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return _completion_with_retry(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _VertexAICommon(BaseModel):
|
|
|
|
|
client: "_LanguageModel" = None #: :meta private:
|
|
|
|
|
model_name: str
|
|
|
|
@ -51,6 +79,8 @@ class _VertexAICommon(BaseModel):
|
|
|
|
|
request_parallelism: int = 5
|
|
|
|
|
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
|
|
|
|
"Default is 5."
|
|
|
|
|
max_retries: int = 6
|
|
|
|
|
"""The maximum number of retries to make when generating."""
|
|
|
|
|
task_executor: ClassVar[Optional[Executor]] = None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@ -76,7 +106,7 @@ class _VertexAICommon(BaseModel):
|
|
|
|
|
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
|
|
|
|
) -> str:
|
|
|
|
|
params = {**self._default_params, **kwargs}
|
|
|
|
|
res = self.client.predict(prompt, **params)
|
|
|
|
|
res = completion_with_retry(self, prompt, **params) # type: ignore
|
|
|
|
|
return self._enforce_stop_words(res.text, stop)
|
|
|
|
|
|
|
|
|
|
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
|
|
|
|