langchain/libs/experimental/langchain_experimental/comprehend_moderation/prompt_safety.py
Nikhil Jha dff24285ea
Comprehend Moderation 0.2 (#11730)
This PR replaces the previous `Intent` check with the new `Prompt
Safety` check. The logic and steps to enable chain moderation via the
Amazon Comprehend service, allowing you to detect and redact PII, Toxic,
and Prompt Safety information in the LLM prompt or answer remains
unchanged.
This implementation updates the code and configuration types with
respect to `Prompt Safety`.


### Usage sample

```python
from langchain_experimental.comprehend_moderation import (BaseModerationConfig, 
                                 ModerationPromptSafetyConfig, 
                                 ModerationPiiConfig, 
                                 ModerationToxicityConfig
)

pii_config = ModerationPiiConfig(
    labels=["SSN"],
    redact=True,
    mask_character="X"
)

toxicity_config = ModerationToxicityConfig(
    threshold=0.5
)

prompt_safety_config = ModerationPromptSafetyConfig(
    threshold=0.5
)

moderation_config = BaseModerationConfig(
    filters=[pii_config, toxicity_config, prompt_safety_config]
)

comp_moderation_with_config = AmazonComprehendModerationChain(
    moderation_config=moderation_config, #specify the configuration
    client=comprehend_client,            #optionally pass the Boto3 Client
    verbose=True
)

template = """Question: {question}

Answer:"""

prompt = PromptTemplate(template=template, input_variables=["question"])

responses = [
    "Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.", 
    "Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here."
]
llm = FakeListLLM(responses=responses)

llm_chain = LLMChain(prompt=prompt, llm=llm)

chain = ( 
    prompt 
    | comp_moderation_with_config 
    | {llm_chain.input_keys[0]: lambda x: x['output'] }  
    | llm_chain 
    | { "input": lambda x: x['text'] } 
    | comp_moderation_with_config 
)

try:
    response = chain.invoke({"question": "A sample SSN number looks like this 123-456-7890. Can you give me some more samples?"})
except Exception as e:
    print(str(e))
else:
    print(response['output'])

```

### Output

```python
> Entering new AmazonComprehendModerationChain chain...
Running AmazonComprehendModerationChain...
Running pii Validation...
Running toxicity Validation...
Running prompt safety Validation...

> Finished chain.


> Entering new AmazonComprehendModerationChain chain...
Running AmazonComprehendModerationChain...
Running pii Validation...
Running toxicity Validation...
Running prompt safety Validation...

> Finished chain.
Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like XXXXXXXXXXXX John Doe's phone number is (999)253-9876.
```

---------

Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Anjan Biswas <84933469+anjanvb@users.noreply.github.com>
2023-10-26 09:42:18 -07:00

88 lines
3.0 KiB
Python

import asyncio
from typing import Any, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationPromptSafetyError,
)
class ComprehendPromptSafety:
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "PromptSafety",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _get_arn(self) -> str:
region_name = self.client.meta.region_name
service = "comprehend"
prompt_safety_endpoint = "document-classifier-endpoint/prompt-safety"
return f"arn:aws:{service}:{region_name}:aws:{prompt_safety_endpoint}"
def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check and validate the safety of the given prompt text.
Args:
prompt_value (str): The input text to be checked for unsafe text.
config (Dict[str, Any]): Configuration settings for prompt safety checks.
Raises:
ValueError: If unsafe prompt is found in the prompt text based
on the specified threshold.
Returns:
str: The input prompt_value.
Note:
This function checks the safety of the provided prompt text using
Comprehend's classify_document API and raises an error if unsafe
text is detected with a score above the specified threshold.
Example:
comprehend_client = boto3.client('comprehend')
prompt_text = "Please tell me your credit card information."
config = {"threshold": 0.7}
checked_prompt = check_prompt_safety(comprehend_client, prompt_text, config)
"""
threshold = config.get("threshold")
unsafe_prompt = False
endpoint_arn = self._get_arn()
response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn
)
if self.callback and self.callback.prompt_safety_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = response
for class_result in response["Classes"]:
if (
class_result["Score"] >= threshold
and class_result["Name"] == "UNSAFE_PROMPT"
):
unsafe_prompt = True
break
if self.callback and self.callback.intent_callback:
if unsafe_prompt:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
)
if unsafe_prompt:
raise ModerationPromptSafetyError
return prompt_value