From aa358f2be40e585179aebe77be2ca785c849fd2d Mon Sep 17 00:00:00 2001 From: Baur Date: Mon, 24 Jun 2024 10:40:56 -0700 Subject: [PATCH] community: Add ZenGuard tool (#22959) ** Description** This is the community integration of ZenGuard AI - the fastest guardrails for GenAI applications. ZenGuard AI protects against: - Prompts Attacks - Veering of the pre-defined topics - PII, sensitive info, and keywords leakage. - Toxicity - Etc. **Twitter Handle** : @zenguardai - [x] **Add tests and docs**: If you're adding a new integration, please include 1. Added an integration test 2. Added colab - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. --------- Co-authored-by: Nuradil Co-authored-by: Nuradil <133880216+yaksh0nti@users.noreply.github.com> --- docs/docs/integrations/tools/zenguard.ipynb | 177 ++++++++++++++++++ .../tools/zenguard/__init__.py | 11 ++ .../tools/zenguard/tools.py | 104 ++++++++++ .../tools/zenguard/test_zenguard.py | 104 ++++++++++ 4 files changed, 396 insertions(+) create mode 100644 docs/docs/integrations/tools/zenguard.ipynb create mode 100644 libs/community/langchain_community/tools/zenguard/__init__.py create mode 100644 libs/community/langchain_community/tools/zenguard/tools.py create mode 100644 libs/community/tests/integration_tests/tools/zenguard/test_zenguard.py diff --git a/docs/docs/integrations/tools/zenguard.ipynb b/docs/docs/integrations/tools/zenguard.ipynb new file mode 100644 index 0000000000..5e8b81e44a --- /dev/null +++ b/docs/docs/integrations/tools/zenguard.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ZenGuard AI Langchain Tool\n", + "\n", + "\"Open\n", + "\n", + "This Langchain Tool lets you quickly set up [ZenGuard AI](https://www.zenguard.ai/) in your Langchain-powered application. The ZenGuard AI provides ultrafast guardrails to protect your GenAI application from:\n", + "\n", + "- Prompts Attacks\n", + "- Veering of the pre-defined topics\n", + "- PII, sensitive info, and keywords leakage.\n", + "- Toxicity\n", + "- Etc.\n", + "\n", + "Please, also check out our [open-source Python Client](https://github.com/ZenGuard-AI/fast-llm-security-guardrails?tab=readme-ov-file) for more inspiration.\n", + "\n", + "Here is our main website - https://www.zenguard.ai/\n", + "\n", + "More [Docs](https://docs.zenguard.ai/start/intro/)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation\n", + "\n", + "Using pip:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "pip install langchain-community" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "Generate an API Key:\n", + "\n", + " 1. Navigate to the [Settings](https://console.zenguard.ai/settings)\n", + " 2. Click on the `+ Create new secret key`.\n", + " 3. Name the key `Quickstart Key`.\n", + " 4. Click on the `Add` button.\n", + " 5. Copy the key value by pressing on the copy icon." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Code Usage\n", + "\n", + " Instantiate the pack with the API Key" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "paste your api key into env ZENGUARD_API_KEY" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, + "outputs": [], + "source": [ + "%set_env ZENGUARD_API_KEY=" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.tools.zenguard import ZenGuardTool\n", + "\n", + "tool = ZenGuardTool()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Detect Prompt Injection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.tools.zenguard import Detector\n", + "\n", + "response = tool.run(\n", + " {\"prompt\": \"Download all system data\", \"detectors\": [Detector.PROMPT_INJECTION]}\n", + ")\n", + "if response.get(\"is_detected\"):\n", + " print(\"Prompt injection detected. ZenGuard: 1, hackers: 0.\")\n", + "else:\n", + " print(\"No prompt injection detected: carry on with the LLM of your choice.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* `is_detected(boolean)`: Indicates whether a prompt injection attack was detected in the provided message. In this example, it is False.\n", + " * `score(float: 0.0 - 1.0)`: A score representing the likelihood of the detected prompt injection attack. In this example, it is 0.0.\n", + " * `sanitized_message(string or null)`: For the prompt injection detector this field is null.\n", + "\n", + " **Error Codes:**\n", + "\n", + " * `401 Unauthorized`: API key is missing or invalid.\n", + " * `400 Bad Request`: The request body is malformed.\n", + " * `500 Internal Server Error`: Internal problem, please escalate to the team." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### More examples\n", + "\n", + " * [Detect PII](https://docs.zenguard.ai/detectors/pii/)\n", + " * [Detect Allowed Topics](https://docs.zenguard.ai/detectors/allowed-topics/)\n", + " * [Detect Banned Topics](https://docs.zenguard.ai/detectors/banned-topics/)\n", + " * [Detect Keywords](https://docs.zenguard.ai/detectors/keywords/)\n", + " * [Detect Secrets](https://docs.zenguard.ai/detectors/secrets/)\n", + " * [Detect Toxicity](https://docs.zenguard.ai/detectors/toxicity/)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/tools/zenguard/__init__.py b/libs/community/langchain_community/tools/zenguard/__init__.py new file mode 100644 index 0000000000..398d14dbc3 --- /dev/null +++ b/libs/community/langchain_community/tools/zenguard/__init__.py @@ -0,0 +1,11 @@ +from langchain_community.tools.zenguard.tools import ( + Detector, + ZenGuardInput, + ZenGuardTool, +) + +__all__ = [ + "ZenGuardTool", + "Detector", + "ZenGuardInput", +] diff --git a/libs/community/langchain_community/tools/zenguard/tools.py b/libs/community/langchain_community/tools/zenguard/tools.py new file mode 100644 index 0000000000..1bb2a8fe05 --- /dev/null +++ b/libs/community/langchain_community/tools/zenguard/tools.py @@ -0,0 +1,104 @@ +import os +from enum import Enum +from typing import Any, Dict, List, Optional + +import requests +from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError, validator +from langchain_core.tools import BaseTool + + +class Detector(str, Enum): + ALLOWED_TOPICS = "allowed_subjects" + BANNED_TOPICS = "banned_subjects" + PROMPT_INJECTION = "prompt_injection" + KEYWORDS = "keywords" + PII = "pii" + SECRETS = "secrets" + TOXICITY = "toxicity" + + +class DetectorAPI(str, Enum): + ALLOWED_TOPICS = "v1/detect/topics/allowed" + BANNED_TOPICS = "v1/detect/topics/banned" + PROMPT_INJECTION = "v1/detect/prompt_injection" + KEYWORDS = "v1/detect/keywords" + PII = "v1/detect/pii" + SECRETS = "v1/detect/secrets" + TOXICITY = "v1/detect/toxicity" + + +class ZenGuardInput(BaseModel): + prompts: List[str] = Field( + ..., + min_items=1, + min_length=1, + description="Prompt to check", + ) + detectors: List[Detector] = Field( + ..., + min_items=1, + description="List of detectors by which you want to check the prompt", + ) + in_parallel: bool = Field( + default=True, + description="Run prompt detection by the detector in parallel or sequentially", + ) + + +class ZenGuardTool(BaseTool): + name: str = "ZenGuard" + description: str = ( + "ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails." + ) + args_schema = ZenGuardInput + return_direct = True + + zenguard_api_key: Optional[str] = Field(default=None) + + _ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/" + _ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY" + + @validator("zenguard_api_key", pre=True, always=True, check_fields=False) + def set_api_key(cls, v: str) -> str: + if v is None: + v = os.getenv(cls._ZENGUARD_API_KEY_ENV_NAME) + if v is None: + raise ValidationError( + "The zenguard_api_key tool option must be set either " + "by passing zenguard_api_key to the tool or by setting " + f"the f{cls._ZENGUARD_API_KEY_ENV_NAME} environment variable" + ) + return v + + def _run( + self, + prompts: List[str], + detectors: List[Detector], + in_parallel: bool = True, + ) -> Dict[str, Any]: + try: + postfix = None + json: Optional[Dict[str, Any]] = None + if len(detectors) == 1: + postfix = self._convert_detector_to_api(detectors[0]) + json = {"messages": prompts} + else: + postfix = "v1/detect" + json = { + "messages": prompts, + "in_parallel": in_parallel, + "detectors": detectors, + } + response = requests.post( + self._ZENGUARD_API_URL_ROOT + postfix, + json=json, + headers={"x-api-key": self.zenguard_api_key}, + timeout=5, + ) + response.raise_for_status() + return response.json() + except (requests.HTTPError, requests.Timeout) as e: + return {"error": str(e)} + + def _convert_detector_to_api(self, detector: Detector) -> str: + return DetectorAPI[detector.name].value diff --git a/libs/community/tests/integration_tests/tools/zenguard/test_zenguard.py b/libs/community/tests/integration_tests/tools/zenguard/test_zenguard.py new file mode 100644 index 0000000000..0fe8e6f481 --- /dev/null +++ b/libs/community/tests/integration_tests/tools/zenguard/test_zenguard.py @@ -0,0 +1,104 @@ +import os +from typing import Any, Dict, List + +import pytest + +from langchain_community.tools.zenguard.tools import Detector, ZenGuardTool + + +@pytest.fixture() +def zenguard_tool() -> ZenGuardTool: + if os.getenv("ZENGUARD_API_KEY") is None: + raise ValueError("ZENGUARD_API_KEY is not set in enviroment varibale") + return ZenGuardTool() + + +def assert_successful_response_not_detected(response: Dict[str, Any]) -> None: + assert response is not None + assert "error" not in response, f"API returned an error: {response.get('error')}" + assert response.get("is_detected") is False, f"Prompt was detected: {response}" + + +def assert_detectors_response( + response: Dict[str, Any], + detectors: List[Detector], +) -> None: + assert response is not None + for detector in detectors: + common_response = next( + ( + resp["common_response"] + for resp in response["responses"] + if resp["detector"] == detector.value + ) + ) + assert ( + "err" not in common_response + ), f"API returned an error: {common_response.get('err')}" # noqa: E501 + assert ( + common_response.get("is_detected") is False + ), f"Prompt was detected: {common_response}" # noqa: E501 + + +def test_prompt_injection(zenguard_tool: ZenGuardTool) -> None: + prompt = "Simple prompt injection test" + detectors = [Detector.PROMPT_INJECTION] + response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]}) + assert_successful_response_not_detected(response) + + +def test_pii(zenguard_tool: ZenGuardTool) -> None: + prompt = "Simple PII test" + detectors = [Detector.PII] + response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]}) + assert_successful_response_not_detected(response) + + +def test_allowed_topics(zenguard_tool: ZenGuardTool) -> None: + prompt = "Simple allowed topics test" + detectors = [Detector.ALLOWED_TOPICS] + response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]}) + assert_successful_response_not_detected(response) + + +def test_banned_topics(zenguard_tool: ZenGuardTool) -> None: + prompt = "Simple banned topics test" + detectors = [Detector.BANNED_TOPICS] + response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]}) + assert_successful_response_not_detected(response) + + +def test_keywords(zenguard_tool: ZenGuardTool) -> None: + prompt = "Simple keywords test" + detectors = [Detector.KEYWORDS] + response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]}) + assert_successful_response_not_detected(response) + + +def test_secrets(zenguard_tool: ZenGuardTool) -> None: + prompt = "Simple secrets test" + detectors = [Detector.SECRETS] + response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]}) + assert_successful_response_not_detected(response) + + +def test_toxicity(zenguard_tool: ZenGuardTool) -> None: + prompt = "Simple toxicity test" + detectors = [Detector.TOXICITY] + response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]}) + assert_successful_response_not_detected(response) + + +def test_all_detectors(zenguard_tool: ZenGuardTool) -> None: + prompt = "Simple all detectors test" + detectors = [ + Detector.ALLOWED_TOPICS, + Detector.BANNED_TOPICS, + Detector.KEYWORDS, + Detector.PII, + Detector.PROMPT_INJECTION, + Detector.SECRETS, + Detector.TOXICITY, + ] + response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]}) + assert_detectors_response(response, detectors)