diff --git a/langchain/llms/vertexai.py b/langchain/llms/vertexai.py index a7a9259c..d9d67b37 100644 --- a/langchain/llms/vertexai.py +++ b/langchain/llms/vertexai.py @@ -43,8 +43,8 @@ class _VertexAICommon(BaseModel): base_params = { "temperature": self.temperature, "max_output_tokens": self.max_output_tokens, - "top_k": self.top_p, - "top_p": self.top_k, + "top_k": self.top_k, + "top_p": self.top_p, } return {**base_params} diff --git a/tests/integration_tests/chat_models/test_vertexai.py b/tests/integration_tests/chat_models/test_vertexai.py index 10b9c698..cb69d68f 100644 --- a/tests/integration_tests/chat_models/test_vertexai.py +++ b/tests/integration_tests/chat_models/test_vertexai.py @@ -7,6 +7,8 @@ pip install google-cloud-aiplatform>=1.25.0 Your end-user credentials would be used to make the calls (make sure you've run `gcloud auth login` first). """ +from unittest.mock import Mock, patch + import pytest from langchain.chat_models import ChatVertexAI @@ -86,3 +88,31 @@ def test_vertexai_single_call_failes_no_message() -> None: str(exc_info.value) == "You should provide at least one message to start the chat!" ) + + +def test_vertexai_args_passed() -> None: + response_text = "Goodbye" + user_prompt = "Hello" + prompt_params = { + "max_output_tokens": 1, + "temperature": 10000.0, + "top_k": 10, + "top_p": 0.5, + } + + # Mock the library to ensure the args are passed correctly + with patch( + "vertexai.language_models._language_models.ChatSession.send_message" + ) as send_message: + mock_response = Mock(text=response_text) + send_message.return_value = mock_response + + model = ChatVertexAI(**prompt_params) + message = HumanMessage(content=user_prompt) + response = model([message]) + + assert response.content == response_text + send_message.assert_called_once_with( + user_prompt, + **prompt_params, + )