small fixes to Vertex (#10934)

Fixed tests, updated the required version of the SDK and a few minor
changes after the recent improvement
(https://github.com/langchain-ai/langchain/pull/10910)
pull/10938/head
Leonid Kuligin 11 months ago committed by GitHub
parent 4e58b78102
commit 9d4b710a48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -138,7 +138,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
values["client"] = ChatModel.from_pretrained(values["model_name"])
except ImportError:
raise_vertex_import_error(minimum_expected_version="1.29.0")
raise_vertex_import_error()
return values
def _generate(
@ -173,15 +173,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = {**self._default_params, **kwargs}
params = self._prepare_params(stop=stop, **kwargs)
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, params)
response = chat.send_message(question.content)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=response.text))]
)
async def _agenerate(
self,
@ -209,15 +210,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
logger.warning("ChatVertexAI does not currently support async streaming.")
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = {**self._default_params, **kwargs}
params = self._prepare_params(stop=stop, **kwargs)
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, params)
response = await chat.send_message_async(question.content)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=response.text))]
)
def _stream(
self,
@ -228,7 +230,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
) -> Iterator[ChatGenerationChunk]:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = {**self._default_params, **kwargs}
params = self._prepare_params(stop=stop, **kwargs)
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
@ -236,10 +238,9 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
chat = self._start_chat(history, params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
text = self._enforce_stop_words(response.text, stop)
if run_manager:
run_manager.on_llm_new_token(text)
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
def _start_chat(
self, history: _ChatHistory, params: dict

@ -18,7 +18,6 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import (
Generation,
@ -151,13 +150,6 @@ class _VertexAIBase(BaseModel):
model_name: Optional[str] = None
"Underlying model name."
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
if stop is None and self.stop is not None:
stop = self.stop
if stop:
return enforce_stop_tokens(text, stop)
return text
@classmethod
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
if cls.task_executor is None:
@ -220,6 +212,14 @@ class _VertexAICommon(_VertexAIBase):
init_vertexai(**params)
return None
def _prepare_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
stop_sequences = stop or self.stop
return {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""
@ -228,7 +228,6 @@ class VertexAI(_VertexAICommon, BaseLLM):
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None
"The name of a tuned model. If provided, model_name is ignored."
streaming: bool = False
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@ -267,10 +266,8 @@ class VertexAI(_VertexAICommon, BaseLLM):
stream: Optional[bool] = None,
**kwargs: Any,
) -> LLMResult:
stop_sequences = stop or self.stop
should_stream = stream if stream is not None else self.streaming
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
params = self._prepare_params(stop=stop, **kwargs)
generations = []
for prompt in prompts:
if should_stream:
@ -294,8 +291,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
stop_sequences = stop or self.stop
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
params = self._prepare_params(stop=stop, **kwargs)
generations = []
for prompt in prompts:
res = await acompletion_with_retry(
@ -311,8 +307,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
stop_sequences = stop or self.stop
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
params = self._prepare_params(stop=stop, **kwargs)
for stream_resp in stream_completion_with_retry(
self, prompt, run_manager=run_manager, **params
):

@ -5,7 +5,7 @@ if TYPE_CHECKING:
from google.auth.credentials import Credentials
def raise_vertex_import_error(minimum_expected_version: str = "1.26.1") -> None:
def raise_vertex_import_error(minimum_expected_version: str = "1.33.0") -> None:
"""Raise ImportError related to Vertex SDK being not available.
Args:

@ -7,6 +7,7 @@ pip install google-cloud-aiplatform>=1.25.0
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
from typing import Optional
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -27,7 +28,7 @@ def test_vertexai_single_call(model_name: str) -> None:
response = model([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert model._llm_type == "chat-vertexai"
assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id
@ -127,7 +128,8 @@ def test_vertexai_single_call_failes_no_message() -> None:
)
def test_vertexai_args_passed() -> None:
@pytest.mark.parametrize("stop", [None, "stop1"])
def test_vertexai_args_passed(stop: Optional[str]) -> None:
response_text = "Goodbye"
user_prompt = "Hello"
prompt_params = {
@ -149,12 +151,19 @@ def test_vertexai_args_passed() -> None:
model = ChatVertexAI(**prompt_params)
message = HumanMessage(content=user_prompt)
response = model([message])
if stop:
response = model([message], stop=[stop])
else:
response = model([message])
assert response.content == response_text
mock_send_message.assert_called_once_with(user_prompt)
expected_stop_sequence = [stop] if stop else None
start_chat.assert_called_once_with(
context=None, message_history=[], **prompt_params
context=None,
message_history=[],
**prompt_params,
stop_sequences=expected_stop_sequence
)

Loading…
Cancel
Save