mirror of https://github.com/hwchase17/langchain
google-vertexai[minor]: added safety_settings property to gemini wrapper (#15344)
**Description:** Gemini model has quite annoying default safety_settings settings. In addition, current VertexAI class doesn't provide a property to override such settings. So, this PR aims to - add safety_settings property to VertexAI - fix issue with incorrect LLM output parsing when LLM responds with appropriate 'blocked' response - fix issue with incorrect parsing LLM output when Gemini API blocks prompt itself as inappropriate - add safety_settings related tests I'm not enough familiar with langchain code base and guidelines. So, any comments and/or suggestions are very welcome. **Issue:** it will likely fix #14841 --------- Co-authored-by: Erick Friis <erick@langchain.dev>pull/16211/head
parent
ecd4f0a7ec
commit
6b9e3ed9e9
@ -1,5 +1,13 @@
|
|||||||
|
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
|
||||||
from langchain_google_vertexai.chat_models import ChatVertexAI
|
from langchain_google_vertexai.chat_models import ChatVertexAI
|
||||||
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
|
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
|
||||||
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
|
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
|
||||||
|
|
||||||
__all__ = ["ChatVertexAI", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden"]
|
__all__ = [
|
||||||
|
"ChatVertexAI",
|
||||||
|
"VertexAIEmbeddings",
|
||||||
|
"VertexAI",
|
||||||
|
"VertexAIModelGarden",
|
||||||
|
"HarmBlockThreshold",
|
||||||
|
"HarmCategory",
|
||||||
|
]
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
from vertexai.preview.generative_models import ( # type: ignore
|
||||||
|
HarmBlockThreshold,
|
||||||
|
HarmCategory,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["HarmBlockThreshold", "HarmCategory"]
|
@ -0,0 +1,97 @@
|
|||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory, VertexAI
|
||||||
|
|
||||||
|
SAFETY_SETTINGS = {
|
||||||
|
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# below context and question are taken from one of opensource QA datasets
|
||||||
|
BLOCKED_PROMPT = """
|
||||||
|
You are agent designed to answer questions.
|
||||||
|
You are given context in triple backticks.
|
||||||
|
```
|
||||||
|
The religion\'s failure to report abuse allegations to authorities has also been
|
||||||
|
criticized. The Watch Tower Society\'s policy is that elders inform authorities when
|
||||||
|
required by law to do so, but otherwise leave that action up to the victim and his
|
||||||
|
or her family. The Australian Royal Commission into Institutional Responses to Child
|
||||||
|
Sexual Abuse found that of 1006 alleged perpetrators of child sexual abuse
|
||||||
|
identified by the Jehovah\'s Witnesses within their organization since 1950,
|
||||||
|
"not one was reported by the church to secular authorities." William Bowen, a former
|
||||||
|
Jehovah\'s Witness elder who established the Silentlambs organization to assist sex
|
||||||
|
abuse victims within the religion, has claimed Witness leaders discourage followers
|
||||||
|
from reporting incidents of sexual misconduct to authorities, and other critics claim
|
||||||
|
the organization is reluctant to alert authorities in order to protect its "crime-free"
|
||||||
|
reputation. In court cases in the United Kingdom and the United States the Watch Tower
|
||||||
|
Society has been found to have been negligent in its failure to protect children from
|
||||||
|
known sex offenders within the congregation and the Society has settled other child
|
||||||
|
abuse lawsuits out of court, reportedly paying as much as $780,000 to one plaintiff
|
||||||
|
without admitting wrongdoing.
|
||||||
|
```
|
||||||
|
Question: What have courts in both the UK and the US found the Watch Tower Society to
|
||||||
|
have been for failing to protect children from sexual predators within the
|
||||||
|
congregation ?
|
||||||
|
Answer:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_safety_settings_generate() -> None:
|
||||||
|
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
|
||||||
|
output = llm.generate(["What do you think about child abuse:"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 1
|
||||||
|
generation_info = output.generations[0][0].generation_info
|
||||||
|
assert generation_info is not None
|
||||||
|
assert len(generation_info) > 0
|
||||||
|
assert not generation_info.get("is_blocked")
|
||||||
|
|
||||||
|
blocked_output = llm.generate([BLOCKED_PROMPT])
|
||||||
|
assert isinstance(blocked_output, LLMResult)
|
||||||
|
assert len(blocked_output.generations) == 1
|
||||||
|
assert len(blocked_output.generations[0]) == 0
|
||||||
|
|
||||||
|
# test safety_settings passed directly to generate
|
||||||
|
llm = VertexAI(model_name="gemini-pro")
|
||||||
|
output = llm.generate(
|
||||||
|
["What do you think about child abuse:"], safety_settings=SAFETY_SETTINGS
|
||||||
|
)
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 1
|
||||||
|
generation_info = output.generations[0][0].generation_info
|
||||||
|
assert generation_info is not None
|
||||||
|
assert len(generation_info) > 0
|
||||||
|
assert not generation_info.get("is_blocked")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_gemini_safety_settings_agenerate() -> None:
|
||||||
|
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
|
||||||
|
output = await llm.agenerate(["What do you think about child abuse:"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 1
|
||||||
|
generation_info = output.generations[0][0].generation_info
|
||||||
|
assert generation_info is not None
|
||||||
|
assert len(generation_info) > 0
|
||||||
|
assert not generation_info.get("is_blocked")
|
||||||
|
|
||||||
|
blocked_output = await llm.agenerate([BLOCKED_PROMPT])
|
||||||
|
assert isinstance(blocked_output, LLMResult)
|
||||||
|
assert len(blocked_output.generations) == 1
|
||||||
|
# assert len(blocked_output.generations[0][0].generation_info) > 0
|
||||||
|
# assert blocked_output.generations[0][0].generation_info.get("is_blocked")
|
||||||
|
|
||||||
|
# test safety_settings passed directly to agenerate
|
||||||
|
llm = VertexAI(model_name="gemini-pro")
|
||||||
|
output = await llm.agenerate(
|
||||||
|
["What do you think about child abuse:"], safety_settings=SAFETY_SETTINGS
|
||||||
|
)
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 1
|
||||||
|
generation_info = output.generations[0][0].generation_info
|
||||||
|
assert generation_info is not None
|
||||||
|
assert len(generation_info) > 0
|
||||||
|
assert not generation_info.get("is_blocked")
|
Loading…
Reference in New Issue