changed default params for gemini (#16044)

Replace this entire comment with:
- **Description:** changed default values for Vertex LLMs (to be handled
on the SDK's side)
pull/16162/head
Leonid Kuligin 6 months ago committed by GitHub
parent ec9642d667
commit 58f0ba306b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -41,6 +41,11 @@ from langchain_google_vertexai._utils import (
is_gemini_model,
)
_PALM_DEFAULT_MAX_OUTPUT_TOKENS = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
_PALM_DEFAULT_TEMPERATURE = 0.0
_PALM_DEFAULT_TOP_P = 0.95
_PALM_DEFAULT_TOP_K = 40
def _completion_with_retry(
llm: VertexAI,
@ -118,14 +123,14 @@ class _VertexAICommon(_VertexAIBase):
client_preview: Any = None #: :meta private:
model_name: str
"Underlying model name."
temperature: float = 0.0
temperature: Optional[float] = None
"Sampling temperature, it controls the degree of randomness in token selection."
max_output_tokens: int = 128
max_output_tokens: Optional[int] = None
"Token limit determines the maximum amount of text output from one prompt."
top_p: float = 0.95
top_p: Optional[float] = None
"Tokens are selected from most probable to least until the sum of their "
"probabilities equals the top-p value. Top-p is ignored for Codey models."
top_k: int = 40
top_k: Optional[int] = None
"How the model selects tokens for output, the next token is selected from "
"among the top-k most probable tokens. Top-k is ignored for Codey models."
credentials: Any = Field(default=None, exclude=True)
@ -156,6 +161,15 @@ class _VertexAICommon(_VertexAIBase):
@property
def _default_params(self) -> Dict[str, Any]:
if self._is_gemini_model:
default_params = {}
else:
default_params = {
"temperature": _PALM_DEFAULT_TEMPERATURE,
"max_output_tokens": _PALM_DEFAULT_MAX_OUTPUT_TOKENS,
"top_p": _PALM_DEFAULT_TOP_P,
"top_k": _PALM_DEFAULT_TOP_K,
}
params = {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
@ -168,7 +182,14 @@ class _VertexAICommon(_VertexAIBase):
"top_p": self.top_p,
}
)
return params
updated_params = {}
for param_name, param_value in params.items():
default_value = default_params.get(param_name)
if param_value or default_value:
updated_params[param_name] = (
param_value if param_value else default_value
)
return updated_params
@classmethod
def _init_vertexai(cls, values: Dict) -> None:
@ -314,7 +335,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
**kwargs: Any,
) -> LLMResult:
params = self._prepare_params(stop=stop, **kwargs)
generations = []
generations: List[List[Generation]] = []
for prompt in prompts:
res = await _acompletion_with_retry(
self,

@ -68,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
mock_model.start_chat = mock_start_chat
mg.return_value = mock_model
model = ChatVertexAI(**prompt_params)
model = ChatVertexAI(**prompt_params) # type: ignore
message = HumanMessage(content=user_prompt)
if stop:
response = model([message], stop=[stop])
@ -110,3 +110,52 @@ def test_parse_chat_history_correct() -> None:
ChatMessage(content=text_question, author="user"),
ChatMessage(content=text_answer, author="bot"),
]
def test_default_params_palm() -> None:
user_prompt = "Hello"
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [Mock(text="Goodbye")]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
mg.return_value = mock_model
model = ChatVertexAI(model_name="text-bison@001")
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(
context=None,
message_history=[],
max_output_tokens=128,
top_k=40,
top_p=0.95,
stop_sequences=None,
)
def test_default_params_gemini() -> None:
user_prompt = "Hello"
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
mock_response = MagicMock()
content = Mock(parts=[Mock(function_call=None)])
mock_response.candidates = [Mock(text="Goodbye", content=content)]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
gm.return_value = mock_model
model = ChatVertexAI(model_name="gemini-pro")
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(history=[])

Loading…
Cancel
Save