|
|
@ -1,5 +1,7 @@
|
|
|
|
"""Test chat model integration."""
|
|
|
|
"""Test chat model integration."""
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import pytest
|
|
|
@ -45,6 +47,13 @@ def test_parse_examples_failes_wrong_sequence() -> None:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
|
|
class StubTextChatResponse:
|
|
|
|
|
|
|
|
"""Stub text-chat response from VertexAI for testing."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("stop", [None, "stop1"])
|
|
|
|
@pytest.mark.parametrize("stop", [None, "stop1"])
|
|
|
|
def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
|
|
|
def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
|
|
|
response_text = "Goodbye"
|
|
|
|
response_text = "Goodbye"
|
|
|
@ -59,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
|
|
|
# Mock the library to ensure the args are passed correctly
|
|
|
|
# Mock the library to ensure the args are passed correctly
|
|
|
|
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
|
|
|
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.candidates = [Mock(text=response_text)]
|
|
|
|
mock_response.candidates = [StubTextChatResponse(text=response_text)]
|
|
|
|
mock_chat = MagicMock()
|
|
|
|
mock_chat = MagicMock()
|
|
|
|
mock_send_message = MagicMock(return_value=mock_response)
|
|
|
|
mock_send_message = MagicMock(return_value=mock_response)
|
|
|
|
mock_chat.send_message = mock_send_message
|
|
|
|
mock_chat.send_message = mock_send_message
|
|
|
@ -136,7 +145,7 @@ def test_default_params_palm() -> None:
|
|
|
|
|
|
|
|
|
|
|
|
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
|
|
|
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response.candidates = [Mock(text="Goodbye")]
|
|
|
|
mock_response.candidates = [StubTextChatResponse(text="Goodbye")]
|
|
|
|
mock_chat = MagicMock()
|
|
|
|
mock_chat = MagicMock()
|
|
|
|
mock_send_message = MagicMock(return_value=mock_response)
|
|
|
|
mock_send_message = MagicMock(return_value=mock_response)
|
|
|
|
mock_chat.send_message = mock_send_message
|
|
|
|
mock_chat.send_message = mock_send_message
|
|
|
@ -159,13 +168,28 @@ def test_default_params_palm() -> None:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
|
|
class StubGeminiResponse:
|
|
|
|
|
|
|
|
"""Stub gemini response from VertexAI for testing."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text: str
|
|
|
|
|
|
|
|
content: Any
|
|
|
|
|
|
|
|
citation_metadata: Any
|
|
|
|
|
|
|
|
safety_ratings: List[Any] = field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_default_params_gemini() -> None:
|
|
|
|
def test_default_params_gemini() -> None:
|
|
|
|
user_prompt = "Hello"
|
|
|
|
user_prompt = "Hello"
|
|
|
|
|
|
|
|
|
|
|
|
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
|
|
|
|
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
|
|
|
|
mock_response = MagicMock()
|
|
|
|
mock_response = MagicMock()
|
|
|
|
content = Mock(parts=[Mock(function_call=None)])
|
|
|
|
mock_response.candidates = [
|
|
|
|
mock_response.candidates = [Mock(text="Goodbye", content=content)]
|
|
|
|
StubGeminiResponse(
|
|
|
|
|
|
|
|
text="Goodbye",
|
|
|
|
|
|
|
|
content=Mock(parts=[Mock(function_call=None)]),
|
|
|
|
|
|
|
|
citation_metadata=Mock(),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
]
|
|
|
|
mock_chat = MagicMock()
|
|
|
|
mock_chat = MagicMock()
|
|
|
|
mock_send_message = MagicMock(return_value=mock_response)
|
|
|
|
mock_send_message = MagicMock(return_value=mock_response)
|
|
|
|
mock_chat.send_message = mock_send_message
|
|
|
|
mock_chat.send_message = mock_send_message
|
|
|
|