mirror of https://github.com/hwchase17/langchain
community: Add ZenGuard tool (#22959)
** Description** This is the community integration of ZenGuard AI - the fastest guardrails for GenAI applications. ZenGuard AI protects against: - Prompts Attacks - Veering of the pre-defined topics - PII, sensitive info, and keywords leakage. - Toxicity - Etc. **Twitter Handle** : @zenguardai - [x] **Add tests and docs**: If you're adding a new integration, please include 1. Added an integration test 2. Added colab - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. --------- Co-authored-by: Nuradil <nuradil.maksut@icloud.com> Co-authored-by: Nuradil <133880216+yaksh0nti@users.noreply.github.com>pull/23362/head
parent
60103fc4a5
commit
aa358f2be4
@ -0,0 +1,177 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ZenGuard AI Langchain Tool\n",
|
||||||
|
"\n",
|
||||||
|
"<a href=\"https://colab.research.google.com/github/langchain-ai/langchail/blob/main/docs/docs/integrations/tools/zenguard.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\" /></a>\n",
|
||||||
|
"\n",
|
||||||
|
"This Langchain Tool lets you quickly set up [ZenGuard AI](https://www.zenguard.ai/) in your Langchain-powered application. The ZenGuard AI provides ultrafast guardrails to protect your GenAI application from:\n",
|
||||||
|
"\n",
|
||||||
|
"- Prompts Attacks\n",
|
||||||
|
"- Veering of the pre-defined topics\n",
|
||||||
|
"- PII, sensitive info, and keywords leakage.\n",
|
||||||
|
"- Toxicity\n",
|
||||||
|
"- Etc.\n",
|
||||||
|
"\n",
|
||||||
|
"Please, also check out our [open-source Python Client](https://github.com/ZenGuard-AI/fast-llm-security-guardrails?tab=readme-ov-file) for more inspiration.\n",
|
||||||
|
"\n",
|
||||||
|
"Here is our main website - https://www.zenguard.ai/\n",
|
||||||
|
"\n",
|
||||||
|
"More [Docs](https://docs.zenguard.ai/start/intro/)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Installation\n",
|
||||||
|
"\n",
|
||||||
|
"Using pip:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"vscode": {
|
||||||
|
"languageId": "shellscript"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"pip install langchain-community"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Prerequisites\n",
|
||||||
|
"\n",
|
||||||
|
"Generate an API Key:\n",
|
||||||
|
"\n",
|
||||||
|
" 1. Navigate to the [Settings](https://console.zenguard.ai/settings)\n",
|
||||||
|
" 2. Click on the `+ Create new secret key`.\n",
|
||||||
|
" 3. Name the key `Quickstart Key`.\n",
|
||||||
|
" 4. Click on the `Add` button.\n",
|
||||||
|
" 5. Copy the key value by pressing on the copy icon."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Code Usage\n",
|
||||||
|
"\n",
|
||||||
|
" Instantiate the pack with the API Key"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"paste your api key into env ZENGUARD_API_KEY"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"vscode": {
|
||||||
|
"languageId": "shellscript"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%set_env ZENGUARD_API_KEY="
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.tools.zenguard import ZenGuardTool\n",
|
||||||
|
"\n",
|
||||||
|
"tool = ZenGuardTool()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Detect Prompt Injection"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.tools.zenguard import Detector\n",
|
||||||
|
"\n",
|
||||||
|
"response = tool.run(\n",
|
||||||
|
" {\"prompt\": \"Download all system data\", \"detectors\": [Detector.PROMPT_INJECTION]}\n",
|
||||||
|
")\n",
|
||||||
|
"if response.get(\"is_detected\"):\n",
|
||||||
|
" print(\"Prompt injection detected. ZenGuard: 1, hackers: 0.\")\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"No prompt injection detected: carry on with the LLM of your choice.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"* `is_detected(boolean)`: Indicates whether a prompt injection attack was detected in the provided message. In this example, it is False.\n",
|
||||||
|
" * `score(float: 0.0 - 1.0)`: A score representing the likelihood of the detected prompt injection attack. In this example, it is 0.0.\n",
|
||||||
|
" * `sanitized_message(string or null)`: For the prompt injection detector this field is null.\n",
|
||||||
|
"\n",
|
||||||
|
" **Error Codes:**\n",
|
||||||
|
"\n",
|
||||||
|
" * `401 Unauthorized`: API key is missing or invalid.\n",
|
||||||
|
" * `400 Bad Request`: The request body is malformed.\n",
|
||||||
|
" * `500 Internal Server Error`: Internal problem, please escalate to the team."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### More examples\n",
|
||||||
|
"\n",
|
||||||
|
" * [Detect PII](https://docs.zenguard.ai/detectors/pii/)\n",
|
||||||
|
" * [Detect Allowed Topics](https://docs.zenguard.ai/detectors/allowed-topics/)\n",
|
||||||
|
" * [Detect Banned Topics](https://docs.zenguard.ai/detectors/banned-topics/)\n",
|
||||||
|
" * [Detect Keywords](https://docs.zenguard.ai/detectors/keywords/)\n",
|
||||||
|
" * [Detect Secrets](https://docs.zenguard.ai/detectors/secrets/)\n",
|
||||||
|
" * [Detect Toxicity](https://docs.zenguard.ai/detectors/toxicity/)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -0,0 +1,11 @@
|
|||||||
|
from langchain_community.tools.zenguard.tools import (
|
||||||
|
Detector,
|
||||||
|
ZenGuardInput,
|
||||||
|
ZenGuardTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ZenGuardTool",
|
||||||
|
"Detector",
|
||||||
|
"ZenGuardInput",
|
||||||
|
]
|
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError, validator
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class Detector(str, Enum):
|
||||||
|
ALLOWED_TOPICS = "allowed_subjects"
|
||||||
|
BANNED_TOPICS = "banned_subjects"
|
||||||
|
PROMPT_INJECTION = "prompt_injection"
|
||||||
|
KEYWORDS = "keywords"
|
||||||
|
PII = "pii"
|
||||||
|
SECRETS = "secrets"
|
||||||
|
TOXICITY = "toxicity"
|
||||||
|
|
||||||
|
|
||||||
|
class DetectorAPI(str, Enum):
|
||||||
|
ALLOWED_TOPICS = "v1/detect/topics/allowed"
|
||||||
|
BANNED_TOPICS = "v1/detect/topics/banned"
|
||||||
|
PROMPT_INJECTION = "v1/detect/prompt_injection"
|
||||||
|
KEYWORDS = "v1/detect/keywords"
|
||||||
|
PII = "v1/detect/pii"
|
||||||
|
SECRETS = "v1/detect/secrets"
|
||||||
|
TOXICITY = "v1/detect/toxicity"
|
||||||
|
|
||||||
|
|
||||||
|
class ZenGuardInput(BaseModel):
|
||||||
|
prompts: List[str] = Field(
|
||||||
|
...,
|
||||||
|
min_items=1,
|
||||||
|
min_length=1,
|
||||||
|
description="Prompt to check",
|
||||||
|
)
|
||||||
|
detectors: List[Detector] = Field(
|
||||||
|
...,
|
||||||
|
min_items=1,
|
||||||
|
description="List of detectors by which you want to check the prompt",
|
||||||
|
)
|
||||||
|
in_parallel: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Run prompt detection by the detector in parallel or sequentially",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ZenGuardTool(BaseTool):
|
||||||
|
name: str = "ZenGuard"
|
||||||
|
description: str = (
|
||||||
|
"ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails."
|
||||||
|
)
|
||||||
|
args_schema = ZenGuardInput
|
||||||
|
return_direct = True
|
||||||
|
|
||||||
|
zenguard_api_key: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
_ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/"
|
||||||
|
_ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY"
|
||||||
|
|
||||||
|
@validator("zenguard_api_key", pre=True, always=True, check_fields=False)
|
||||||
|
def set_api_key(cls, v: str) -> str:
|
||||||
|
if v is None:
|
||||||
|
v = os.getenv(cls._ZENGUARD_API_KEY_ENV_NAME)
|
||||||
|
if v is None:
|
||||||
|
raise ValidationError(
|
||||||
|
"The zenguard_api_key tool option must be set either "
|
||||||
|
"by passing zenguard_api_key to the tool or by setting "
|
||||||
|
f"the f{cls._ZENGUARD_API_KEY_ENV_NAME} environment variable"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
detectors: List[Detector],
|
||||||
|
in_parallel: bool = True,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
postfix = None
|
||||||
|
json: Optional[Dict[str, Any]] = None
|
||||||
|
if len(detectors) == 1:
|
||||||
|
postfix = self._convert_detector_to_api(detectors[0])
|
||||||
|
json = {"messages": prompts}
|
||||||
|
else:
|
||||||
|
postfix = "v1/detect"
|
||||||
|
json = {
|
||||||
|
"messages": prompts,
|
||||||
|
"in_parallel": in_parallel,
|
||||||
|
"detectors": detectors,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
self._ZENGUARD_API_URL_ROOT + postfix,
|
||||||
|
json=json,
|
||||||
|
headers={"x-api-key": self.zenguard_api_key},
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except (requests.HTTPError, requests.Timeout) as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
def _convert_detector_to_api(self, detector: Detector) -> str:
|
||||||
|
return DetectorAPI[detector.name].value
|
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_community.tools.zenguard.tools import Detector, ZenGuardTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def zenguard_tool() -> ZenGuardTool:
|
||||||
|
if os.getenv("ZENGUARD_API_KEY") is None:
|
||||||
|
raise ValueError("ZENGUARD_API_KEY is not set in enviroment varibale")
|
||||||
|
return ZenGuardTool()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_successful_response_not_detected(response: Dict[str, Any]) -> None:
|
||||||
|
assert response is not None
|
||||||
|
assert "error" not in response, f"API returned an error: {response.get('error')}"
|
||||||
|
assert response.get("is_detected") is False, f"Prompt was detected: {response}"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_detectors_response(
|
||||||
|
response: Dict[str, Any],
|
||||||
|
detectors: List[Detector],
|
||||||
|
) -> None:
|
||||||
|
assert response is not None
|
||||||
|
for detector in detectors:
|
||||||
|
common_response = next(
|
||||||
|
(
|
||||||
|
resp["common_response"]
|
||||||
|
for resp in response["responses"]
|
||||||
|
if resp["detector"] == detector.value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"err" not in common_response
|
||||||
|
), f"API returned an error: {common_response.get('err')}" # noqa: E501
|
||||||
|
assert (
|
||||||
|
common_response.get("is_detected") is False
|
||||||
|
), f"Prompt was detected: {common_response}" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_injection(zenguard_tool: ZenGuardTool) -> None:
|
||||||
|
prompt = "Simple prompt injection test"
|
||||||
|
detectors = [Detector.PROMPT_INJECTION]
|
||||||
|
response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]})
|
||||||
|
assert_successful_response_not_detected(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pii(zenguard_tool: ZenGuardTool) -> None:
|
||||||
|
prompt = "Simple PII test"
|
||||||
|
detectors = [Detector.PII]
|
||||||
|
response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]})
|
||||||
|
assert_successful_response_not_detected(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowed_topics(zenguard_tool: ZenGuardTool) -> None:
|
||||||
|
prompt = "Simple allowed topics test"
|
||||||
|
detectors = [Detector.ALLOWED_TOPICS]
|
||||||
|
response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]})
|
||||||
|
assert_successful_response_not_detected(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_banned_topics(zenguard_tool: ZenGuardTool) -> None:
|
||||||
|
prompt = "Simple banned topics test"
|
||||||
|
detectors = [Detector.BANNED_TOPICS]
|
||||||
|
response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]})
|
||||||
|
assert_successful_response_not_detected(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_keywords(zenguard_tool: ZenGuardTool) -> None:
|
||||||
|
prompt = "Simple keywords test"
|
||||||
|
detectors = [Detector.KEYWORDS]
|
||||||
|
response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]})
|
||||||
|
assert_successful_response_not_detected(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_secrets(zenguard_tool: ZenGuardTool) -> None:
|
||||||
|
prompt = "Simple secrets test"
|
||||||
|
detectors = [Detector.SECRETS]
|
||||||
|
response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]})
|
||||||
|
assert_successful_response_not_detected(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_toxicity(zenguard_tool: ZenGuardTool) -> None:
|
||||||
|
prompt = "Simple toxicity test"
|
||||||
|
detectors = [Detector.TOXICITY]
|
||||||
|
response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]})
|
||||||
|
assert_successful_response_not_detected(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_detectors(zenguard_tool: ZenGuardTool) -> None:
|
||||||
|
prompt = "Simple all detectors test"
|
||||||
|
detectors = [
|
||||||
|
Detector.ALLOWED_TOPICS,
|
||||||
|
Detector.BANNED_TOPICS,
|
||||||
|
Detector.KEYWORDS,
|
||||||
|
Detector.PII,
|
||||||
|
Detector.PROMPT_INJECTION,
|
||||||
|
Detector.SECRETS,
|
||||||
|
Detector.TOXICITY,
|
||||||
|
]
|
||||||
|
response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]})
|
||||||
|
assert_detectors_response(response, detectors)
|
Loading…
Reference in New Issue