|
|
@ -1,21 +1,28 @@
|
|
|
|
"""Tool for the identification of prompt injection attacks."""
|
|
|
|
"""Tool for the identification of prompt injection attacks."""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
from enum import Enum
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.pydantic_v1 import Field
|
|
|
|
from langchain.tools.base import BaseTool
|
|
|
|
from langchain.tools.base import BaseTool
|
|
|
|
from transformers import Pipeline, pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
|
|
|
from transformers import Pipeline
|
|
|
|
|
|
|
|
|
|
|
|
class PromptInjectionModelOutput(str, Enum):
|
|
|
|
|
|
|
|
"""Output of the prompt injection model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LEGIT = "LEGIT"
|
|
|
|
def _model_default_factory() -> Pipeline:
|
|
|
|
INJECTION = "INJECTION"
|
|
|
|
try:
|
|
|
|
|
|
|
|
from transformers import pipeline
|
|
|
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
|
|
|
raise ImportError(
|
|
|
|
|
|
|
|
"Cannot import transformers, please install with "
|
|
|
|
|
|
|
|
"`pip install transformers`."
|
|
|
|
|
|
|
|
) from e
|
|
|
|
|
|
|
|
return pipeline("text-classification", model="deepset/deberta-v3-base-injection")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceInjectionIdentifier(BaseTool):
|
|
|
|
class HuggingFaceInjectionIdentifier(BaseTool):
|
|
|
|
"""Tool that uses deberta-v3-base-injection model
|
|
|
|
"""Tool that uses deberta-v3-base-injection to detect prompt injection attacks."""
|
|
|
|
to identify prompt injection attacks."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
name: str = "hugging_face_injection_identifier"
|
|
|
|
name: str = "hugging_face_injection_identifier"
|
|
|
|
description: str = (
|
|
|
|
description: str = (
|
|
|
@ -23,21 +30,12 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
|
|
|
"Useful for when you need to ensure that prompt is free of injection attacks. "
|
|
|
|
"Useful for when you need to ensure that prompt is free of injection attacks. "
|
|
|
|
"Input should be any message from the user."
|
|
|
|
"Input should be any message from the user."
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
model: Pipeline = Field(default_factory=_model_default_factory)
|
|
|
|
model: Pipeline = pipeline(
|
|
|
|
|
|
|
|
"text-classification", model="deepset/deberta-v3-base-injection"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _classify_user_input(self, query: str) -> bool:
|
|
|
|
|
|
|
|
result = self.model(query)
|
|
|
|
|
|
|
|
result = sorted(result, key=lambda x: x["score"], reverse=True)
|
|
|
|
|
|
|
|
if result[0]["label"] == PromptInjectionModelOutput.INJECTION:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run(self, query: str) -> str:
|
|
|
|
def _run(self, query: str) -> str:
|
|
|
|
"""Use the tool."""
|
|
|
|
"""Use the tool."""
|
|
|
|
is_query_safe = self._classify_user_input(query)
|
|
|
|
result = self.model(query)
|
|
|
|
if not is_query_safe:
|
|
|
|
result = sorted(result, key=lambda x: x["score"], reverse=True)
|
|
|
|
|
|
|
|
if result[0]["label"] == "INJECTION":
|
|
|
|
raise ValueError("Prompt injection attack detected")
|
|
|
|
raise ValueError("Prompt injection attack detected")
|
|
|
|
return query
|
|
|
|
return query
|
|
|
|