You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py

319 lines
10 KiB
Python

"""Test chat model integration."""
import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch
import pytest
from google.cloud.aiplatform_v1beta1.types import (
Content,
FunctionCall,
Part,
)
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
)
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
from vertexai.preview.generative_models import ( # type: ignore
Candidate,
)
from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history,
_parse_chat_history_gemini,
_parse_examples,
_parse_response_candidate,
)
def test_parse_examples_correct() -> None:
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
question = HumanMessage(content=text_question)
text_answer = (
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
"(2001): This is the first movie in the Lord of the Rings trilogy."
)
answer = AIMessage(content=text_answer)
examples = _parse_examples([question, answer, question, answer])
assert len(examples) == 2
assert examples == [
InputOutputTextPair(input_text=text_question, output_text=text_answer),
InputOutputTextPair(input_text=text_question, output_text=text_answer),
]
def test_parse_examples_failes_wrong_sequence() -> None:
with pytest.raises(ValueError) as exc_info:
_ = _parse_examples([AIMessage(content="a")])
assert (
str(exc_info.value)
== "Expect examples to have an even amount of messages, got 1."
)
@dataclass
class StubTextChatResponse:
"""Stub text-chat response from VertexAI for testing."""
text: str
@pytest.mark.parametrize("stop", [None, "stop1"])
def test_vertexai_args_passed(stop: Optional[str]) -> None:
response_text = "Goodbye"
user_prompt = "Hello"
prompt_params: Dict[str, Any] = {
"max_output_tokens": 1,
"temperature": 10000.0,
"top_k": 10,
"top_p": 0.5,
}
# Mock the library to ensure the args are passed correctly
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [StubTextChatResponse(text=response_text)]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
mg.return_value = mock_model
model = ChatVertexAI(**prompt_params)
message = HumanMessage(content=user_prompt)
if stop:
response = model([message], stop=[stop])
else:
response = model([message])
assert response.content == response_text
mock_send_message.assert_called_once_with(user_prompt, candidate_count=1)
expected_stop_sequence = [stop] if stop else None
mock_start_chat.assert_called_once_with(
context=None,
message_history=[],
**prompt_params,
stop_sequences=expected_stop_sequence,
)
def test_parse_chat_history_correct() -> None:
text_context = (
"My name is Ned. You are my personal assistant. My "
"favorite movies are Lord of the Rings and Hobbit."
)
context = SystemMessage(content=text_context)
text_question = (
"Hello, could you recommend a good movie for me to watch this evening, please?"
)
question = HumanMessage(content=text_question)
text_answer = (
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
"(2001): This is the first movie in the Lord of the Rings trilogy."
)
answer = AIMessage(content=text_answer)
history = _parse_chat_history([context, question, answer, question, answer])
assert history.context == context.content
assert len(history.history) == 4
assert history.history == [
ChatMessage(content=text_question, author="user"),
ChatMessage(content=text_answer, author="bot"),
ChatMessage(content=text_question, author="user"),
ChatMessage(content=text_answer, author="bot"),
]
def test_parse_history_gemini() -> None:
system_input = "You're supposed to answer math questions."
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history_gemini(messages, convert_system_message_to_human=True)
assert len(history) == 3
assert history[0].role == "user"
assert history[0].parts[0].text == system_input
assert history[0].parts[1].text == text_question1
assert history[1].role == "model"
assert history[1].parts[0].text == text_answer1
def test_default_params_palm() -> None:
user_prompt = "Hello"
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [StubTextChatResponse(text="Goodbye")]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
mg.return_value = mock_model
model = ChatVertexAI(model_name="text-bison@001")
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(
context=None,
message_history=[],
max_output_tokens=128,
top_k=40,
top_p=0.95,
stop_sequences=None,
)
@dataclass
class StubGeminiResponse:
"""Stub gemini response from VertexAI for testing."""
text: str
content: Any
citation_metadata: Any
safety_ratings: List[Any] = field(default_factory=list)
def test_default_params_gemini() -> None:
user_prompt = "Hello"
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
mock_response = MagicMock()
mock_response.candidates = [
StubGeminiResponse(
text="Goodbye",
content=Mock(parts=[Mock(function_call=None)]),
citation_metadata=None,
)
]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
gm.return_value = mock_model
model = ChatVertexAI(model_name="gemini-pro")
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(history=[])
@pytest.mark.parametrize(
"raw_candidate, expected",
[
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"name": "Ben"},
),
)
],
)
),
{
"name": "Information",
"arguments": {"name": "Ben"},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": ["A", "B", "C"]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": ["A", "B", "C"]},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
),
)
],
)
),
{
"name": "Information",
"arguments": {
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": [[1, 2, 3], [4, 5, 6]]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": [[1, 2, 3], [4, 5, 6]]},
},
),
],
)
def test_parse_response_candidate(raw_candidate, expected) -> None:
response_candidate = Candidate._from_gapic(raw_candidate)
result = _parse_response_candidate(response_candidate)
result_arguments = json.loads(
result.additional_kwargs["function_call"]["arguments"]
)
assert result_arguments == expected["arguments"]