langchain-google-vertexai: perserving grounding metadata (#16309)

Revival of https://github.com/langchain-ai/langchain/pull/14549 that
closes https://github.com/langchain-ai/langchain/issues/14548.
pull/16557/head
James Braza 5 months ago committed by GitHub
parent adc008407e
commit 0785432e7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,6 @@
"""Utilities to init Vertex AI."""
import dataclasses
from importlib import metadata
from typing import Any, Callable, Dict, Optional, Union
@ -10,7 +12,13 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from vertexai.preview.generative_models import Image # type: ignore
from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
Candidate,
)
from vertexai.language_models import ( # type: ignore[import-untyped]
TextGenerationResponse,
)
from vertexai.preview.generative_models import Image # type: ignore[import-untyped]
def create_retry_decorator(
@ -88,27 +96,23 @@ def is_gemini_model(model_name: str) -> bool:
return model_name is not None and "gemini" in model_name
def get_generation_info(candidate: Any, is_gemini: bool) -> Optional[Dict[str, Any]]:
try:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
return {
"is_blocked": any(
[rating.blocked for rating in candidate.safety_ratings]
),
"safety_ratings": [
{
"category": rating.category.name,
"probability_label": rating.probability.name,
}
for rating in candidate.safety_ratings
],
}
else:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
return {
"is_blocked": candidate.is_blocked,
"safety_attributes": candidate.safety_attributes,
}
except Exception:
return None
def get_generation_info(
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
) -> Dict[str, Any]:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
return {
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
"safety_ratings": [
{
"category": rating.category.name,
"probability_label": rating.probability.name,
}
for rating in candidate.safety_ratings
],
"citation_metadata": candidate.citation_metadata,
}
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
candidate_dc = dataclasses.asdict(candidate)
candidate_dc.pop("text")
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}

@ -1,5 +1,5 @@
"""Test ChatGoogleVertexAI chat model."""
from typing import cast
from typing import Optional, cast
import pytest
from langchain_core.messages import (
@ -16,7 +16,7 @@ model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_initialization(model_name: str) -> None:
def test_initialization(model_name: Optional[str]) -> None:
"""Test chat model initialization."""
if model_name:
model = ChatVertexAI(model_name=model_name)
@ -30,7 +30,7 @@ def test_initialization(model_name: str) -> None:
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call(model_name: str) -> None:
def test_vertexai_single_call(model_name: Optional[str]) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
@ -164,7 +164,7 @@ def test_vertexai_single_call_with_examples() -> None:
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call_with_history(model_name: str) -> None:
def test_vertexai_single_call_with_history(model_name: Optional[str]) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
@ -203,7 +203,7 @@ def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_chat_vertexai_system_message(model_name: str) -> None:
def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
if model_name:
model = ChatVertexAI(
model_name=model_name, convert_system_message_to_human=True

@ -1,5 +1,7 @@
"""Test chat model integration."""
from typing import Any, Dict, Optional
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -45,6 +47,13 @@ def test_parse_examples_failes_wrong_sequence() -> None:
)
@dataclass
class StubTextChatResponse:
"""Stub text-chat response from VertexAI for testing."""
text: str
@pytest.mark.parametrize("stop", [None, "stop1"])
def test_vertexai_args_passed(stop: Optional[str]) -> None:
response_text = "Goodbye"
@ -59,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
# Mock the library to ensure the args are passed correctly
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [Mock(text=response_text)]
mock_response.candidates = [StubTextChatResponse(text=response_text)]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
@ -136,7 +145,7 @@ def test_default_params_palm() -> None:
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [Mock(text="Goodbye")]
mock_response.candidates = [StubTextChatResponse(text="Goodbye")]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
@ -159,13 +168,28 @@ def test_default_params_palm() -> None:
)
@dataclass
class StubGeminiResponse:
"""Stub gemini response from VertexAI for testing."""
text: str
content: Any
citation_metadata: Any
safety_ratings: List[Any] = field(default_factory=list)
def test_default_params_gemini() -> None:
user_prompt = "Hello"
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
mock_response = MagicMock()
content = Mock(parts=[Mock(function_call=None)])
mock_response.candidates = [Mock(text="Goodbye", content=content)]
mock_response.candidates = [
StubGeminiResponse(
text="Goodbye",
content=Mock(parts=[Mock(function_call=None)]),
citation_metadata=Mock(),
)
]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message

Loading…
Cancel
Save