HF Injection Identifier Refactor

pull/10464/head
Bagatur 1 year ago
parent 2c656e457c
commit 0f81b3dd2f

@ -1,21 +1,28 @@
"""Tool for the identification of prompt injection attacks.""" """Tool for the identification of prompt injection attacks."""
from __future__ import annotations
from enum import Enum from typing import TYPE_CHECKING
from langchain.pydantic_v1 import Field
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from transformers import Pipeline, pipeline
if TYPE_CHECKING:
from transformers import Pipeline
class PromptInjectionModelOutput(str, Enum):
"""Output of the prompt injection model."""
LEGIT = "LEGIT" def _model_default_factory() -> Pipeline:
INJECTION = "INJECTION" try:
from transformers import pipeline
except ImportError as e:
raise ImportError(
"Cannot import transformers, please install with "
"`pip install transformers`."
) from e
return pipeline("text-classification", model="deepset/deberta-v3-base-injection")
class HuggingFaceInjectionIdentifier(BaseTool): class HuggingFaceInjectionIdentifier(BaseTool):
"""Tool that uses deberta-v3-base-injection model """Tool that uses deberta-v3-base-injection to detect prompt injection attacks."""
to identify prompt injection attacks."""
name: str = "hugging_face_injection_identifier" name: str = "hugging_face_injection_identifier"
description: str = ( description: str = (
@ -23,21 +30,12 @@ class HuggingFaceInjectionIdentifier(BaseTool):
"Useful for when you need to ensure that prompt is free of injection attacks. " "Useful for when you need to ensure that prompt is free of injection attacks. "
"Input should be any message from the user." "Input should be any message from the user."
) )
model: Pipeline = Field(default_factory=_model_default_factory)
model: Pipeline = pipeline(
"text-classification", model="deepset/deberta-v3-base-injection"
)
def _classify_user_input(self, query: str) -> bool:
result = self.model(query)
result = sorted(result, key=lambda x: x["score"], reverse=True)
if result[0]["label"] == PromptInjectionModelOutput.INJECTION:
return False
return True
def _run(self, query: str) -> str: def _run(self, query: str) -> str:
"""Use the tool.""" """Use the tool."""
is_query_safe = self._classify_user_input(query) result = self.model(query)
if not is_query_safe: result = sorted(result, key=lambda x: x["score"], reverse=True)
if result[0]["label"] == "INJECTION":
raise ValueError("Prompt injection attack detected") raise ValueError("Prompt injection attack detected")
return query return query

Loading…
Cancel
Save