|
|
|
@ -1,4 +1,6 @@
|
|
|
|
|
"""Test ChatGoogleVertexAI chat model."""
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
from typing import Optional, cast
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
@ -9,6 +11,7 @@ from langchain_core.messages import (
|
|
|
|
|
SystemMessage,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel
|
|
|
|
|
|
|
|
|
|
from langchain_google_vertexai.chat_models import ChatVertexAI
|
|
|
|
|
|
|
|
|
@ -220,3 +223,26 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
|
|
|
|
|
response = model([system_message, message1, message2, message3])
|
|
|
|
|
assert isinstance(response, AIMessage)
|
|
|
|
|
assert isinstance(response.content, str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_chat_vertexai_gemini_function_calling() -> None:
|
|
|
|
|
class MyModel(BaseModel):
|
|
|
|
|
name: str
|
|
|
|
|
age: int
|
|
|
|
|
|
|
|
|
|
model = ChatVertexAI(model_name="gemini-pro").bind(functions=[MyModel])
|
|
|
|
|
message = HumanMessage(content="My name is Erick and I am 27 years old")
|
|
|
|
|
response = model.invoke([message])
|
|
|
|
|
assert isinstance(response, AIMessage)
|
|
|
|
|
assert isinstance(response.content, str)
|
|
|
|
|
assert response.content == ""
|
|
|
|
|
function_call = response.additional_kwargs.get("function_call")
|
|
|
|
|
assert function_call
|
|
|
|
|
assert function_call["name"] == "MyModel"
|
|
|
|
|
arguments_str = function_call.get("arguments")
|
|
|
|
|
assert arguments_str
|
|
|
|
|
arguments = json.loads(arguments_str)
|
|
|
|
|
assert arguments == {
|
|
|
|
|
"name": "Erick",
|
|
|
|
|
"age": 27.0,
|
|
|
|
|
}
|
|
|
|
|