diff --git a/langchain/chat_models/vertexai.py b/langchain/chat_models/vertexai.py index 4b090be66a..3c56410a4e 100644 --- a/langchain/chat_models/vertexai.py +++ b/langchain/chat_models/vertexai.py @@ -1,6 +1,6 @@ """Wrapper around Google VertexAI chat-based models.""" from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from pydantic import root_validator @@ -22,55 +22,46 @@ from langchain.schema.messages import ( ) from langchain.utilities.vertexai import raise_vertex_import_error - -@dataclass -class _MessagePair: - """InputOutputTextPair represents a pair of input and output texts.""" - - question: HumanMessage - answer: AIMessage +if TYPE_CHECKING: + from vertexai.language_models import ChatMessage @dataclass class _ChatHistory: - """InputOutputTextPair represents a pair of input and output texts.""" + """Represents a context and a history of messages.""" - history: List[_MessagePair] = field(default_factory=list) - system_message: Optional[SystemMessage] = None + history: List["ChatMessage"] = field(default_factory=list) + context: Optional[str] = None def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory: """Parse a sequence of messages into history. - A sequence should be either (SystemMessage, HumanMessage, AIMessage, - HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage, - AIMessage, ...). CodeChat does not support SystemMessage. - Args: history: The list of messages to re-create the history of the chat. Returns: A parsed chat history. Raises: - ValueError: If a sequence of message is odd, or a human message is not followed - by a message from AI (e.g., Human, Human, AI or AI, AI, Human). + ValueError: If a sequence of message has a SystemMessage not at the + first place. """ - if not history: - return _ChatHistory() - first_message = history[0] - system_message = first_message if isinstance(first_message, SystemMessage) else None - chat_history = _ChatHistory(system_message=system_message) - messages_left = history[1:] if system_message else history - if len(messages_left) % 2 != 0: - raise ValueError( - f"Amount of messages in history should be even, got {len(messages_left)}!" - ) - for question, answer in zip(messages_left[::2], messages_left[1::2]): - if not isinstance(question, HumanMessage) or not isinstance(answer, AIMessage): + from vertexai.language_models import ChatMessage + + vertex_messages, context = [], None + for i, message in enumerate(history): + if i == 0 and isinstance(message, SystemMessage): + context = message.content + elif isinstance(message, AIMessage): + vertex_message = ChatMessage(content=message.content, author="bot") + vertex_messages.append(vertex_message) + elif isinstance(message, HumanMessage): + vertex_message = ChatMessage(content=message.content, author="user") + vertex_messages.append(vertex_message) + else: raise ValueError( - "A human message should follow a bot one, " - f"got {question.type}, {answer.type}." + f"Unexpected message with type {type(message)} at the position {i}." ) - chat_history.history.append(_MessagePair(question=question, answer=answer)) + chat_history = _ChatHistory(context=context, history=vertex_messages) return chat_history @@ -126,16 +117,15 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): raise ValueError( f"Last message in the list should be from human, got {question.type}." ) - history = _parse_chat_history(messages[:-1]) - context = history.system_message.content if history.system_message else None + context = history.context if history.context else None params = {**self._default_params, **kwargs} if not self.is_codey_model: - chat = self.client.start_chat(context=context, **params) + chat = self.client.start_chat( + context=context, message_history=history.history, **params + ) else: chat = self.client.start_chat(**params) - for pair in history.history: - chat._history.append((pair.question.content, pair.answer.content)) response = chat.send_message(question.content, **params) text = self._enforce_stop_words(response.text, stop) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) diff --git a/langchain/utilities/vertexai.py b/langchain/utilities/vertexai.py index 7af3fdb68b..a934987cd8 100644 --- a/langchain/utilities/vertexai.py +++ b/langchain/utilities/vertexai.py @@ -11,7 +11,7 @@ def raise_vertex_import_error() -> None: Raises: ImportError: an ImportError that mentions a required version of the SDK. """ - sdk = "'google-cloud-aiplatform>=1.26.0'" + sdk = "'google-cloud-aiplatform>=1.26.1'" raise ImportError( "Could not import VertexAI. Please, install it with " f"pip install {sdk}" ) diff --git a/tests/integration_tests/chat_models/test_vertexai.py b/tests/integration_tests/chat_models/test_vertexai.py index d4d9ed9778..9e5682b655 100644 --- a/tests/integration_tests/chat_models/test_vertexai.py +++ b/tests/integration_tests/chat_models/test_vertexai.py @@ -12,7 +12,7 @@ from unittest.mock import Mock, patch import pytest from langchain.chat_models import ChatVertexAI -from langchain.chat_models.vertexai import _MessagePair, _parse_chat_history +from langchain.chat_models.vertexai import _parse_chat_history from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage @@ -43,6 +43,8 @@ def test_vertexai_single_call_with_context() -> None: def test_parse_chat_history_correct() -> None: + from vertexai.language_models import ChatMessage + text_context = ( "My name is Ned. You are my personal assistant. My " "favorite movies are Lord of the Rings and Hobbit." @@ -58,22 +60,14 @@ def test_parse_chat_history_correct() -> None: ) answer = AIMessage(content=text_answer) history = _parse_chat_history([context, question, answer, question, answer]) - assert history.system_message == context - assert len(history.history) == 2 - assert history.history[0] == _MessagePair(question=question, answer=answer) - - -def test_parse_chat_history_wrong_sequence() -> None: - text_question = ( - "Hello, could you recommend a good movie for me to watch this evening, please?" - ) - question = HumanMessage(content=text_question) - with pytest.raises(ValueError) as exc_info: - _ = _parse_chat_history([question, question]) - assert ( - str(exc_info.value) - == "A human message should follow a bot one, got human, human." - ) + assert history.context == context.content + assert len(history.history) == 4 + assert history.history == [ + ChatMessage(content=text_question, author="user"), + ChatMessage(content=text_answer, author="bot"), + ChatMessage(content=text_question, author="user"), + ChatMessage(content=text_answer, author="bot"), + ] def test_vertexai_single_call_failes_no_message() -> None: