From 9d4b710a486ed877eff06ec3c31732d86c4149a0 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Fri, 22 Sep 2023 17:18:09 +0200 Subject: [PATCH] small fixes to Vertex (#10934) Fixed tests, updated the required version of the SDK and a few minor changes after the recent improvement (https://github.com/langchain-ai/langchain/pull/10910) --- .../langchain/chat_models/vertexai.py | 23 ++++++++-------- libs/langchain/langchain/llms/vertexai.py | 27 ++++++++----------- .../langchain/langchain/utilities/vertexai.py | 2 +- .../chat_models/test_vertexai.py | 17 +++++++++--- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/libs/langchain/langchain/chat_models/vertexai.py b/libs/langchain/langchain/chat_models/vertexai.py index a2e2850305..3407620e69 100644 --- a/libs/langchain/langchain/chat_models/vertexai.py +++ b/libs/langchain/langchain/chat_models/vertexai.py @@ -138,7 +138,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): values["client"] = ChatModel.from_pretrained(values["model_name"]) except ImportError: - raise_vertex_import_error(minimum_expected_version="1.29.0") + raise_vertex_import_error() return values def _generate( @@ -173,15 +173,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): question = _get_question(messages) history = _parse_chat_history(messages[:-1]) - params = {**self._default_params, **kwargs} + params = self._prepare_params(stop=stop, **kwargs) examples = kwargs.get("examples", None) if examples: params["examples"] = _parse_examples(examples) chat = self._start_chat(history, params) response = chat.send_message(question.content) - text = self._enforce_stop_words(response.text, stop) - return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) + return ChatResult( + generations=[ChatGeneration(message=AIMessage(content=response.text))] + ) async def _agenerate( self, @@ -209,15 +210,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): logger.warning("ChatVertexAI does not currently support async streaming.") question = _get_question(messages) history = _parse_chat_history(messages[:-1]) - params = {**self._default_params, **kwargs} + params = self._prepare_params(stop=stop, **kwargs) examples = kwargs.get("examples", None) if examples: params["examples"] = _parse_examples(examples) chat = self._start_chat(history, params) response = await chat.send_message_async(question.content) - text = self._enforce_stop_words(response.text, stop) - return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) + return ChatResult( + generations=[ChatGeneration(message=AIMessage(content=response.text))] + ) def _stream( self, @@ -228,7 +230,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): ) -> Iterator[ChatGenerationChunk]: question = _get_question(messages) history = _parse_chat_history(messages[:-1]) - params = {**self._default_params, **kwargs} + params = self._prepare_params(stop=stop, **kwargs) examples = kwargs.get("examples", None) if examples: params["examples"] = _parse_examples(examples) @@ -236,10 +238,9 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): chat = self._start_chat(history, params) responses = chat.send_message_streaming(question.content, **params) for response in responses: - text = self._enforce_stop_words(response.text, stop) if run_manager: - run_manager.on_llm_new_token(text) - yield ChatGenerationChunk(message=AIMessageChunk(content=text)) + run_manager.on_llm_new_token(response.text) + yield ChatGenerationChunk(message=AIMessageChunk(content=response.text)) def _start_chat( self, history: _ChatHistory, params: dict diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index 1367f9d6bb..b0e6ea2dd3 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -18,7 +18,6 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.llms.base import BaseLLM, create_base_retry_decorator -from langchain.llms.utils import enforce_stop_tokens from langchain.pydantic_v1 import BaseModel, root_validator from langchain.schema import ( Generation, @@ -151,13 +150,6 @@ class _VertexAIBase(BaseModel): model_name: Optional[str] = None "Underlying model name." - def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: - if stop is None and self.stop is not None: - stop = self.stop - if stop: - return enforce_stop_tokens(text, stop) - return text - @classmethod def _get_task_executor(cls, request_parallelism: int = 5) -> Executor: if cls.task_executor is None: @@ -220,6 +212,14 @@ class _VertexAICommon(_VertexAIBase): init_vertexai(**params) return None + def _prepare_params( + self, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> dict: + stop_sequences = stop or self.stop + return {**self._default_params, "stop_sequences": stop_sequences, **kwargs} + class VertexAI(_VertexAICommon, BaseLLM): """Google Vertex AI large language models.""" @@ -228,7 +228,6 @@ class VertexAI(_VertexAICommon, BaseLLM): "The name of the Vertex AI large language model." tuned_model_name: Optional[str] = None "The name of a tuned model. If provided, model_name is ignored." - streaming: bool = False @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -267,10 +266,8 @@ class VertexAI(_VertexAICommon, BaseLLM): stream: Optional[bool] = None, **kwargs: Any, ) -> LLMResult: - stop_sequences = stop or self.stop should_stream = stream if stream is not None else self.streaming - - params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs} + params = self._prepare_params(stop=stop, **kwargs) generations = [] for prompt in prompts: if should_stream: @@ -294,8 +291,7 @@ class VertexAI(_VertexAICommon, BaseLLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: - stop_sequences = stop or self.stop - params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs} + params = self._prepare_params(stop=stop, **kwargs) generations = [] for prompt in prompts: res = await acompletion_with_retry( @@ -311,8 +307,7 @@ class VertexAI(_VertexAICommon, BaseLLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: - stop_sequences = stop or self.stop - params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs} + params = self._prepare_params(stop=stop, **kwargs) for stream_resp in stream_completion_with_retry( self, prompt, run_manager=run_manager, **params ): diff --git a/libs/langchain/langchain/utilities/vertexai.py b/libs/langchain/langchain/utilities/vertexai.py index 23f6d00119..244292db43 100644 --- a/libs/langchain/langchain/utilities/vertexai.py +++ b/libs/langchain/langchain/utilities/vertexai.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from google.auth.credentials import Credentials -def raise_vertex_import_error(minimum_expected_version: str = "1.26.1") -> None: +def raise_vertex_import_error(minimum_expected_version: str = "1.33.0") -> None: """Raise ImportError related to Vertex SDK being not available. Args: diff --git a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py index b3e3feca95..cc5163e525 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py @@ -7,6 +7,7 @@ pip install google-cloud-aiplatform>=1.25.0 Your end-user credentials would be used to make the calls (make sure you've run `gcloud auth login` first). """ +from typing import Optional from unittest.mock import MagicMock, Mock, patch import pytest @@ -27,7 +28,7 @@ def test_vertexai_single_call(model_name: str) -> None: response = model([message]) assert isinstance(response, AIMessage) assert isinstance(response.content, str) - assert model._llm_type == "chat-vertexai" + assert model._llm_type == "vertexai" assert model.model_name == model.client._model_id @@ -127,7 +128,8 @@ def test_vertexai_single_call_failes_no_message() -> None: ) -def test_vertexai_args_passed() -> None: +@pytest.mark.parametrize("stop", [None, "stop1"]) +def test_vertexai_args_passed(stop: Optional[str]) -> None: response_text = "Goodbye" user_prompt = "Hello" prompt_params = { @@ -149,12 +151,19 @@ def test_vertexai_args_passed() -> None: model = ChatVertexAI(**prompt_params) message = HumanMessage(content=user_prompt) - response = model([message]) + if stop: + response = model([message], stop=[stop]) + else: + response = model([message]) assert response.content == response_text mock_send_message.assert_called_once_with(user_prompt) + expected_stop_sequence = [stop] if stop else None start_chat.assert_called_once_with( - context=None, message_history=[], **prompt_params + context=None, + message_history=[], + **prompt_params, + stop_sequences=expected_stop_sequence )