|
|
|
@ -41,6 +41,11 @@ from langchain_google_vertexai._utils import (
|
|
|
|
|
is_gemini_model,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_PALM_DEFAULT_MAX_OUTPUT_TOKENS = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
|
|
|
|
|
_PALM_DEFAULT_TEMPERATURE = 0.0
|
|
|
|
|
_PALM_DEFAULT_TOP_P = 0.95
|
|
|
|
|
_PALM_DEFAULT_TOP_K = 40
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _completion_with_retry(
|
|
|
|
|
llm: VertexAI,
|
|
|
|
@ -118,14 +123,14 @@ class _VertexAICommon(_VertexAIBase):
|
|
|
|
|
client_preview: Any = None #: :meta private:
|
|
|
|
|
model_name: str
|
|
|
|
|
"Underlying model name."
|
|
|
|
|
temperature: float = 0.0
|
|
|
|
|
temperature: Optional[float] = None
|
|
|
|
|
"Sampling temperature, it controls the degree of randomness in token selection."
|
|
|
|
|
max_output_tokens: int = 128
|
|
|
|
|
max_output_tokens: Optional[int] = None
|
|
|
|
|
"Token limit determines the maximum amount of text output from one prompt."
|
|
|
|
|
top_p: float = 0.95
|
|
|
|
|
top_p: Optional[float] = None
|
|
|
|
|
"Tokens are selected from most probable to least until the sum of their "
|
|
|
|
|
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
|
|
|
|
top_k: int = 40
|
|
|
|
|
top_k: Optional[int] = None
|
|
|
|
|
"How the model selects tokens for output, the next token is selected from "
|
|
|
|
|
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
|
|
|
|
credentials: Any = Field(default=None, exclude=True)
|
|
|
|
@ -156,6 +161,15 @@ class _VertexAICommon(_VertexAIBase):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
|
|
|
if self._is_gemini_model:
|
|
|
|
|
default_params = {}
|
|
|
|
|
else:
|
|
|
|
|
default_params = {
|
|
|
|
|
"temperature": _PALM_DEFAULT_TEMPERATURE,
|
|
|
|
|
"max_output_tokens": _PALM_DEFAULT_MAX_OUTPUT_TOKENS,
|
|
|
|
|
"top_p": _PALM_DEFAULT_TOP_P,
|
|
|
|
|
"top_k": _PALM_DEFAULT_TOP_K,
|
|
|
|
|
}
|
|
|
|
|
params = {
|
|
|
|
|
"temperature": self.temperature,
|
|
|
|
|
"max_output_tokens": self.max_output_tokens,
|
|
|
|
@ -168,7 +182,14 @@ class _VertexAICommon(_VertexAIBase):
|
|
|
|
|
"top_p": self.top_p,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return params
|
|
|
|
|
updated_params = {}
|
|
|
|
|
for param_name, param_value in params.items():
|
|
|
|
|
default_value = default_params.get(param_name)
|
|
|
|
|
if param_value or default_value:
|
|
|
|
|
updated_params[param_name] = (
|
|
|
|
|
param_value if param_value else default_value
|
|
|
|
|
)
|
|
|
|
|
return updated_params
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _init_vertexai(cls, values: Dict) -> None:
|
|
|
|
@ -314,7 +335,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> LLMResult:
|
|
|
|
|
params = self._prepare_params(stop=stop, **kwargs)
|
|
|
|
|
generations = []
|
|
|
|
|
generations: List[List[Generation]] = []
|
|
|
|
|
for prompt in prompts:
|
|
|
|
|
res = await _acompletion_with_retry(
|
|
|
|
|
self,
|
|
|
|
|