langchain_google_vertexai[patch]: Add support for SystemMessage for Gemini chat model (#15933)

- **Description:** In Google Vertex AI, Gemini Chat models currently
doesn't have a support for SystemMessage. This PR adds support for it
only if a user provides additional convert_system_message_to_human flag
during model initialization (in this case, SystemMessage would be
prepended to the first HumanMessage). **NOTE:** The implementation is
similar to #14824


- **Twitter handle:** rajesh_thallam

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/16202/head^2
Rajesh Thallam 6 months ago committed by GitHub
parent 65b231d40b
commit 6bc6d64a12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -11,7 +11,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -95,7 +94,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If we want to construct a simple chain that takes user specified parameters:"
"Gemini doesn't support SystemMessage at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to `True`:"
]
},
{
@ -106,7 +105,7 @@
{
"data": {
"text/plain": [
"AIMessage(content=' プログラミングが大好きです')"
"AIMessage(content=\"J'aime la programmation.\")"
]
},
"execution_count": 9,
@ -114,6 +113,40 @@
"output_type": "execute_result"
}
],
"source": [
"system = \"You are a helpful assistant who translate English to French\"\n",
"human = \"Translate this sentence from English to French. I love programming.\"\n",
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
"\n",
"chat = ChatVertexAI(model_name=\"gemini-pro\", convert_system_message_to_human=True)\n",
"\n",
"chain = prompt | chat\n",
"chain.invoke({})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we want to construct a simple chain that takes user specified parameters:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' プログラミングが大好きです')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"system = (\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
@ -121,6 +154,8 @@
"human = \"{text}\"\n",
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
"\n",
"chat = ChatVertexAI()\n",
"\n",
"chain = prompt | chat\n",
"\n",
"chain.invoke(\n",
@ -133,7 +168,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"execution": {
@ -352,7 +386,7 @@
"AIMessage(content=' Why do you love programming?')"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -428,8 +462,14 @@
}
],
"metadata": {
"environment": {
"kernel": "python3",
"name": "common-cpu.m108",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cpu:m108"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -443,7 +483,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.10"
},
"vscode": {
"interpreter": {

@ -111,7 +111,9 @@ def _is_url(s: str) -> bool:
def _parse_chat_history_gemini(
history: List[BaseMessage], project: Optional[str]
history: List[BaseMessage],
project: Optional[str] = None,
convert_system_message_to_human: Optional[bool] = False,
) -> List[Content]:
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
if isinstance(part, str):
@ -155,9 +157,25 @@ def _parse_chat_history_gemini(
return [_convert_to_prompt(part) for part in raw_content]
vertex_messages = []
raw_system_message = None
for i, message in enumerate(history):
if i == 0 and isinstance(message, SystemMessage):
raise ValueError("SystemMessages are not yet supported!")
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:
llm = ChatVertexAI(model_name="gemini-pro", convert_system_message_to_human=True)
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
continue
elif isinstance(message, AIMessage):
raw_function_call = message.additional_kwargs.get("function_call")
role = "model"
@ -170,6 +188,8 @@ def _parse_chat_history_gemini(
)
gapic_part = GapicPart(function_call=function_call)
parts = [Part._from_gapic(gapic_part)]
else:
parts = _convert_to_parts(message)
elif isinstance(message, HumanMessage):
role = "user"
parts = _convert_to_parts(message)
@ -188,6 +208,15 @@ def _parse_chat_history_gemini(
f"Unexpected message with type {type(message)} at the position {i}."
)
if raw_system_message:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
parts = _convert_to_parts(raw_system_message) + parts
raw_system_message = None
vertex_message = Content(role=role, parts=parts)
vertex_messages.append(vertex_message)
return vertex_messages
@ -258,6 +287,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
model_name: str = "chat-bison"
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""
@classmethod
def is_lc_serializable(self) -> bool:
@ -327,7 +361,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
@ -396,7 +434,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
@ -441,7 +483,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented

@ -182,3 +182,36 @@ def test_vertexai_single_call_fails_no_message() -> None:
str(exc_info.value)
== "You should provide at least one message to start the chat!"
)
@pytest.mark.parametrize("model_name", ["gemini-pro"])
def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
model = ChatVertexAI(model_name=model_name)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_chat_vertexai_system_message(model_name: str) -> None:
if model_name:
model = ChatVertexAI(
model_name=model_name, convert_system_message_to_human=True
)
else:
model = ChatVertexAI()
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)

@ -13,6 +13,7 @@ from vertexai.language_models import ChatMessage, InputOutputTextPair # type: i
from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history,
_parse_chat_history_gemini,
_parse_examples,
)
@ -112,6 +113,24 @@ def test_parse_chat_history_correct() -> None:
]
def test_parse_history_gemini() -> None:
system_input = "You're supposed to answer math questions."
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history_gemini(messages, convert_system_message_to_human=True)
assert len(history) == 3
assert history[0].role == "user"
assert history[0].parts[0].text == system_input
assert history[0].parts[1].text == text_question1
assert history[1].role == "model"
assert history[1].parts[0].text == text_answer1
def test_default_params_palm() -> None:
user_prompt = "Hello"

Loading…
Cancel
Save