@ -111,7 +111,9 @@ def _is_url(s: str) -> bool:
def _parse_chat_history_gemini (
def _parse_chat_history_gemini (
history : List [ BaseMessage ] , project : Optional [ str ]
history : List [ BaseMessage ] ,
project : Optional [ str ] = None ,
convert_system_message_to_human : Optional [ bool ] = False ,
) - > List [ Content ] :
) - > List [ Content ] :
def _convert_to_prompt ( part : Union [ str , Dict ] ) - > Part :
def _convert_to_prompt ( part : Union [ str , Dict ] ) - > Part :
if isinstance ( part , str ) :
if isinstance ( part , str ) :
@ -155,9 +157,25 @@ def _parse_chat_history_gemini(
return [ _convert_to_prompt ( part ) for part in raw_content ]
return [ _convert_to_prompt ( part ) for part in raw_content ]
vertex_messages = [ ]
vertex_messages = [ ]
raw_system_message = None
for i , message in enumerate ( history ) :
for i , message in enumerate ( history ) :
if i == 0 and isinstance ( message , SystemMessage ) :
if (
raise ValueError ( " SystemMessages are not yet supported! " )
i == 0
and isinstance ( message , SystemMessage )
and not convert_system_message_to_human
) :
raise ValueError (
""" SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage ,
set ` convert_system_message_to_human ` to True . Example :
llm = ChatVertexAI ( model_name = " gemini-pro " , convert_system_message_to_human = True )
"""
)
elif i == 0 and isinstance ( message , SystemMessage ) :
raw_system_message = message
continue
elif isinstance ( message , AIMessage ) :
elif isinstance ( message , AIMessage ) :
raw_function_call = message . additional_kwargs . get ( " function_call " )
raw_function_call = message . additional_kwargs . get ( " function_call " )
role = " model "
role = " model "
@ -170,6 +188,8 @@ def _parse_chat_history_gemini(
)
)
gapic_part = GapicPart ( function_call = function_call )
gapic_part = GapicPart ( function_call = function_call )
parts = [ Part . _from_gapic ( gapic_part ) ]
parts = [ Part . _from_gapic ( gapic_part ) ]
else :
parts = _convert_to_parts ( message )
elif isinstance ( message , HumanMessage ) :
elif isinstance ( message , HumanMessage ) :
role = " user "
role = " user "
parts = _convert_to_parts ( message )
parts = _convert_to_parts ( message )
@ -188,6 +208,15 @@ def _parse_chat_history_gemini(
f " Unexpected message with type { type ( message ) } at the position { i } . "
f " Unexpected message with type { type ( message ) } at the position { i } . "
)
)
if raw_system_message :
if role == " model " :
raise ValueError (
" SystemMessage should be followed by a HumanMessage and "
" not by AIMessage. "
)
parts = _convert_to_parts ( raw_system_message ) + parts
raw_system_message = None
vertex_message = Content ( role = role , parts = parts )
vertex_message = Content ( role = role , parts = parts )
vertex_messages . append ( vertex_message )
vertex_messages . append ( vertex_message )
return vertex_messages
return vertex_messages
@ -258,6 +287,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
model_name : str = " chat-bison "
model_name : str = " chat-bison "
" Underlying model name. "
" Underlying model name. "
examples : Optional [ List [ BaseMessage ] ] = None
examples : Optional [ List [ BaseMessage ] ] = None
convert_system_message_to_human : bool = False
""" Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages ; any unsupported messages will
raise an error . """
@classmethod
@classmethod
def is_lc_serializable ( self ) - > bool :
def is_lc_serializable ( self ) - > bool :
@ -327,7 +361,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
msg_params [ " candidate_count " ] = params . pop ( " candidate_count " )
msg_params [ " candidate_count " ] = params . pop ( " candidate_count " )
if self . _is_gemini_model :
if self . _is_gemini_model :
history_gemini = _parse_chat_history_gemini ( messages , project = self . project )
history_gemini = _parse_chat_history_gemini (
messages ,
project = self . project ,
convert_system_message_to_human = self . convert_system_message_to_human ,
)
message = history_gemini . pop ( )
message = history_gemini . pop ( )
chat = self . client . start_chat ( history = history_gemini )
chat = self . client . start_chat ( history = history_gemini )
@ -396,7 +434,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
msg_params [ " candidate_count " ] = params . pop ( " candidate_count " )
msg_params [ " candidate_count " ] = params . pop ( " candidate_count " )
if self . _is_gemini_model :
if self . _is_gemini_model :
history_gemini = _parse_chat_history_gemini ( messages , project = self . project )
history_gemini = _parse_chat_history_gemini (
messages ,
project = self . project ,
convert_system_message_to_human = self . convert_system_message_to_human ,
)
message = history_gemini . pop ( )
message = history_gemini . pop ( )
chat = self . client . start_chat ( history = history_gemini )
chat = self . client . start_chat ( history = history_gemini )
# set param to `functions` until core tool/function calling implemented
# set param to `functions` until core tool/function calling implemented
@ -441,7 +483,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
) - > Iterator [ ChatGenerationChunk ] :
) - > Iterator [ ChatGenerationChunk ] :
params = self . _prepare_params ( stop = stop , stream = True , * * kwargs )
params = self . _prepare_params ( stop = stop , stream = True , * * kwargs )
if self . _is_gemini_model :
if self . _is_gemini_model :
history_gemini = _parse_chat_history_gemini ( messages , project = self . project )
history_gemini = _parse_chat_history_gemini (
messages ,
project = self . project ,
convert_system_message_to_human = self . convert_system_message_to_human ,
)
message = history_gemini . pop ( )
message = history_gemini . pop ( )
chat = self . client . start_chat ( history = history_gemini )
chat = self . client . start_chat ( history = history_gemini )
# set param to `functions` until core tool/function calling implemented
# set param to `functions` until core tool/function calling implemented