mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
174 lines
6.6 KiB
Python
174 lines
6.6 KiB
Python
1 year ago
|
import asyncio
|
||
|
from typing import Any, Dict, Optional
|
||
|
|
||
|
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
||
|
ModerationPiiError,
|
||
|
)
|
||
|
|
||
|
|
||
|
class ComprehendPII:
|
||
|
def __init__(
|
||
|
self,
|
||
|
client: Any,
|
||
|
callback: Optional[Any] = None,
|
||
|
unique_id: Optional[str] = None,
|
||
|
chain_id: Optional[str] = None,
|
||
|
) -> None:
|
||
|
self.client = client
|
||
|
self.moderation_beacon = {
|
||
|
"moderation_chain_id": chain_id,
|
||
|
"moderation_type": "PII",
|
||
|
"moderation_status": "LABELS_NOT_FOUND",
|
||
|
}
|
||
|
self.callback = callback
|
||
|
self.unique_id = unique_id
|
||
|
|
||
|
def validate(
|
||
|
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||
|
) -> str:
|
||
|
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
||
|
BaseModerationActions,
|
||
|
)
|
||
|
|
||
|
if config:
|
||
|
action = config.get("action", BaseModerationActions.STOP)
|
||
|
if action not in [BaseModerationActions.STOP, BaseModerationActions.ALLOW]:
|
||
|
raise ValueError("Action can either be stop or allow")
|
||
|
|
||
|
return (
|
||
|
self._contains_pii(prompt_value=prompt_value, config=config)
|
||
|
if action == BaseModerationActions.STOP
|
||
|
else self._detect_pii(prompt_value=prompt_value, config=config)
|
||
|
)
|
||
|
else:
|
||
|
return self._contains_pii(prompt_value=prompt_value)
|
||
|
|
||
|
def _contains_pii(
|
||
|
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||
|
) -> str:
|
||
|
"""
|
||
|
Checks for Personally Identifiable Information (PII) labels above a
|
||
|
specified threshold.
|
||
|
|
||
|
Args:
|
||
|
prompt_value (str): The input text to be checked for PII labels.
|
||
|
config (Dict[str, Any]): Configuration for PII check and actions.
|
||
|
|
||
|
Returns:
|
||
|
str: the original prompt
|
||
|
|
||
|
Note:
|
||
|
- The provided client should be initialized with valid AWS credentials.
|
||
|
"""
|
||
|
pii_identified = self.client.contains_pii_entities(
|
||
|
Text=prompt_value, LanguageCode="en"
|
||
|
)
|
||
|
|
||
|
if self.callback and self.callback.pii_callback:
|
||
|
self.moderation_beacon["moderation_input"] = prompt_value
|
||
|
self.moderation_beacon["moderation_output"] = pii_identified
|
||
|
|
||
|
threshold = config.get("threshold", 0.5) if config else 0.5
|
||
|
pii_labels = config.get("labels", []) if config else []
|
||
|
pii_found = False
|
||
|
for entity in pii_identified["Labels"]:
|
||
|
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
|
||
|
entity["Score"] >= threshold and not pii_labels
|
||
|
):
|
||
|
pii_found = True
|
||
|
break
|
||
|
|
||
|
if self.callback and self.callback.pii_callback:
|
||
|
if pii_found:
|
||
|
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||
|
asyncio.create_task(
|
||
|
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
|
||
|
)
|
||
|
if pii_found:
|
||
|
raise ModerationPiiError
|
||
|
return prompt_value
|
||
|
|
||
|
def _detect_pii(self, prompt_value: str, config: Optional[Dict[str, Any]]) -> str:
|
||
|
"""
|
||
|
Detects and handles Personally Identifiable Information (PII) entities in the
|
||
|
given prompt text using Amazon Comprehend's detect_pii_entities API. The
|
||
|
function provides options to redact or stop processing based on the identified
|
||
|
PII entities and a provided configuration.
|
||
|
|
||
|
Args:
|
||
|
prompt_value (str): The input text to be checked for PII entities.
|
||
|
config (Dict[str, Any]): A configuration specifying how to handle
|
||
|
PII entities.
|
||
|
|
||
|
Returns:
|
||
|
str: The processed prompt text with redacted PII entities or raised
|
||
|
exceptions.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the prompt contains configured PII entities for
|
||
|
stopping processing.
|
||
|
|
||
|
Note:
|
||
|
- If PII is not found in the prompt, the original prompt is returned.
|
||
|
- The client should be initialized with valid AWS credentials.
|
||
|
"""
|
||
|
pii_identified = self.client.detect_pii_entities(
|
||
|
Text=prompt_value, LanguageCode="en"
|
||
|
)
|
||
|
|
||
|
if self.callback and self.callback.pii_callback:
|
||
|
self.moderation_beacon["moderation_input"] = prompt_value
|
||
|
self.moderation_beacon["moderation_output"] = pii_identified
|
||
|
|
||
|
if (pii_identified["Entities"]) == []:
|
||
|
if self.callback and self.callback.pii_callback:
|
||
|
asyncio.create_task(
|
||
|
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
|
||
|
)
|
||
|
return prompt_value
|
||
|
|
||
|
pii_found = False
|
||
|
if not config and pii_identified["Entities"]:
|
||
|
for entity in pii_identified["Entities"]:
|
||
|
if entity["Score"] >= 0.5:
|
||
|
pii_found = True
|
||
|
break
|
||
|
|
||
|
if self.callback and self.callback.pii_callback:
|
||
|
if pii_found:
|
||
|
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||
|
asyncio.create_task(
|
||
|
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
|
||
|
)
|
||
|
if pii_found:
|
||
|
raise ModerationPiiError
|
||
|
else:
|
||
|
threshold = config.get("threshold", 0.5) # type: ignore
|
||
|
pii_labels = config.get("labels", []) # type: ignore
|
||
|
mask_marker = config.get("mask_character", "*") # type: ignore
|
||
|
pii_found = False
|
||
|
|
||
|
for entity in pii_identified["Entities"]:
|
||
|
if (
|
||
|
pii_labels
|
||
|
and entity["Type"] in pii_labels
|
||
|
and entity["Score"] >= threshold
|
||
|
) or (not pii_labels and entity["Score"] >= threshold):
|
||
|
pii_found = True
|
||
|
char_offset_begin = entity["BeginOffset"]
|
||
|
char_offset_end = entity["EndOffset"]
|
||
|
prompt_value = (
|
||
|
prompt_value[:char_offset_begin]
|
||
|
+ mask_marker * (char_offset_end - char_offset_begin)
|
||
|
+ prompt_value[char_offset_end:]
|
||
|
)
|
||
|
|
||
|
if self.callback and self.callback.pii_callback:
|
||
|
if pii_found:
|
||
|
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||
|
asyncio.create_task(
|
||
|
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
|
||
|
)
|
||
|
|
||
|
return prompt_value
|