langchain/libs/experimental/langchain_experimental/comprehend_moderation/intent.py
nikhilkjha d57d08fd01
Initial commit for comprehend moderator (#9665)
This PR implements a custom chain that wraps Amazon Comprehend API
calls. The custom chain is aimed to be used with LLM chains to provide
moderation capability that let’s you detect and redact PII, Toxic and
Intent content in the LLM prompt, or the LLM response. The
implementation accepts a configuration object to control what checks
will be performed on a LLM prompt and can be used in a variety of setups
using the LangChain expression language to not only detect the
configured info in chains, but also other constructs such as a
retriever.
The included sample notebook goes over the different configuration
options and how to use it with other chains.

###  Usage sample
```python
from langchain_experimental.comprehend_moderation import BaseModerationActions, BaseModerationFilters

moderation_config = { 
        "filters":[ 
                BaseModerationFilters.PII, 
                BaseModerationFilters.TOXICITY,
                BaseModerationFilters.INTENT
        ],
        "pii":{ 
                "action": BaseModerationActions.ALLOW, 
                "threshold":0.5, 
                "labels":["SSN"],
                "mask_character": "X"
        },
        "toxicity":{ 
                "action": BaseModerationActions.STOP, 
                "threshold":0.5
        },
        "intent":{ 
                "action": BaseModerationActions.STOP, 
                "threshold":0.5
        }
}

comp_moderation_with_config = AmazonComprehendModerationChain(
    moderation_config=moderation_config, #specify the configuration
    client=comprehend_client,            #optionally pass the Boto3 Client
    verbose=True
)

template = """Question: {question}

Answer:"""

prompt = PromptTemplate(template=template, input_variables=["question"])

responses = [
    "Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.", 
    "Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here."
]
llm = FakeListLLM(responses=responses)

llm_chain = LLMChain(prompt=prompt, llm=llm)

chain = ( 
    prompt 
    | comp_moderation_with_config 
    | {llm_chain.input_keys[0]: lambda x: x['output'] }  
    | llm_chain 
    | { "input": lambda x: x['text'] } 
    | comp_moderation_with_config 
)

response = chain.invoke({"question": "A sample SSN number looks like this 123-456-7890. Can you give me some more samples?"})

print(response['output'])


```
### Output
```
> Entering new AmazonComprehendModerationChain chain...
Running AmazonComprehendModerationChain...
Running pii validation...
Found PII content..stopping..
The prompt contains PII entities and cannot be processed
```

---------

Co-authored-by: Piyush Jain <piyushjain@duck.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-08-25 15:11:27 -07:00

102 lines
3.4 KiB
Python

import asyncio
import warnings
from typing import Any, Dict, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationIntentionError,
)
class ComprehendIntent:
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "Intent",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _get_arn(self) -> str:
region_name = self.client.meta.region_name
service = "comprehend"
intent_endpoint = "document-classifier-endpoint/prompt-intent"
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
def validate(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
"""
Check and validate the intent of the given prompt text.
Args:
comprehend_client: Comprehend client for intent classification
prompt_value (str): The input text to be checked for unintended intent
config (Dict[str, Any]): Configuration settings for intent checks
Raises:
ValueError: If unintended intent is found in the prompt text based
on the specified threshold.
Returns:
str: The input prompt_value.
Note:
This function checks the intent of the provided prompt text using
Comprehend's classify_document API and raises an error if unintended
intent is detected with a score above the specified threshold.
"""
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
BaseModerationActions,
)
threshold = config.get("threshold", 0.5) if config else 0.5
action = (
config.get("action", BaseModerationActions.STOP)
if config
else BaseModerationActions.STOP
)
intent_found = False
if action == BaseModerationActions.ALLOW:
warnings.warn(
"You have allowed content with Harmful content."
"Defaulting to STOP action..."
)
action = BaseModerationActions.STOP
endpoint_arn = self._get_arn()
response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn
)
if self.callback and self.callback.intent_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = response
for class_result in response["Classes"]:
if (
class_result["Score"] >= threshold
and class_result["Name"] == "UNDESIRED_PROMPT"
):
intent_found = True
break
if self.callback and self.callback.intent_callback:
if intent_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
)
if intent_found:
raise ModerationIntentionError
return prompt_value