From 85e1c9b34835c8c6acdd4caea86821b761515c6d Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Fri, 14 Jul 2023 08:03:04 +0200 Subject: [PATCH] Added support for examples for VertexAI chat models. (#7636) #5278 Co-authored-by: Leonid Kuligin --- langchain/chat_models/vertexai.py | 37 ++++++++++- .../chat_models/test_vertexai.py | 63 ++++++++++++++++--- 2 files changed, 90 insertions(+), 10 deletions(-) diff --git a/langchain/chat_models/vertexai.py b/langchain/chat_models/vertexai.py index 3c56410a4e..84ea206570 100644 --- a/langchain/chat_models/vertexai.py +++ b/langchain/chat_models/vertexai.py @@ -23,7 +23,7 @@ from langchain.schema.messages import ( from langchain.utilities.vertexai import raise_vertex_import_error if TYPE_CHECKING: - from vertexai.language_models import ChatMessage + from vertexai.language_models import ChatMessage, InputOutputTextPair @dataclass @@ -65,6 +65,36 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory: return chat_history +def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]: + from vertexai.language_models import InputOutputTextPair + + if len(examples) % 2 != 0: + raise ValueError( + f"Expect examples to have an even amount of messages, got {len(examples)}." + ) + example_pairs = [] + input_text = None + for i, example in enumerate(examples): + if i % 2 == 0: + if not isinstance(example, HumanMessage): + raise ValueError( + f"Expected the first message in a part to be from human, got " + f"{type(example)} for the {i}th message." + ) + input_text = example.content + if i % 2 == 1: + if not isinstance(example, AIMessage): + raise ValueError( + f"Expected the second message in a part to be from AI, got " + f"{type(example)} for the {i}th message." + ) + pair = InputOutputTextPair( + input_text=input_text, output_text=example.content + ) + example_pairs.append(pair) + return example_pairs + + class ChatVertexAI(_VertexAICommon, BaseChatModel): """Wrapper around Vertex AI large language models.""" @@ -120,13 +150,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): history = _parse_chat_history(messages[:-1]) context = history.context if history.context else None params = {**self._default_params, **kwargs} + examples = kwargs.get("examples", None) + if examples: + params["examples"] = _parse_examples(examples) if not self.is_codey_model: chat = self.client.start_chat( context=context, message_history=history.history, **params ) else: chat = self.client.start_chat(**params) - response = chat.send_message(question.content, **params) + response = chat.send_message(question.content) text = self._enforce_stop_words(response.text, stop) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) diff --git a/tests/integration_tests/chat_models/test_vertexai.py b/tests/integration_tests/chat_models/test_vertexai.py index 9e5682b655..3957f2a75c 100644 --- a/tests/integration_tests/chat_models/test_vertexai.py +++ b/tests/integration_tests/chat_models/test_vertexai.py @@ -7,12 +7,12 @@ pip install google-cloud-aiplatform>=1.25.0 Your end-user credentials would be used to make the calls (make sure you've run `gcloud auth login` first). """ -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from langchain.chat_models import ChatVertexAI -from langchain.chat_models.vertexai import _parse_chat_history +from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage @@ -42,6 +42,20 @@ def test_vertexai_single_call_with_context() -> None: assert isinstance(response.content, str) +def test_vertexai_single_call_with_examples() -> None: + model = ChatVertexAI() + raw_context = "My name is Ned. You are my personal assistant." + question = "2+2" + text_question, text_answer = "4+4", "8" + inp = HumanMessage(content=text_question) + output = AIMessage(content=text_answer) + context = SystemMessage(content=raw_context) + message = HumanMessage(content=question) + response = model([context, message], examples=[inp, output]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + def test_parse_chat_history_correct() -> None: from vertexai.language_models import ChatMessage @@ -92,17 +106,50 @@ def test_vertexai_args_passed() -> None: # Mock the library to ensure the args are passed correctly with patch( - "vertexai.language_models._language_models.ChatSession.send_message" - ) as send_message: + "vertexai.language_models._language_models.ChatModel.start_chat" + ) as start_chat: mock_response = Mock(text=response_text) - send_message.return_value = mock_response + mock_chat = MagicMock() + start_chat.return_value = mock_chat + mock_send_message = MagicMock(return_value=mock_response) + mock_chat.send_message = mock_send_message model = ChatVertexAI(**prompt_params) message = HumanMessage(content=user_prompt) response = model([message]) assert response.content == response_text - send_message.assert_called_once_with( - user_prompt, - **prompt_params, + mock_send_message.assert_called_once_with(user_prompt) + start_chat.assert_called_once_with( + context=None, message_history=[], **prompt_params ) + + +def test_parse_examples_correct() -> None: + from vertexai.language_models import InputOutputTextPair + + text_question = ( + "Hello, could you recommend a good movie for me to watch this evening, please?" + ) + question = HumanMessage(content=text_question) + text_answer = ( + "Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring " + "(2001): This is the first movie in the Lord of the Rings trilogy." + ) + answer = AIMessage(content=text_answer) + examples = _parse_examples([question, answer, question, answer]) + assert len(examples) == 2 + assert examples == [ + InputOutputTextPair(input_text=text_question, output_text=text_answer), + InputOutputTextPair(input_text=text_question, output_text=text_answer), + ] + + +def test_parse_exmaples_failes_wrong_sequence() -> None: + with pytest.raises(ValueError) as exc_info: + _ = _parse_examples([AIMessage(content="a")]) + print(str(exc_info.value)) + assert ( + str(exc_info.value) + == "Expect examples to have an even amount of messages, got 1." + )