diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/_utils.py b/libs/partners/google-vertexai/langchain_google_vertexai/_utils.py index 91c4086a7f..4fe052a56b 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/_utils.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/_utils.py @@ -1,4 +1,6 @@ """Utilities to init Vertex AI.""" + +import dataclasses from importlib import metadata from typing import Any, Callable, Dict, Optional, Union @@ -10,7 +12,13 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models.llms import create_base_retry_decorator -from vertexai.preview.generative_models import Image # type: ignore +from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped] + Candidate, +) +from vertexai.language_models import ( # type: ignore[import-untyped] + TextGenerationResponse, +) +from vertexai.preview.generative_models import Image # type: ignore[import-untyped] def create_retry_decorator( @@ -88,27 +96,23 @@ def is_gemini_model(model_name: str) -> bool: return model_name is not None and "gemini" in model_name -def get_generation_info(candidate: Any, is_gemini: bool) -> Optional[Dict[str, Any]]: - try: - if is_gemini: - # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body - return { - "is_blocked": any( - [rating.blocked for rating in candidate.safety_ratings] - ), - "safety_ratings": [ - { - "category": rating.category.name, - "probability_label": rating.probability.name, - } - for rating in candidate.safety_ratings - ], - } - else: - # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body - return { - "is_blocked": candidate.is_blocked, - "safety_attributes": candidate.safety_attributes, - } - except Exception: - return None +def get_generation_info( + candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool +) -> Dict[str, Any]: + if is_gemini: + # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body + return { + "is_blocked": any([rating.blocked for rating in candidate.safety_ratings]), + "safety_ratings": [ + { + "category": rating.category.name, + "probability_label": rating.probability.name, + } + for rating in candidate.safety_ratings + ], + "citation_metadata": candidate.citation_metadata, + } + # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body + candidate_dc = dataclasses.asdict(candidate) + candidate_dc.pop("text") + return {k: v for k, v in candidate_dc.items() if not k.startswith("_")} diff --git a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py index a29094bf92..f2981a0801 100644 --- a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py @@ -1,5 +1,5 @@ """Test ChatGoogleVertexAI chat model.""" -from typing import cast +from typing import Optional, cast import pytest from langchain_core.messages import ( @@ -16,7 +16,7 @@ model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"] @pytest.mark.parametrize("model_name", model_names_to_test) -def test_initialization(model_name: str) -> None: +def test_initialization(model_name: Optional[str]) -> None: """Test chat model initialization.""" if model_name: model = ChatVertexAI(model_name=model_name) @@ -30,7 +30,7 @@ def test_initialization(model_name: str) -> None: @pytest.mark.parametrize("model_name", model_names_to_test) -def test_vertexai_single_call(model_name: str) -> None: +def test_vertexai_single_call(model_name: Optional[str]) -> None: if model_name: model = ChatVertexAI(model_name=model_name) else: @@ -164,7 +164,7 @@ def test_vertexai_single_call_with_examples() -> None: @pytest.mark.parametrize("model_name", model_names_to_test) -def test_vertexai_single_call_with_history(model_name: str) -> None: +def test_vertexai_single_call_with_history(model_name: Optional[str]) -> None: if model_name: model = ChatVertexAI(model_name=model_name) else: @@ -203,7 +203,7 @@ def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None: @pytest.mark.parametrize("model_name", model_names_to_test) -def test_chat_vertexai_system_message(model_name: str) -> None: +def test_chat_vertexai_system_message(model_name: Optional[str]) -> None: if model_name: model = ChatVertexAI( model_name=model_name, convert_system_message_to_human=True diff --git a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py index 3f30fef358..052cf559f0 100644 --- a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py @@ -1,5 +1,7 @@ """Test chat model integration.""" -from typing import Any, Dict, Optional + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional from unittest.mock import MagicMock, Mock, patch import pytest @@ -45,6 +47,13 @@ def test_parse_examples_failes_wrong_sequence() -> None: ) +@dataclass +class StubTextChatResponse: + """Stub text-chat response from VertexAI for testing.""" + + text: str + + @pytest.mark.parametrize("stop", [None, "stop1"]) def test_vertexai_args_passed(stop: Optional[str]) -> None: response_text = "Goodbye" @@ -59,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None: # Mock the library to ensure the args are passed correctly with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg: mock_response = MagicMock() - mock_response.candidates = [Mock(text=response_text)] + mock_response.candidates = [StubTextChatResponse(text=response_text)] mock_chat = MagicMock() mock_send_message = MagicMock(return_value=mock_response) mock_chat.send_message = mock_send_message @@ -136,7 +145,7 @@ def test_default_params_palm() -> None: with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg: mock_response = MagicMock() - mock_response.candidates = [Mock(text="Goodbye")] + mock_response.candidates = [StubTextChatResponse(text="Goodbye")] mock_chat = MagicMock() mock_send_message = MagicMock(return_value=mock_response) mock_chat.send_message = mock_send_message @@ -159,13 +168,28 @@ def test_default_params_palm() -> None: ) +@dataclass +class StubGeminiResponse: + """Stub gemini response from VertexAI for testing.""" + + text: str + content: Any + citation_metadata: Any + safety_ratings: List[Any] = field(default_factory=list) + + def test_default_params_gemini() -> None: user_prompt = "Hello" with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm: mock_response = MagicMock() - content = Mock(parts=[Mock(function_call=None)]) - mock_response.candidates = [Mock(text="Goodbye", content=content)] + mock_response.candidates = [ + StubGeminiResponse( + text="Goodbye", + content=Mock(parts=[Mock(function_call=None)]), + citation_metadata=Mock(), + ) + ] mock_chat = MagicMock() mock_send_message = MagicMock(return_value=mock_response) mock_chat.send_message = mock_send_message