google-genai[minor]: add safety settings (#16836)

Replace this entire comment with:
- **Description:Expose safety_settings for Gemini integrations on
google-generativeai
  - **Issue:NA,
  - **Dependencies:NA
  - **Twitter handle:@aditya_rane

@lkuligin for review

---------

Co-authored-by: adityarane@google.com <adityarane@google.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
pull/17069/head
Aditya 8 months ago committed by GitHub
parent 584b647b96
commit a23c719c8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -54,6 +54,8 @@ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
embeddings.embed_query("hello, world!")
```
""" # noqa: E501
from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_google_genai.llms import GoogleGenerativeAI
@ -62,4 +64,6 @@ __all__ = [
"ChatGoogleGenerativeAI",
"GoogleGenerativeAIEmbeddings",
"GoogleGenerativeAI",
"HarmBlockThreshold",
"HarmCategory",
]

@ -0,0 +1,6 @@
from google.generativeai.types.safety_types import ( # type: ignore
HarmBlockThreshold,
HarmCategory,
)
__all__ = ["HarmBlockThreshold", "HarmCategory"]

@ -517,6 +517,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
"temperature": self.temperature,
"top_k": self.top_k,
"n": self.n,
"safety_settings": self.safety_settings,
}
def _prepare_params(
@ -549,7 +550,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
params, chat, message = self._prepare_chat(
messages,
stop=stop,
functions=kwargs.get("functions"),
**kwargs,
)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
@ -568,7 +569,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
params, chat, message = self._prepare_chat(
messages,
stop=stop,
functions=kwargs.get("functions"),
**kwargs,
)
response: genai.types.GenerateContentResponse = await _achat_with_retry(
content=message,
@ -587,7 +588,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
params, chat, message = self._prepare_chat(
messages,
stop=stop,
functions=kwargs.get("functions"),
**kwargs,
)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
@ -613,7 +614,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
params, chat, message = self._prepare_chat(
messages,
stop=stop,
functions=kwargs.get("functions"),
**kwargs,
)
async for chunk in await _achat_with_retry(
content=message,
@ -636,9 +637,14 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
client = self.client
functions = kwargs.pop("functions", None)
if functions:
tools = convert_to_genai_function_declarations(functions)
client = genai.GenerativeModel(model_name=self.model, tools=tools)
safety_settings = kwargs.pop("safety_settings", self.safety_settings)
if functions or safety_settings:
tools = (
convert_to_genai_function_declarations(functions) if functions else None
)
client = genai.GenerativeModel(
model_name=self.model, tools=tools, safety_settings=safety_settings
)
params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(

@ -15,6 +15,11 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_google_genai._enums import (
HarmBlockThreshold,
HarmCategory,
)
class GoogleModelFamily(str, Enum):
GEMINI = auto()
@ -77,7 +82,10 @@ def _completion_with_retry(
try:
if is_gemini:
return llm.client.generate_content(
contents=prompt, stream=stream, generation_config=generation_config
contents=prompt,
stream=stream,
generation_config=generation_config,
safety_settings=kwargs.pop("safety_settings", None),
)
return llm.client.generate_text(prompt=prompt, **kwargs)
except google.api_core.exceptions.FailedPrecondition as exc:
@ -143,6 +151,22 @@ Supported examples:
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
"""The default safety settings to use for all generations.
For example:
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
safety_settings = {
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
""" # noqa: E501
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@ -184,6 +208,8 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
)
model_name = values["model"]
safety_settings = values["safety_settings"]
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
@ -193,8 +219,15 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
client_options=values.get("client_options"),
)
if safety_settings and (
not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
):
raise ValueError("Safety settings are only supported for Gemini models")
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
values["client"] = genai.GenerativeModel(model_name=model_name)
values["client"] = genai.GenerativeModel(
model_name=model_name, safety_settings=safety_settings
)
else:
values["client"] = genai
@ -237,6 +270,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
safety_settings=kwargs.pop("safety_settings", None),
)
candidates = [
"".join([p.text for p in c.content.parts]) for c in res.candidates
@ -278,6 +312,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
safety_settings=kwargs.pop("safety_settings", None),
**kwargs,
):
chunk = GenerationChunk(text=stream_resp.text)

@ -1146,13 +1146,13 @@ files = [
[[package]]
name = "tqdm"
version = "4.66.1"
version = "4.66.2"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
files = [
{file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"},
{file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"},
{file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"},
{file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"},
]
[package.dependencies]

@ -1,12 +1,16 @@
"""Test ChatGoogleGenerativeAI chat model."""
from typing import Generator
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_google_genai.chat_models import (
from langchain_google_genai import (
ChatGoogleGenerativeAI,
ChatGoogleGenerativeAIError,
HarmBlockThreshold,
HarmCategory,
)
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
_MODEL = "gemini-pro" # TODO: Use nano when it's available.
_VISION_MODEL = "gemini-pro-vision"
@ -193,3 +197,32 @@ def test_generativeai_get_num_tokens_gemini() -> None:
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.get_num_tokens("How are you?")
assert output == 4
def test_safety_settings_gemini() -> None:
safety_settings = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
# test with safety filters on bind
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro").bind(
safety_settings=safety_settings
)
output = llm.invoke("how to make a bomb?")
assert isinstance(output, AIMessage)
assert len(output.content) > 0
# test direct to stream
streamed_messages = []
output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings)
assert isinstance(output_stream, Generator)
for message in output_stream:
streamed_messages.append(message)
assert len(streamed_messages) > 0
# test as init param
llm = ChatGoogleGenerativeAI(
temperature=0, model="gemini-pro", safety_settings=safety_settings
)
out2 = llm.invoke("how to make a bomb")
assert isinstance(out2, AIMessage)
assert len(out2.content) > 0

@ -4,10 +4,12 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
valid API key.
"""
from typing import Generator
import pytest
from langchain_core.outputs import LLMResult
from langchain_google_genai.llms import GoogleGenerativeAI
from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory
model_names = ["models/text-bison-001", "gemini-pro"]
@ -66,3 +68,39 @@ def test_generativeai_get_num_tokens_gemini() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.get_num_tokens("How are you?")
assert output == 4
def test_safety_settings_gemini() -> None:
# test with blocked prompt
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.generate(prompts=["how to make a bomb?"])
assert isinstance(output, LLMResult)
assert len(output.generations[0]) == 0
# safety filters
safety_settings = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
# test with safety filters directly to generate
output = llm.generate(["how to make a bomb?"], safety_settings=safety_settings)
assert isinstance(output, LLMResult)
assert len(output.generations[0]) > 0
# test with safety filters directly to stream
streamed_messages = []
output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings)
assert isinstance(output_stream, Generator)
for message in output_stream:
streamed_messages.append(message)
assert len(streamed_messages) > 0
# test with safety filters on instantiation
llm = GoogleGenerativeAI(
model="gemini-pro",
safety_settings=safety_settings,
temperature=0,
)
output = llm.generate(prompts=["how to make a bomb?"])
assert isinstance(output, LLMResult)
assert len(output.generations[0]) > 0

@ -4,6 +4,8 @@ EXPECTED_ALL = [
"ChatGoogleGenerativeAI",
"GoogleGenerativeAIEmbeddings",
"GoogleGenerativeAI",
"HarmBlockThreshold",
"HarmCategory",
]

Loading…
Cancel
Save