langchain-google-vertexai: perserving grounding metadata (#16309)

Revival of https://github.com/langchain-ai/langchain/pull/14549 that
closes https://github.com/langchain-ai/langchain/issues/14548.
pull/16557/head
James Braza 8 months ago committed by GitHub
parent adc008407e
commit 0785432e7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,6 @@
"""Utilities to init Vertex AI.""" """Utilities to init Vertex AI."""
import dataclasses
from importlib import metadata from importlib import metadata
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Dict, Optional, Union
@ -10,7 +12,13 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.language_models.llms import create_base_retry_decorator
from vertexai.preview.generative_models import Image # type: ignore from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
Candidate,
)
from vertexai.language_models import ( # type: ignore[import-untyped]
TextGenerationResponse,
)
from vertexai.preview.generative_models import Image # type: ignore[import-untyped]
def create_retry_decorator( def create_retry_decorator(
@ -88,27 +96,23 @@ def is_gemini_model(model_name: str) -> bool:
return model_name is not None and "gemini" in model_name return model_name is not None and "gemini" in model_name
def get_generation_info(candidate: Any, is_gemini: bool) -> Optional[Dict[str, Any]]: def get_generation_info(
try: candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
if is_gemini: ) -> Dict[str, Any]:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body if is_gemini:
return { # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
"is_blocked": any( return {
[rating.blocked for rating in candidate.safety_ratings] "is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
), "safety_ratings": [
"safety_ratings": [ {
{ "category": rating.category.name,
"category": rating.category.name, "probability_label": rating.probability.name,
"probability_label": rating.probability.name, }
} for rating in candidate.safety_ratings
for rating in candidate.safety_ratings ],
], "citation_metadata": candidate.citation_metadata,
} }
else: # https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body candidate_dc = dataclasses.asdict(candidate)
return { candidate_dc.pop("text")
"is_blocked": candidate.is_blocked, return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
"safety_attributes": candidate.safety_attributes,
}
except Exception:
return None

@ -1,5 +1,5 @@
"""Test ChatGoogleVertexAI chat model.""" """Test ChatGoogleVertexAI chat model."""
from typing import cast from typing import Optional, cast
import pytest import pytest
from langchain_core.messages import ( from langchain_core.messages import (
@ -16,7 +16,7 @@ model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]
@pytest.mark.parametrize("model_name", model_names_to_test) @pytest.mark.parametrize("model_name", model_names_to_test)
def test_initialization(model_name: str) -> None: def test_initialization(model_name: Optional[str]) -> None:
"""Test chat model initialization.""" """Test chat model initialization."""
if model_name: if model_name:
model = ChatVertexAI(model_name=model_name) model = ChatVertexAI(model_name=model_name)
@ -30,7 +30,7 @@ def test_initialization(model_name: str) -> None:
@pytest.mark.parametrize("model_name", model_names_to_test) @pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call(model_name: str) -> None: def test_vertexai_single_call(model_name: Optional[str]) -> None:
if model_name: if model_name:
model = ChatVertexAI(model_name=model_name) model = ChatVertexAI(model_name=model_name)
else: else:
@ -164,7 +164,7 @@ def test_vertexai_single_call_with_examples() -> None:
@pytest.mark.parametrize("model_name", model_names_to_test) @pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call_with_history(model_name: str) -> None: def test_vertexai_single_call_with_history(model_name: Optional[str]) -> None:
if model_name: if model_name:
model = ChatVertexAI(model_name=model_name) model = ChatVertexAI(model_name=model_name)
else: else:
@ -203,7 +203,7 @@ def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
@pytest.mark.parametrize("model_name", model_names_to_test) @pytest.mark.parametrize("model_name", model_names_to_test)
def test_chat_vertexai_system_message(model_name: str) -> None: def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
if model_name: if model_name:
model = ChatVertexAI( model = ChatVertexAI(
model_name=model_name, convert_system_message_to_human=True model_name=model_name, convert_system_message_to_human=True

@ -1,5 +1,7 @@
"""Test chat model integration.""" """Test chat model integration."""
from typing import Any, Dict, Optional
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
@ -45,6 +47,13 @@ def test_parse_examples_failes_wrong_sequence() -> None:
) )
@dataclass
class StubTextChatResponse:
"""Stub text-chat response from VertexAI for testing."""
text: str
@pytest.mark.parametrize("stop", [None, "stop1"]) @pytest.mark.parametrize("stop", [None, "stop1"])
def test_vertexai_args_passed(stop: Optional[str]) -> None: def test_vertexai_args_passed(stop: Optional[str]) -> None:
response_text = "Goodbye" response_text = "Goodbye"
@ -59,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
# Mock the library to ensure the args are passed correctly # Mock the library to ensure the args are passed correctly
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg: with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.candidates = [Mock(text=response_text)] mock_response.candidates = [StubTextChatResponse(text=response_text)]
mock_chat = MagicMock() mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response) mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message mock_chat.send_message = mock_send_message
@ -136,7 +145,7 @@ def test_default_params_palm() -> None:
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg: with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.candidates = [Mock(text="Goodbye")] mock_response.candidates = [StubTextChatResponse(text="Goodbye")]
mock_chat = MagicMock() mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response) mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message mock_chat.send_message = mock_send_message
@ -159,13 +168,28 @@ def test_default_params_palm() -> None:
) )
@dataclass
class StubGeminiResponse:
"""Stub gemini response from VertexAI for testing."""
text: str
content: Any
citation_metadata: Any
safety_ratings: List[Any] = field(default_factory=list)
def test_default_params_gemini() -> None: def test_default_params_gemini() -> None:
user_prompt = "Hello" user_prompt = "Hello"
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm: with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
mock_response = MagicMock() mock_response = MagicMock()
content = Mock(parts=[Mock(function_call=None)]) mock_response.candidates = [
mock_response.candidates = [Mock(text="Goodbye", content=content)] StubGeminiResponse(
text="Goodbye",
content=Mock(parts=[Mock(function_call=None)]),
citation_metadata=Mock(),
)
]
mock_chat = MagicMock() mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response) mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message mock_chat.send_message = mock_send_message

Loading…
Cancel
Save