From 58f0ba306b26a28044a108035a900b39c0d97bc7 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Wed, 17 Jan 2024 21:19:18 +0100 Subject: [PATCH] changed default params for gemini (#16044) Replace this entire comment with: - **Description:** changed default values for Vertex LLMs (to be handled on the SDK's side) --- .../langchain_google_vertexai/llms.py | 33 +++++++++--- .../tests/unit_tests/test_chat_models.py | 51 ++++++++++++++++++- 2 files changed, 77 insertions(+), 7 deletions(-) diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/llms.py b/libs/partners/google-vertexai/langchain_google_vertexai/llms.py index c3ff12d906..6e02ae43f9 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/llms.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/llms.py @@ -41,6 +41,11 @@ from langchain_google_vertexai._utils import ( is_gemini_model, ) +_PALM_DEFAULT_MAX_OUTPUT_TOKENS = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS +_PALM_DEFAULT_TEMPERATURE = 0.0 +_PALM_DEFAULT_TOP_P = 0.95 +_PALM_DEFAULT_TOP_K = 40 + def _completion_with_retry( llm: VertexAI, @@ -118,14 +123,14 @@ class _VertexAICommon(_VertexAIBase): client_preview: Any = None #: :meta private: model_name: str "Underlying model name." - temperature: float = 0.0 + temperature: Optional[float] = None "Sampling temperature, it controls the degree of randomness in token selection." - max_output_tokens: int = 128 + max_output_tokens: Optional[int] = None "Token limit determines the maximum amount of text output from one prompt." - top_p: float = 0.95 + top_p: Optional[float] = None "Tokens are selected from most probable to least until the sum of their " "probabilities equals the top-p value. Top-p is ignored for Codey models." - top_k: int = 40 + top_k: Optional[int] = None "How the model selects tokens for output, the next token is selected from " "among the top-k most probable tokens. Top-k is ignored for Codey models." credentials: Any = Field(default=None, exclude=True) @@ -156,6 +161,15 @@ class _VertexAICommon(_VertexAIBase): @property def _default_params(self) -> Dict[str, Any]: + if self._is_gemini_model: + default_params = {} + else: + default_params = { + "temperature": _PALM_DEFAULT_TEMPERATURE, + "max_output_tokens": _PALM_DEFAULT_MAX_OUTPUT_TOKENS, + "top_p": _PALM_DEFAULT_TOP_P, + "top_k": _PALM_DEFAULT_TOP_K, + } params = { "temperature": self.temperature, "max_output_tokens": self.max_output_tokens, @@ -168,7 +182,14 @@ class _VertexAICommon(_VertexAIBase): "top_p": self.top_p, } ) - return params + updated_params = {} + for param_name, param_value in params.items(): + default_value = default_params.get(param_name) + if param_value or default_value: + updated_params[param_name] = ( + param_value if param_value else default_value + ) + return updated_params @classmethod def _init_vertexai(cls, values: Dict) -> None: @@ -314,7 +335,7 @@ class VertexAI(_VertexAICommon, BaseLLM): **kwargs: Any, ) -> LLMResult: params = self._prepare_params(stop=stop, **kwargs) - generations = [] + generations: List[List[Generation]] = [] for prompt in prompts: res = await _acompletion_with_retry( self, diff --git a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py index bff39ee431..d11a970d65 100644 --- a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py @@ -68,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None: mock_model.start_chat = mock_start_chat mg.return_value = mock_model - model = ChatVertexAI(**prompt_params) + model = ChatVertexAI(**prompt_params) # type: ignore message = HumanMessage(content=user_prompt) if stop: response = model([message], stop=[stop]) @@ -110,3 +110,52 @@ def test_parse_chat_history_correct() -> None: ChatMessage(content=text_question, author="user"), ChatMessage(content=text_answer, author="bot"), ] + + +def test_default_params_palm() -> None: + user_prompt = "Hello" + + with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg: + mock_response = MagicMock() + mock_response.candidates = [Mock(text="Goodbye")] + mock_chat = MagicMock() + mock_send_message = MagicMock(return_value=mock_response) + mock_chat.send_message = mock_send_message + + mock_model = MagicMock() + mock_start_chat = MagicMock(return_value=mock_chat) + mock_model.start_chat = mock_start_chat + mg.return_value = mock_model + + model = ChatVertexAI(model_name="text-bison@001") + message = HumanMessage(content=user_prompt) + _ = model([message]) + mock_start_chat.assert_called_once_with( + context=None, + message_history=[], + max_output_tokens=128, + top_k=40, + top_p=0.95, + stop_sequences=None, + ) + + +def test_default_params_gemini() -> None: + user_prompt = "Hello" + + with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm: + mock_response = MagicMock() + content = Mock(parts=[Mock(function_call=None)]) + mock_response.candidates = [Mock(text="Goodbye", content=content)] + mock_chat = MagicMock() + mock_send_message = MagicMock(return_value=mock_response) + mock_chat.send_message = mock_send_message + + mock_model = MagicMock() + mock_start_chat = MagicMock(return_value=mock_chat) + mock_model.start_chat = mock_start_chat + gm.return_value = mock_model + model = ChatVertexAI(model_name="gemini-pro") + message = HumanMessage(content=user_prompt) + _ = model([message]) + mock_start_chat.assert_called_once_with(history=[])