google-vertexai[minor]: added safety_settings property to gemini wrapper (#15344)

**Description:** Gemini model has quite annoying default safety_settings
settings. In addition, current VertexAI class doesn't provide a property
to override such settings.
So, this PR aims to 
 - add safety_settings property to VertexAI
- fix issue with incorrect LLM output parsing when LLM responds with
appropriate 'blocked' response
- fix issue with incorrect parsing LLM output when Gemini API blocks
prompt itself as inappropriate
- add safety_settings related tests

I'm not enough familiar with langchain code base and guidelines. So, any
comments and/or suggestions are very welcome.
 
**Issue:** it will likely fix #14841

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/16211/head
Eugene Zapolsky 6 months ago committed by GitHub
parent ecd4f0a7ec
commit 6b9e3ed9e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -35,7 +35,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {
"tags": []
},
@ -44,10 +44,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
"^C\n",
"\u001b[31mERROR: Operation cancelled by user\u001b[0m\u001b[31m\n",
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
]
}
],
@ -57,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -67,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -76,7 +75,7 @@
"AIMessage(content=\" J'aime la programmation.\")"
]
},
"execution_count": 2,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@ -101,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -110,7 +109,7 @@
"AIMessage(content=' プログラミングが大好きです')"
]
},
"execution_count": 3,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -154,7 +153,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"tags": []
},
@ -165,27 +164,51 @@
"text": [
" ```python\n",
"def is_prime(n):\n",
" if n <= 1:\n",
" return False\n",
" for i in range(2, n):\n",
" if n % i == 0:\n",
" return False\n",
" return True\n",
" \"\"\"\n",
" Check if a number is prime.\n",
"\n",
" Args:\n",
" n: The number to check.\n",
"\n",
" Returns:\n",
" True if n is prime, False otherwise.\n",
" \"\"\"\n",
"\n",
" # If n is 1, it is not prime.\n",
" if n == 1:\n",
" return False\n",
"\n",
" # Iterate over all numbers from 2 to the square root of n.\n",
" for i in range(2, int(n ** 0.5) + 1):\n",
" # If n is divisible by any number from 2 to its square root, it is not prime.\n",
" if n % i == 0:\n",
" return False\n",
"\n",
" # If n is divisible by no number from 2 to its square root, it is prime.\n",
" return True\n",
"\n",
"\n",
"def find_prime_numbers(n):\n",
" prime_numbers = []\n",
" for i in range(2, n + 1):\n",
" if is_prime(i):\n",
" prime_numbers.append(i)\n",
" return prime_numbers\n",
" \"\"\"\n",
" Find all prime numbers up to a given number.\n",
"\n",
" Args:\n",
" n: The upper bound for the prime numbers to find.\n",
"\n",
" Returns:\n",
" A list of all prime numbers up to n.\n",
" \"\"\"\n",
"\n",
"print(find_prime_numbers(100))\n",
"```\n",
" # Create a list of all numbers from 2 to n.\n",
" numbers = list(range(2, n + 1))\n",
"\n",
"Output:\n",
" # Iterate over the list of numbers and remove any that are not prime.\n",
" for number in numbers:\n",
" if not is_prime(number):\n",
" numbers.remove(number)\n",
"\n",
"```\n",
"[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]\n",
" # Return the list of prime numbers.\n",
" return numbers\n",
"```\n"
]
}
@ -199,6 +222,102 @@
"print(message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Full generation info\n",
"\n",
"We can use the `generate` method to get back extra metadata like [safety attributes](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_confidence_scoring) and not just chat completions\n",
"\n",
"Note that the `generation_info` will be different depending if you're using a gemini model or not.\n",
"\n",
"### Gemini model\n",
"\n",
"`generation_info` will include:\n",
"\n",
"- `is_blocked`: whether generation was blocked or not\n",
"- `safety_ratings`: safety ratings' categories and probability labels"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'is_blocked': False,\n",
" 'safety_ratings': [{'category': 'HARM_CATEGORY_HARASSMENT',\n",
" 'probability_label': 'NEGLIGIBLE'},\n",
" {'category': 'HARM_CATEGORY_HATE_SPEECH',\n",
" 'probability_label': 'NEGLIGIBLE'},\n",
" {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',\n",
" 'probability_label': 'NEGLIGIBLE'},\n",
" {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',\n",
" 'probability_label': 'NEGLIGIBLE'}]}\n"
]
}
],
"source": [
"from pprint import pprint\n",
"\n",
"from langchain_core.messages import HumanMessage\n",
"from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory\n",
"\n",
"human = \"Translate this sentence from English to French. I love programming.\"\n",
"messages = [HumanMessage(content=human)]\n",
"\n",
"\n",
"chat = ChatVertexAI(\n",
" model_name=\"gemini-pro\",\n",
" safety_settings={\n",
" HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE\n",
" },\n",
")\n",
"\n",
"result = chat.generate([messages])\n",
"pprint(result.generations[0][0].generation_info)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Non-gemini model\n",
"\n",
"`generation_info` will include:\n",
"\n",
"- `is_blocked`: whether generation was blocked or not\n",
"- `safety_attributes`: a dictionary mapping safety attributes to their scores"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'is_blocked': False,\n",
" 'safety_attributes': {'Derogatory': 0.1,\n",
" 'Finance': 0.3,\n",
" 'Insult': 0.1,\n",
" 'Sexual': 0.1}}\n"
]
}
],
"source": [
"chat = ChatVertexAI() # default is `chat-bison`\n",
"\n",
"result = chat.generate([messages])\n",
"pprint(result.generations[0][0].generation_info)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -210,7 +329,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -224,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -268,7 +387,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{

@ -1,5 +1,13 @@
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
__all__ = ["ChatVertexAI", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden"]
__all__ = [
"ChatVertexAI",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
"HarmBlockThreshold",
"HarmCategory",
]

@ -0,0 +1,6 @@
from vertexai.preview.generative_models import ( # type: ignore
HarmBlockThreshold,
HarmCategory,
)
__all__ = ["HarmBlockThreshold", "HarmCategory"]

@ -1,6 +1,6 @@
"""Utilities to init Vertex AI."""
from importlib import metadata
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
import google.api_core
from google.api_core.gapic_v1.client_info import ClientInfo
@ -86,3 +86,29 @@ def is_codey_model(model_name: str) -> bool:
def is_gemini_model(model_name: str) -> bool:
"""Returns True if the model name is a Gemini model."""
return model_name is not None and "gemini" in model_name
def get_generation_info(candidate: Any, is_gemini: bool) -> Optional[Dict[str, Any]]:
try:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
return {
"is_blocked": any(
[rating.blocked for rating in candidate.safety_ratings]
),
"safety_ratings": [
{
"category": rating.category.name,
"probability_label": rating.probability.name,
}
for rating in candidate.safety_ratings
],
}
else:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
return {
"is_blocked": candidate.is_blocked,
"safety_attributes": candidate.safety_attributes,
}
except Exception:
return None

@ -47,6 +47,9 @@ from vertexai.preview.generative_models import ( # type: ignore
)
from langchain_google_vertexai._utils import (
get_generation_info,
is_codey_model,
is_gemini_model,
load_image_from_gcs,
)
from langchain_google_vertexai.functions_utils import (
@ -54,8 +57,6 @@ from langchain_google_vertexai.functions_utils import (
)
from langchain_google_vertexai.llms import (
_VertexAICommon,
is_codey_model,
is_gemini_model,
)
logger = logging.getLogger(__name__)
@ -271,9 +272,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
is_gemini = is_gemini_model(values["model_name"])
safety_settings = values["safety_settings"]
if safety_settings and not is_gemini:
raise ValueError("Safety settings are only supported for Gemini models")
cls._init_vertexai(values)
if is_gemini:
values["client"] = GenerativeModel(model_name=values["model_name"])
values["client"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
else:
if is_codey_model(values["model_name"]):
model_cls = CodeChatModel
@ -306,6 +314,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
ValueError: if the last message in the list is not from human.
"""
should_stream = stream if stream is not None else self.streaming
safety_settings = kwargs.pop("safety_settings", None)
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
@ -325,9 +334,17 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
response = chat.send_message(message, generation_config=params, tools=tools)
response = chat.send_message(
message,
generation_config=params,
tools=tools,
safety_settings=safety_settings,
)
generations = [
ChatGeneration(message=_parse_response_candidate(c))
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
)
for c in response.candidates
]
else:
@ -339,7 +356,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(message=AIMessage(content=r.text))
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
)
for r in response.candidates
]
return ChatResult(generations=generations)
@ -370,6 +390,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
logger.warning("ChatVertexAI does not currently support async streaming.")
params = self._prepare_params(stop=stop, **kwargs)
safety_settings = kwargs.pop("safety_settings", None)
msg_params = {}
if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count")
@ -382,22 +403,31 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
response = await chat.send_message_async(
message, generation_config=params, tools=tools
message,
generation_config=params,
tools=tools,
safety_settings=safety_settings,
)
generations = [
ChatGeneration(message=_parse_response_candidate(c))
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
)
for c in response.candidates
]
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None)
examples = kwargs.get("examples", None) or self.examples
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)
generations = [
ChatGeneration(message=AIMessage(content=r.text))
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
)
for r in response.candidates
]
return ChatResult(generations=generations)
@ -441,7 +471,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(response, self._is_gemini_model),
)
def _start_chat(
self, history: _ChatHistory, **kwargs: Any

@ -26,7 +26,10 @@ from vertexai.language_models import ( # type: ignore
from vertexai.language_models._language_models import ( # type: ignore
TextGenerationResponse,
)
from vertexai.preview.generative_models import GenerativeModel, Image # type: ignore
from vertexai.preview.generative_models import ( # type: ignore
GenerativeModel,
Image,
)
from vertexai.preview.language_models import ( # type: ignore
CodeGenerationModel as PreviewCodeGenerationModel,
)
@ -34,9 +37,11 @@ from vertexai.preview.language_models import (
TextGenerationModel as PreviewTextGenerationModel,
)
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai._utils import (
create_retry_decorator,
get_client_info,
get_generation_info,
is_codey_model,
is_gemini_model,
)
@ -66,7 +71,10 @@ def _completion_with_retry(
) -> Any:
if is_gemini:
return llm.client.generate_content(
prompt, stream=stream, generation_config=kwargs
prompt,
stream=stream,
safety_settings=kwargs.pop("safety_settings", None),
generation_config=kwargs,
)
else:
if stream:
@ -94,7 +102,9 @@ async def _acompletion_with_retry(
) -> Any:
if is_gemini:
return await llm.client.generate_content_async(
prompt, generation_config=kwargs
prompt,
generation_config=kwargs,
safety_settings=kwargs.pop("safety_settings", None),
)
return await llm.client.predict_async(prompt, **kwargs)
@ -141,6 +151,21 @@ class _VertexAICommon(_VertexAIBase):
"""How many completions to generate for each prompt."""
streaming: bool = False
"""Whether to stream the results or not."""
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
"""The default safety settings to use for all generations.
For example:
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
safety_settings = {
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
""" # noqa: E501
@property
def _llm_type(self) -> str:
@ -237,9 +262,13 @@ class VertexAI(_VertexAICommon, BaseLLM):
"""Validate that the python package exists in environment."""
tuned_model_name = values.get("tuned_model_name")
model_name = values["model_name"]
safety_settings = values["safety_settings"]
is_gemini = is_gemini_model(values["model_name"])
cls._init_vertexai(values)
if safety_settings and (not is_gemini or tuned_model_name):
raise ValueError("Safety settings are only supported for Gemini models")
if is_codey_model(model_name):
model_cls = CodeGenerationModel
preview_model_cls = PreviewCodeGenerationModel
@ -257,8 +286,12 @@ class VertexAI(_VertexAICommon, BaseLLM):
)
else:
if is_gemini:
values["client"] = model_cls(model_name=model_name)
values["client_preview"] = preview_model_cls(model_name=model_name)
values["client"] = model_cls(
model_name=model_name, safety_settings=safety_settings
)
values["client_preview"] = preview_model_cls(
model_name=model_name, safety_settings=safety_settings
)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
@ -285,14 +318,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
self, response: TextGenerationResponse
) -> GenerationChunk:
"""Converts a stream response to a generation chunk."""
try:
generation_info = {
"is_blocked": response.is_blocked,
"safety_attributes": response.safety_attributes,
}
except Exception:
generation_info = None
return GenerationChunk(text=response.text, generation_info=generation_info)
generation_info = get_generation_info(response, self._is_gemini_model)
return GenerationChunk(
text=response.text
if hasattr(response, "text")
else "", # might not exist if blocked
generation_info=generation_info,
)
def _generate(
self,

@ -504,13 +504,13 @@ uritemplate = ">=3.0.1,<5"
[[package]]
name = "google-auth"
version = "2.26.1"
version = "2.26.2"
description = "Google Authentication Library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google-auth-2.26.1.tar.gz", hash = "sha256:54385acca5c0fbdda510cd8585ba6f3fcb06eeecf8a6ecca39d3ee148b092590"},
{file = "google_auth-2.26.1-py2.py3-none-any.whl", hash = "sha256:2c8b55e3e564f298122a02ab7b97458ccfcc5617840beb5d0ac757ada92c9780"},
{file = "google-auth-2.26.2.tar.gz", hash = "sha256:97327dbbf58cccb58fc5a1712bba403ae76668e64814eb30f7316f7e27126b81"},
{file = "google_auth-2.26.2-py2.py3-none-any.whl", hash = "sha256:3f445c8ce9b61ed6459aad86d8ccdba4a9afed841b2d1451a11ef4db08957424"},
]
[package.dependencies]
@ -582,13 +582,13 @@ xai = ["tensorflow (>=2.3.0,<3.0.0dev)"]
[[package]]
name = "google-cloud-bigquery"
version = "3.14.1"
version = "3.16.0"
description = "Google BigQuery API client library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google-cloud-bigquery-3.14.1.tar.gz", hash = "sha256:aa15bd86f79ea76824c7d710f5ae532323c4b3ba01ef4abff42d4ee7a2e9b142"},
{file = "google_cloud_bigquery-3.14.1-py2.py3-none-any.whl", hash = "sha256:a8ded18455da71508db222b7c06197bc12b6dbc6ed5b0b64e7007b76d7016957"},
{file = "google-cloud-bigquery-3.16.0.tar.gz", hash = "sha256:1d6abf4b1d740df17cb43a078789872af8059a0b1dd999f32ea69ebc6f7ba7ef"},
{file = "google_cloud_bigquery-3.16.0-py2.py3-none-any.whl", hash = "sha256:8bac7754f92bf87ee81f38deabb7554d82bb9591fbe06a5c82f33e46e5a482f9"},
]
[package.dependencies]
@ -1110,13 +1110,13 @@ url = "../../core"
[[package]]
name = "langsmith"
version = "0.0.77"
version = "0.0.81"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langsmith-0.0.77-py3-none-any.whl", hash = "sha256:750c0aa9177240c64e131d831e009ed08dd59038f7cabbd0bbcf62ccb7c8dcac"},
{file = "langsmith-0.0.77.tar.gz", hash = "sha256:c4c8d3a96ad8671a41064f3ccc673e2e22a4153e823b19f915c9c9b8a4f33a2c"},
{file = "langsmith-0.0.81-py3-none-any.whl", hash = "sha256:eb816ad456776ec4c6005ddce8a4c315a1a582ed4d079979888e9f8a1db209b3"},
{file = "langsmith-0.0.81.tar.gz", hash = "sha256:5838e5a4bb1939e9794eb3f802f7c390247a847bd603e31442be5be00068e504"},
]
[package.dependencies]
@ -1410,22 +1410,22 @@ testing = ["google-api-core[grpc] (>=1.31.5)"]
[[package]]
name = "protobuf"
version = "4.25.1"
version = "4.25.2"
description = ""
optional = false
python-versions = ">=3.8"
files = [
{file = "protobuf-4.25.1-cp310-abi3-win32.whl", hash = "sha256:193f50a6ab78a970c9b4f148e7c750cfde64f59815e86f686c22e26b4fe01ce7"},
{file = "protobuf-4.25.1-cp310-abi3-win_amd64.whl", hash = "sha256:3497c1af9f2526962f09329fd61a36566305e6c72da2590ae0d7d1322818843b"},
{file = "protobuf-4.25.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:0bf384e75b92c42830c0a679b0cd4d6e2b36ae0cf3dbb1e1dfdda48a244f4bcd"},
{file = "protobuf-4.25.1-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:0f881b589ff449bf0b931a711926e9ddaad3b35089cc039ce1af50b21a4ae8cb"},
{file = "protobuf-4.25.1-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:ca37bf6a6d0046272c152eea90d2e4ef34593aaa32e8873fc14c16440f22d4b7"},
{file = "protobuf-4.25.1-cp38-cp38-win32.whl", hash = "sha256:abc0525ae2689a8000837729eef7883b9391cd6aa7950249dcf5a4ede230d5dd"},
{file = "protobuf-4.25.1-cp38-cp38-win_amd64.whl", hash = "sha256:1484f9e692091450e7edf418c939e15bfc8fc68856e36ce399aed6889dae8bb0"},
{file = "protobuf-4.25.1-cp39-cp39-win32.whl", hash = "sha256:8bdbeaddaac52d15c6dce38c71b03038ef7772b977847eb6d374fc86636fa510"},
{file = "protobuf-4.25.1-cp39-cp39-win_amd64.whl", hash = "sha256:becc576b7e6b553d22cbdf418686ee4daa443d7217999125c045ad56322dda10"},
{file = "protobuf-4.25.1-py3-none-any.whl", hash = "sha256:a19731d5e83ae4737bb2a089605e636077ac001d18781b3cf489b9546c7c80d6"},
{file = "protobuf-4.25.1.tar.gz", hash = "sha256:57d65074b4f5baa4ab5da1605c02be90ac20c8b40fb137d6a8df9f416b0d0ce2"},
{file = "protobuf-4.25.2-cp310-abi3-win32.whl", hash = "sha256:b50c949608682b12efb0b2717f53256f03636af5f60ac0c1d900df6213910fd6"},
{file = "protobuf-4.25.2-cp310-abi3-win_amd64.whl", hash = "sha256:8f62574857ee1de9f770baf04dde4165e30b15ad97ba03ceac65f760ff018ac9"},
{file = "protobuf-4.25.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:2db9f8fa64fbdcdc93767d3cf81e0f2aef176284071507e3ede160811502fd3d"},
{file = "protobuf-4.25.2-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:10894a2885b7175d3984f2be8d9850712c57d5e7587a2410720af8be56cdaf62"},
{file = "protobuf-4.25.2-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:fc381d1dd0516343f1440019cedf08a7405f791cd49eef4ae1ea06520bc1c020"},
{file = "protobuf-4.25.2-cp38-cp38-win32.whl", hash = "sha256:33a1aeef4b1927431d1be780e87b641e322b88d654203a9e9d93f218ee359e61"},
{file = "protobuf-4.25.2-cp38-cp38-win_amd64.whl", hash = "sha256:47f3de503fe7c1245f6f03bea7e8d3ec11c6c4a2ea9ef910e3221c8a15516d62"},
{file = "protobuf-4.25.2-cp39-cp39-win32.whl", hash = "sha256:5e5c933b4c30a988b52e0b7c02641760a5ba046edc5e43d3b94a74c9fc57c1b3"},
{file = "protobuf-4.25.2-cp39-cp39-win_amd64.whl", hash = "sha256:d66a769b8d687df9024f2985d5137a337f957a0916cf5464d1513eee96a63ff0"},
{file = "protobuf-4.25.2-py3-none-any.whl", hash = "sha256:a8b7a98d4ce823303145bf3c1a8bdb0f2f4642a414b196f04ad9853ed0c8f830"},
{file = "protobuf-4.25.2.tar.gz", hash = "sha256:fe599e175cb347efc8ee524bcd4b902d11f7262c0e569ececcb89995c15f0a5e"},
]
[[package]]
@ -1775,28 +1775,28 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.1.11"
version = "0.1.13"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.1.11-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:a7f772696b4cdc0a3b2e527fc3c7ccc41cdcb98f5c80fdd4f2b8c50eb1458196"},
{file = "ruff-0.1.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:934832f6ed9b34a7d5feea58972635c2039c7a3b434fe5ba2ce015064cb6e955"},
{file = "ruff-0.1.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea0d3e950e394c4b332bcdd112aa566010a9f9c95814844a7468325290aabfd9"},
{file = "ruff-0.1.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9bd4025b9c5b429a48280785a2b71d479798a69f5c2919e7d274c5f4b32c3607"},
{file = "ruff-0.1.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1ad00662305dcb1e987f5ec214d31f7d6a062cae3e74c1cbccef15afd96611d"},
{file = "ruff-0.1.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4b077ce83f47dd6bea1991af08b140e8b8339f0ba8cb9b7a484c30ebab18a23f"},
{file = "ruff-0.1.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4a88efecec23c37b11076fe676e15c6cdb1271a38f2b415e381e87fe4517f18"},
{file = "ruff-0.1.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5b25093dad3b055667730a9b491129c42d45e11cdb7043b702e97125bcec48a1"},
{file = "ruff-0.1.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:231d8fb11b2cc7c0366a326a66dafc6ad449d7fcdbc268497ee47e1334f66f77"},
{file = "ruff-0.1.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:09c415716884950080921dd6237767e52e227e397e2008e2bed410117679975b"},
{file = "ruff-0.1.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0f58948c6d212a6b8d41cd59e349751018797ce1727f961c2fa755ad6208ba45"},
{file = "ruff-0.1.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:190a566c8f766c37074d99640cd9ca3da11d8deae2deae7c9505e68a4a30f740"},
{file = "ruff-0.1.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6464289bd67b2344d2a5d9158d5eb81025258f169e69a46b741b396ffb0cda95"},
{file = "ruff-0.1.11-py3-none-win32.whl", hash = "sha256:9b8f397902f92bc2e70fb6bebfa2139008dc72ae5177e66c383fa5426cb0bf2c"},
{file = "ruff-0.1.11-py3-none-win_amd64.whl", hash = "sha256:eb85ee287b11f901037a6683b2374bb0ec82928c5cbc984f575d0437979c521a"},
{file = "ruff-0.1.11-py3-none-win_arm64.whl", hash = "sha256:97ce4d752f964ba559c7023a86e5f8e97f026d511e48013987623915431c7ea9"},
{file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"},
{file = "ruff-0.1.13-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:e3fd36e0d48aeac672aa850045e784673449ce619afc12823ea7868fcc41d8ba"},
{file = "ruff-0.1.13-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9fb6b3b86450d4ec6a6732f9f60c4406061b6851c4b29f944f8c9d91c3611c7a"},
{file = "ruff-0.1.13-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b13ba5d7156daaf3fd08b6b993360a96060500aca7e307d95ecbc5bb47a69296"},
{file = "ruff-0.1.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9ebb40442f7b531e136d334ef0851412410061e65d61ca8ce90d894a094feb22"},
{file = "ruff-0.1.13-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226b517f42d59a543d6383cfe03cccf0091e3e0ed1b856c6824be03d2a75d3b6"},
{file = "ruff-0.1.13-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5f0312ba1061e9b8c724e9a702d3c8621e3c6e6c2c9bd862550ab2951ac75c16"},
{file = "ruff-0.1.13-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2f59bcf5217c661254bd6bc42d65a6fd1a8b80c48763cb5c2293295babd945dd"},
{file = "ruff-0.1.13-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e6894b00495e00c27b6ba61af1fc666f17de6140345e5ef27dd6e08fb987259d"},
{file = "ruff-0.1.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a1600942485c6e66119da294c6294856b5c86fd6df591ce293e4a4cc8e72989"},
{file = "ruff-0.1.13-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ee3febce7863e231a467f90e681d3d89210b900d49ce88723ce052c8761be8c7"},
{file = "ruff-0.1.13-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dcaab50e278ff497ee4d1fe69b29ca0a9a47cd954bb17963628fa417933c6eb1"},
{file = "ruff-0.1.13-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f57de973de4edef3ad3044d6a50c02ad9fc2dff0d88587f25f1a48e3f72edf5e"},
{file = "ruff-0.1.13-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7a36fa90eb12208272a858475ec43ac811ac37e91ef868759770b71bdabe27b6"},
{file = "ruff-0.1.13-py3-none-win32.whl", hash = "sha256:a623349a505ff768dad6bd57087e2461be8db58305ebd5577bd0e98631f9ae69"},
{file = "ruff-0.1.13-py3-none-win_amd64.whl", hash = "sha256:f988746e3c3982bea7f824c8fa318ce7f538c4dfefec99cd09c8770bd33e6539"},
{file = "ruff-0.1.13-py3-none-win_arm64.whl", hash = "sha256:6bbbc3042075871ec17f28864808540a26f0f79a4478c357d3e3d2284e832998"},
{file = "ruff-0.1.13.tar.gz", hash = "sha256:e261f1baed6291f434ffb1d5c6bd8051d1c2a26958072d38dfbec39b3dda7352"},
]
[[package]]
@ -2033,24 +2033,24 @@ files = [
[[package]]
name = "types-protobuf"
version = "4.24.0.4"
version = "4.24.0.20240106"
description = "Typing stubs for protobuf"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "types-protobuf-4.24.0.4.tar.gz", hash = "sha256:57ab42cb171dfdba2c74bb5b50c250478538cc3c5ed95b8b368929ad0c9f90a5"},
{file = "types_protobuf-4.24.0.4-py3-none-any.whl", hash = "sha256:131ab7d0cbc9e444bc89c994141327dcce7bcaeded72b1acb72a94827eb9c7af"},
{file = "types-protobuf-4.24.0.20240106.tar.gz", hash = "sha256:024f034f3b5e2bb2bbff55ebc4d591ed0d2280d90faceedcb148b9e714a3f3ee"},
{file = "types_protobuf-4.24.0.20240106-py3-none-any.whl", hash = "sha256:0612ef3156bd80567460a15ac7c109b313f6022f1fee04b4d922ab2789baab79"},
]
[[package]]
name = "types-requests"
version = "2.31.0.20231231"
version = "2.31.0.20240106"
description = "Typing stubs for requests"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "types-requests-2.31.0.20231231.tar.gz", hash = "sha256:0f8c0c9764773384122813548d9eea92a5c4e1f33ed54556b508968ec5065cee"},
{file = "types_requests-2.31.0.20231231-py3-none-any.whl", hash = "sha256:2e2230c7bc8dd63fa3153c1c0ae335f8a368447f0582fc332f17d54f88e69027"},
{file = "types-requests-2.31.0.20240106.tar.gz", hash = "sha256:0e1c731c17f33618ec58e022b614a1a2ecc25f7dc86800b36ef341380402c612"},
{file = "types_requests-2.31.0.20240106-py3-none-any.whl", hash = "sha256:da997b3b6a72cc08d09f4dba9802fdbabc89104b35fe24ee588e674037689354"},
]
[package.dependencies]

@ -1,4 +1,6 @@
"""Test ChatGoogleVertexAI chat model."""
from typing import cast
import pytest
from langchain_core.messages import (
AIMessage,
@ -6,7 +8,7 @@ from langchain_core.messages import (
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import LLMResult
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_google_vertexai.chat_models import ChatVertexAI
@ -60,7 +62,13 @@ async def test_vertexai_agenerate(model_name: str) -> None:
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
sync_response = model.generate([[message]])
assert response.generations[0][0] == sync_response.generations[0][0]
sync_generation = cast(ChatGeneration, sync_response.generations[0][0])
async_generation = cast(ChatGeneration, response.generations[0][0])
# assert some properties to make debugging easier
assert sync_generation.message.content == async_generation.message.content
assert sync_generation.generation_info == async_generation.generation_info
assert sync_generation == async_generation
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])

@ -42,6 +42,7 @@ def test_vertex_call(model_name: str) -> None:
assert isinstance(output, str)
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
def test_vertex_generate() -> None:
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
output = llm.generate(["Say foo:"])
@ -50,6 +51,7 @@ def test_vertex_generate() -> None:
assert len(output.generations[0]) == 2
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
def test_vertex_generate_code() -> None:
llm = VertexAI(temperature=0.3, n=2, model_name="code-bison@001")
output = llm.generate(["generate a python method that says foo:"])
@ -87,6 +89,7 @@ async def test_vertex_consistency() -> None:
assert output.generations[0][0].text == async_output.generations[0][0].text
@pytest.mark.skip("CI testing not set up")
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
@ -115,6 +118,7 @@ def test_model_garden(
assert llm._llm_type == "vertexai_model_garden"
@pytest.mark.skip("CI testing not set up")
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
@ -143,6 +147,7 @@ def test_model_garden_generate(
assert len(output.generations) == 2
@pytest.mark.skip("CI testing not set up")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",

@ -0,0 +1,97 @@
from langchain_core.outputs import LLMResult
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory, VertexAI
SAFETY_SETTINGS = {
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
# below context and question are taken from one of opensource QA datasets
BLOCKED_PROMPT = """
You are agent designed to answer questions.
You are given context in triple backticks.
```
The religion\'s failure to report abuse allegations to authorities has also been
criticized. The Watch Tower Society\'s policy is that elders inform authorities when
required by law to do so, but otherwise leave that action up to the victim and his
or her family. The Australian Royal Commission into Institutional Responses to Child
Sexual Abuse found that of 1006 alleged perpetrators of child sexual abuse
identified by the Jehovah\'s Witnesses within their organization since 1950,
"not one was reported by the church to secular authorities." William Bowen, a former
Jehovah\'s Witness elder who established the Silentlambs organization to assist sex
abuse victims within the religion, has claimed Witness leaders discourage followers
from reporting incidents of sexual misconduct to authorities, and other critics claim
the organization is reluctant to alert authorities in order to protect its "crime-free"
reputation. In court cases in the United Kingdom and the United States the Watch Tower
Society has been found to have been negligent in its failure to protect children from
known sex offenders within the congregation and the Society has settled other child
abuse lawsuits out of court, reportedly paying as much as $780,000 to one plaintiff
without admitting wrongdoing.
```
Question: What have courts in both the UK and the US found the Watch Tower Society to
have been for failing to protect children from sexual predators within the
congregation ?
Answer:
"""
def test_gemini_safety_settings_generate() -> None:
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
output = llm.generate(["What do you think about child abuse:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
generation_info = output.generations[0][0].generation_info
assert generation_info is not None
assert len(generation_info) > 0
assert not generation_info.get("is_blocked")
blocked_output = llm.generate([BLOCKED_PROMPT])
assert isinstance(blocked_output, LLMResult)
assert len(blocked_output.generations) == 1
assert len(blocked_output.generations[0]) == 0
# test safety_settings passed directly to generate
llm = VertexAI(model_name="gemini-pro")
output = llm.generate(
["What do you think about child abuse:"], safety_settings=SAFETY_SETTINGS
)
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
generation_info = output.generations[0][0].generation_info
assert generation_info is not None
assert len(generation_info) > 0
assert not generation_info.get("is_blocked")
async def test_gemini_safety_settings_agenerate() -> None:
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
output = await llm.agenerate(["What do you think about child abuse:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
generation_info = output.generations[0][0].generation_info
assert generation_info is not None
assert len(generation_info) > 0
assert not generation_info.get("is_blocked")
blocked_output = await llm.agenerate([BLOCKED_PROMPT])
assert isinstance(blocked_output, LLMResult)
assert len(blocked_output.generations) == 1
# assert len(blocked_output.generations[0][0].generation_info) > 0
# assert blocked_output.generations[0][0].generation_info.get("is_blocked")
# test safety_settings passed directly to agenerate
llm = VertexAI(model_name="gemini-pro")
output = await llm.agenerate(
["What do you think about child abuse:"], safety_settings=SAFETY_SETTINGS
)
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
generation_info = output.generations[0][0].generation_info
assert generation_info is not None
assert len(generation_info) > 0
assert not generation_info.get("is_blocked")

@ -1,6 +1,13 @@
from langchain_google_vertexai import __all__
EXPECTED_ALL = ["ChatVertexAI", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden"]
EXPECTED_ALL = [
"ChatVertexAI",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
"HarmBlockThreshold",
"HarmCategory",
]
def test_all_imports() -> None:

Loading…
Cancel
Save