diff --git a/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb b/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb index 0443dbf844..050a32f2bf 100644 --- a/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb +++ b/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb @@ -11,7 +11,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -95,7 +94,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If we want to construct a simple chain that takes user specified parameters:" + "Gemini doesn't support SystemMessage at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to `True`:" ] }, { @@ -106,7 +105,7 @@ { "data": { "text/plain": [ - "AIMessage(content=' プログラミングが大好きです')" + "AIMessage(content=\"J'aime la programmation.\")" ] }, "execution_count": 9, @@ -114,6 +113,40 @@ "output_type": "execute_result" } ], + "source": [ + "system = \"You are a helpful assistant who translate English to French\"\n", + "human = \"Translate this sentence from English to French. I love programming.\"\n", + "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n", + "\n", + "chat = ChatVertexAI(model_name=\"gemini-pro\", convert_system_message_to_human=True)\n", + "\n", + "chain = prompt | chat\n", + "chain.invoke({})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we want to construct a simple chain that takes user specified parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' プログラミングが大好きです')" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "system = (\n", " \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n", @@ -121,6 +154,8 @@ "human = \"{text}\"\n", "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n", "\n", + "chat = ChatVertexAI()\n", + "\n", "chain = prompt | chat\n", "\n", "chain.invoke(\n", @@ -133,7 +168,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "execution": { @@ -352,7 +386,7 @@ "AIMessage(content=' Why do you love programming?')" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -428,8 +462,14 @@ } ], "metadata": { + "environment": { + "kernel": "python3", + "name": "common-cpu.m108", + "type": "gcloud", + "uri": "gcr.io/deeplearning-platform-release/base-cpu:m108" + }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -443,7 +483,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.10" }, "vscode": { "interpreter": { diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py index a76ade2d22..1c62d76c75 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py @@ -111,7 +111,9 @@ def _is_url(s: str) -> bool: def _parse_chat_history_gemini( - history: List[BaseMessage], project: Optional[str] + history: List[BaseMessage], + project: Optional[str] = None, + convert_system_message_to_human: Optional[bool] = False, ) -> List[Content]: def _convert_to_prompt(part: Union[str, Dict]) -> Part: if isinstance(part, str): @@ -155,9 +157,25 @@ def _parse_chat_history_gemini( return [_convert_to_prompt(part) for part in raw_content] vertex_messages = [] + raw_system_message = None for i, message in enumerate(history): - if i == 0 and isinstance(message, SystemMessage): - raise ValueError("SystemMessages are not yet supported!") + if ( + i == 0 + and isinstance(message, SystemMessage) + and not convert_system_message_to_human + ): + raise ValueError( + """SystemMessages are not yet supported! + +To automatically convert the leading SystemMessage to a HumanMessage, +set `convert_system_message_to_human` to True. Example: + +llm = ChatVertexAI(model_name="gemini-pro", convert_system_message_to_human=True) +""" + ) + elif i == 0 and isinstance(message, SystemMessage): + raw_system_message = message + continue elif isinstance(message, AIMessage): raw_function_call = message.additional_kwargs.get("function_call") role = "model" @@ -170,6 +188,8 @@ def _parse_chat_history_gemini( ) gapic_part = GapicPart(function_call=function_call) parts = [Part._from_gapic(gapic_part)] + else: + parts = _convert_to_parts(message) elif isinstance(message, HumanMessage): role = "user" parts = _convert_to_parts(message) @@ -188,6 +208,15 @@ def _parse_chat_history_gemini( f"Unexpected message with type {type(message)} at the position {i}." ) + if raw_system_message: + if role == "model": + raise ValueError( + "SystemMessage should be followed by a HumanMessage and " + "not by AIMessage." + ) + parts = _convert_to_parts(raw_system_message) + parts + raw_system_message = None + vertex_message = Content(role=role, parts=parts) vertex_messages.append(vertex_message) return vertex_messages @@ -258,6 +287,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): model_name: str = "chat-bison" "Underlying model name." examples: Optional[List[BaseMessage]] = None + convert_system_message_to_human: bool = False + """Whether to merge any leading SystemMessage into the following HumanMessage. + + Gemini does not support system messages; any unsupported messages will + raise an error.""" @classmethod def is_lc_serializable(self) -> bool: @@ -327,7 +361,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): msg_params["candidate_count"] = params.pop("candidate_count") if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini(messages, project=self.project) + history_gemini = _parse_chat_history_gemini( + messages, + project=self.project, + convert_system_message_to_human=self.convert_system_message_to_human, + ) message = history_gemini.pop() chat = self.client.start_chat(history=history_gemini) @@ -396,7 +434,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): msg_params["candidate_count"] = params.pop("candidate_count") if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini(messages, project=self.project) + history_gemini = _parse_chat_history_gemini( + messages, + project=self.project, + convert_system_message_to_human=self.convert_system_message_to_human, + ) message = history_gemini.pop() chat = self.client.start_chat(history=history_gemini) # set param to `functions` until core tool/function calling implemented @@ -441,7 +483,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): ) -> Iterator[ChatGenerationChunk]: params = self._prepare_params(stop=stop, stream=True, **kwargs) if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini(messages, project=self.project) + history_gemini = _parse_chat_history_gemini( + messages, + project=self.project, + convert_system_message_to_human=self.convert_system_message_to_human, + ) message = history_gemini.pop() chat = self.client.start_chat(history=history_gemini) # set param to `functions` until core tool/function calling implemented diff --git a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py index 2b28051563..030b484d06 100644 --- a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py @@ -182,3 +182,36 @@ def test_vertexai_single_call_fails_no_message() -> None: str(exc_info.value) == "You should provide at least one message to start the chat!" ) + + +@pytest.mark.parametrize("model_name", ["gemini-pro"]) +def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None: + model = ChatVertexAI(model_name=model_name) + text_question1, text_answer1 = "How much is 2+2?", "4" + text_question2 = "How much is 3+3?" + system_message = SystemMessage(content="You're supposed to answer math questions.") + message1 = HumanMessage(content=text_question1) + message2 = AIMessage(content=text_answer1) + message3 = HumanMessage(content=text_question2) + with pytest.raises(ValueError): + model([system_message, message1, message2, message3]) + + +@pytest.mark.parametrize("model_name", model_names_to_test) +def test_chat_vertexai_system_message(model_name: str) -> None: + if model_name: + model = ChatVertexAI( + model_name=model_name, convert_system_message_to_human=True + ) + else: + model = ChatVertexAI() + + text_question1, text_answer1 = "How much is 2+2?", "4" + text_question2 = "How much is 3+3?" + system_message = SystemMessage(content="You're supposed to answer math questions.") + message1 = HumanMessage(content=text_question1) + message2 = AIMessage(content=text_answer1) + message3 = HumanMessage(content=text_question2) + response = model([system_message, message1, message2, message3]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) diff --git a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py index d11a970d65..caed17118a 100644 --- a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py @@ -13,6 +13,7 @@ from vertexai.language_models import ChatMessage, InputOutputTextPair # type: i from langchain_google_vertexai.chat_models import ( ChatVertexAI, _parse_chat_history, + _parse_chat_history_gemini, _parse_examples, ) @@ -112,6 +113,24 @@ def test_parse_chat_history_correct() -> None: ] +def test_parse_history_gemini() -> None: + system_input = "You're supposed to answer math questions." + text_question1, text_answer1 = "How much is 2+2?", "4" + text_question2 = "How much is 3+3?" + system_message = SystemMessage(content=system_input) + message1 = HumanMessage(content=text_question1) + message2 = AIMessage(content=text_answer1) + message3 = HumanMessage(content=text_question2) + messages = [system_message, message1, message2, message3] + history = _parse_chat_history_gemini(messages, convert_system_message_to_human=True) + assert len(history) == 3 + assert history[0].role == "user" + assert history[0].parts[0].text == system_input + assert history[0].parts[1].text == text_question1 + assert history[1].role == "model" + assert history[1].parts[0].text == text_answer1 + + def test_default_params_palm() -> None: user_prompt = "Hello"