2023-05-24 22:51:12 +00:00
|
|
|
"""Test Vertex AI API wrapper.
|
|
|
|
In order to run this test, you need to install VertexAI SDK (that is is the private
|
|
|
|
preview) and be whitelisted to list the models themselves:
|
|
|
|
In order to run this test, you need to install VertexAI SDK
|
2023-10-30 22:10:05 +00:00
|
|
|
pip install google-cloud-aiplatform>=1.35.0
|
2023-05-24 22:51:12 +00:00
|
|
|
|
|
|
|
Your end-user credentials would be used to make the calls (make sure you've run
|
|
|
|
`gcloud auth login` first).
|
|
|
|
"""
|
2023-09-22 15:18:09 +00:00
|
|
|
from typing import Optional
|
2023-07-14 06:03:04 +00:00
|
|
|
from unittest.mock import MagicMock, Mock, patch
|
2023-06-04 23:59:53 +00:00
|
|
|
|
2023-05-24 22:51:12 +00:00
|
|
|
import pytest
|
2023-12-07 17:46:11 +00:00
|
|
|
from langchain_core.messages import (
|
|
|
|
AIMessage,
|
|
|
|
AIMessageChunk,
|
|
|
|
HumanMessage,
|
|
|
|
SystemMessage,
|
|
|
|
)
|
2023-11-21 16:35:29 +00:00
|
|
|
from langchain_core.outputs import LLMResult
|
2023-05-24 22:51:12 +00:00
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
from langchain_community.chat_models import ChatVertexAI
|
|
|
|
from langchain_community.chat_models.vertexai import (
|
|
|
|
_parse_chat_history,
|
|
|
|
_parse_examples,
|
|
|
|
)
|
2023-05-24 22:51:12 +00:00
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
|
|
|
def test_vertexai_instantiation(model_name: str) -> None:
|
2023-10-13 20:31:20 +00:00
|
|
|
if model_name:
|
|
|
|
model = ChatVertexAI(model_name=model_name)
|
|
|
|
else:
|
|
|
|
model = ChatVertexAI()
|
2023-09-23 22:51:59 +00:00
|
|
|
assert model._llm_type == "vertexai"
|
|
|
|
assert model.model_name == model.client._model_id
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.scheduled
|
2023-08-07 14:34:35 +00:00
|
|
|
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
|
|
|
def test_vertexai_single_call(model_name: str) -> None:
|
|
|
|
if model_name:
|
|
|
|
model = ChatVertexAI(model_name=model_name)
|
|
|
|
else:
|
|
|
|
model = ChatVertexAI()
|
2023-05-24 22:51:12 +00:00
|
|
|
message = HumanMessage(content="Hello")
|
|
|
|
response = model([message])
|
|
|
|
assert isinstance(response, AIMessage)
|
|
|
|
assert isinstance(response.content, str)
|
|
|
|
|
|
|
|
|
2023-12-08 19:00:37 +00:00
|
|
|
# mark xfail because Vertex API randomly doesn't respect
|
|
|
|
# the n/candidate_count parameter
|
|
|
|
@pytest.mark.xfail
|
2023-12-07 17:46:11 +00:00
|
|
|
@pytest.mark.scheduled
|
2023-10-13 20:31:20 +00:00
|
|
|
def test_candidates() -> None:
|
|
|
|
model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2)
|
|
|
|
message = HumanMessage(content="Hello")
|
|
|
|
response = model.generate(messages=[[message]])
|
|
|
|
assert isinstance(response, LLMResult)
|
|
|
|
assert len(response.generations) == 1
|
|
|
|
assert len(response.generations[0]) == 2
|
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
@pytest.mark.scheduled
|
2023-09-22 08:44:09 +00:00
|
|
|
async def test_vertexai_agenerate() -> None:
|
|
|
|
model = ChatVertexAI(temperature=0)
|
|
|
|
message = HumanMessage(content="Hello")
|
|
|
|
response = await model.agenerate([[message]])
|
|
|
|
assert isinstance(response, LLMResult)
|
|
|
|
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
|
|
|
|
|
|
|
|
sync_response = model.generate([[message]])
|
|
|
|
assert response.generations[0][0] == sync_response.generations[0][0]
|
|
|
|
|
|
|
|
|
2023-12-07 17:46:11 +00:00
|
|
|
@pytest.mark.scheduled
|
|
|
|
async def test_vertexai_stream() -> None:
|
|
|
|
model = ChatVertexAI(temperature=0)
|
|
|
|
message = HumanMessage(content="Hello")
|
|
|
|
|
|
|
|
sync_response = model.stream([message])
|
|
|
|
for chunk in sync_response:
|
|
|
|
assert isinstance(chunk, AIMessageChunk)
|
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
@pytest.mark.scheduled
|
2023-05-24 22:51:12 +00:00
|
|
|
def test_vertexai_single_call_with_context() -> None:
|
|
|
|
model = ChatVertexAI()
|
|
|
|
raw_context = (
|
|
|
|
"My name is Ned. You are my personal assistant. My favorite movies "
|
|
|
|
"are Lord of the Rings and Hobbit."
|
|
|
|
)
|
|
|
|
question = (
|
|
|
|
"Hello, could you recommend a good movie for me to watch this evening, please?"
|
|
|
|
)
|
|
|
|
context = SystemMessage(content=raw_context)
|
|
|
|
message = HumanMessage(content=question)
|
|
|
|
response = model([context, message])
|
|
|
|
assert isinstance(response, AIMessage)
|
|
|
|
assert isinstance(response.content, str)
|
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
@pytest.mark.scheduled
|
2023-07-14 06:03:04 +00:00
|
|
|
def test_vertexai_single_call_with_examples() -> None:
|
|
|
|
model = ChatVertexAI()
|
|
|
|
raw_context = "My name is Ned. You are my personal assistant."
|
|
|
|
question = "2+2"
|
|
|
|
text_question, text_answer = "4+4", "8"
|
|
|
|
inp = HumanMessage(content=text_question)
|
|
|
|
output = AIMessage(content=text_answer)
|
|
|
|
context = SystemMessage(content=raw_context)
|
|
|
|
message = HumanMessage(content=question)
|
|
|
|
response = model([context, message], examples=[inp, output])
|
|
|
|
assert isinstance(response, AIMessage)
|
|
|
|
assert isinstance(response.content, str)
|
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
@pytest.mark.scheduled
|
2023-08-07 14:34:35 +00:00
|
|
|
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
|
|
|
def test_vertexai_single_call_with_history(model_name: str) -> None:
|
|
|
|
if model_name:
|
|
|
|
model = ChatVertexAI(model_name=model_name)
|
|
|
|
else:
|
|
|
|
model = ChatVertexAI()
|
|
|
|
text_question1, text_answer1 = "How much is 2+2?", "4"
|
|
|
|
text_question2 = "How much is 3+3?"
|
|
|
|
message1 = HumanMessage(content=text_question1)
|
|
|
|
message2 = AIMessage(content=text_answer1)
|
|
|
|
message3 = HumanMessage(content=text_question2)
|
|
|
|
response = model([message1, message2, message3])
|
|
|
|
assert isinstance(response, AIMessage)
|
|
|
|
assert isinstance(response.content, str)
|
|
|
|
|
|
|
|
|
2023-05-24 22:51:12 +00:00
|
|
|
def test_parse_chat_history_correct() -> None:
|
2023-07-11 19:27:26 +00:00
|
|
|
from vertexai.language_models import ChatMessage
|
|
|
|
|
2023-05-24 22:51:12 +00:00
|
|
|
text_context = (
|
|
|
|
"My name is Ned. You are my personal assistant. My "
|
|
|
|
"favorite movies are Lord of the Rings and Hobbit."
|
|
|
|
)
|
|
|
|
context = SystemMessage(content=text_context)
|
|
|
|
text_question = (
|
|
|
|
"Hello, could you recommend a good movie for me to watch this evening, please?"
|
|
|
|
)
|
|
|
|
question = HumanMessage(content=text_question)
|
|
|
|
text_answer = (
|
|
|
|
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
|
|
|
|
"(2001): This is the first movie in the Lord of the Rings trilogy."
|
|
|
|
)
|
|
|
|
answer = AIMessage(content=text_answer)
|
|
|
|
history = _parse_chat_history([context, question, answer, question, answer])
|
2023-07-11 19:27:26 +00:00
|
|
|
assert history.context == context.content
|
|
|
|
assert len(history.history) == 4
|
|
|
|
assert history.history == [
|
|
|
|
ChatMessage(content=text_question, author="user"),
|
|
|
|
ChatMessage(content=text_answer, author="bot"),
|
|
|
|
ChatMessage(content=text_question, author="user"),
|
|
|
|
ChatMessage(content=text_answer, author="bot"),
|
|
|
|
]
|
2023-05-24 22:51:12 +00:00
|
|
|
|
|
|
|
|
2023-09-23 22:51:59 +00:00
|
|
|
def test_vertexai_single_call_fails_no_message() -> None:
|
2023-05-24 22:51:12 +00:00
|
|
|
chat = ChatVertexAI()
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
|
|
_ = chat([])
|
|
|
|
assert (
|
|
|
|
str(exc_info.value)
|
|
|
|
== "You should provide at least one message to start the chat!"
|
|
|
|
)
|
2023-06-04 23:59:53 +00:00
|
|
|
|
|
|
|
|
2023-09-22 15:18:09 +00:00
|
|
|
@pytest.mark.parametrize("stop", [None, "stop1"])
|
|
|
|
def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
2023-06-04 23:59:53 +00:00
|
|
|
response_text = "Goodbye"
|
|
|
|
user_prompt = "Hello"
|
|
|
|
prompt_params = {
|
|
|
|
"max_output_tokens": 1,
|
|
|
|
"temperature": 10000.0,
|
|
|
|
"top_k": 10,
|
|
|
|
"top_p": 0.5,
|
|
|
|
}
|
|
|
|
|
|
|
|
# Mock the library to ensure the args are passed correctly
|
|
|
|
with patch(
|
2023-07-14 06:03:04 +00:00
|
|
|
"vertexai.language_models._language_models.ChatModel.start_chat"
|
|
|
|
) as start_chat:
|
2023-10-13 20:31:20 +00:00
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.candidates = [Mock(text=response_text)]
|
2023-07-14 06:03:04 +00:00
|
|
|
mock_chat = MagicMock()
|
|
|
|
start_chat.return_value = mock_chat
|
|
|
|
mock_send_message = MagicMock(return_value=mock_response)
|
|
|
|
mock_chat.send_message = mock_send_message
|
2023-06-04 23:59:53 +00:00
|
|
|
|
|
|
|
model = ChatVertexAI(**prompt_params)
|
|
|
|
message = HumanMessage(content=user_prompt)
|
2023-09-22 15:18:09 +00:00
|
|
|
if stop:
|
|
|
|
response = model([message], stop=[stop])
|
|
|
|
else:
|
|
|
|
response = model([message])
|
2023-06-04 23:59:53 +00:00
|
|
|
|
|
|
|
assert response.content == response_text
|
2023-10-13 20:31:20 +00:00
|
|
|
mock_send_message.assert_called_once_with(user_prompt, candidate_count=1)
|
2023-09-22 15:18:09 +00:00
|
|
|
expected_stop_sequence = [stop] if stop else None
|
2023-07-14 06:03:04 +00:00
|
|
|
start_chat.assert_called_once_with(
|
2023-09-22 15:18:09 +00:00
|
|
|
context=None,
|
|
|
|
message_history=[],
|
|
|
|
**prompt_params,
|
2023-10-31 14:53:12 +00:00
|
|
|
stop_sequences=expected_stop_sequence,
|
2023-06-04 23:59:53 +00:00
|
|
|
)
|
2023-07-14 06:03:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_parse_examples_correct() -> None:
|
|
|
|
from vertexai.language_models import InputOutputTextPair
|
|
|
|
|
|
|
|
text_question = (
|
|
|
|
"Hello, could you recommend a good movie for me to watch this evening, please?"
|
|
|
|
)
|
|
|
|
question = HumanMessage(content=text_question)
|
|
|
|
text_answer = (
|
|
|
|
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
|
|
|
|
"(2001): This is the first movie in the Lord of the Rings trilogy."
|
|
|
|
)
|
|
|
|
answer = AIMessage(content=text_answer)
|
|
|
|
examples = _parse_examples([question, answer, question, answer])
|
|
|
|
assert len(examples) == 2
|
|
|
|
assert examples == [
|
|
|
|
InputOutputTextPair(input_text=text_question, output_text=text_answer),
|
|
|
|
InputOutputTextPair(input_text=text_question, output_text=text_answer),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2023-08-07 14:34:35 +00:00
|
|
|
def test_parse_examples_failes_wrong_sequence() -> None:
|
2023-07-14 06:03:04 +00:00
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
|
|
_ = _parse_examples([AIMessage(content="a")])
|
|
|
|
print(str(exc_info.value))
|
|
|
|
assert (
|
|
|
|
str(exc_info.value)
|
|
|
|
== "Expect examples to have an even amount of messages, got 1."
|
|
|
|
)
|