Comprehend Moderation 0.2 (#11730)

This PR replaces the previous `Intent` check with the new `Prompt
Safety` check. The logic and steps to enable chain moderation via the
Amazon Comprehend service, allowing you to detect and redact PII, Toxic,
and Prompt Safety information in the LLM prompt or answer remains
unchanged.
This implementation updates the code and configuration types with
respect to `Prompt Safety`.


### Usage sample

```python
from langchain_experimental.comprehend_moderation import (BaseModerationConfig, 
                                 ModerationPromptSafetyConfig, 
                                 ModerationPiiConfig, 
                                 ModerationToxicityConfig
)

pii_config = ModerationPiiConfig(
    labels=["SSN"],
    redact=True,
    mask_character="X"
)

toxicity_config = ModerationToxicityConfig(
    threshold=0.5
)

prompt_safety_config = ModerationPromptSafetyConfig(
    threshold=0.5
)

moderation_config = BaseModerationConfig(
    filters=[pii_config, toxicity_config, prompt_safety_config]
)

comp_moderation_with_config = AmazonComprehendModerationChain(
    moderation_config=moderation_config, #specify the configuration
    client=comprehend_client,            #optionally pass the Boto3 Client
    verbose=True
)

template = """Question: {question}

Answer:"""

prompt = PromptTemplate(template=template, input_variables=["question"])

responses = [
    "Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.", 
    "Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here."
]
llm = FakeListLLM(responses=responses)

llm_chain = LLMChain(prompt=prompt, llm=llm)

chain = ( 
    prompt 
    | comp_moderation_with_config 
    | {llm_chain.input_keys[0]: lambda x: x['output'] }  
    | llm_chain 
    | { "input": lambda x: x['text'] } 
    | comp_moderation_with_config 
)

try:
    response = chain.invoke({"question": "A sample SSN number looks like this 123-456-7890. Can you give me some more samples?"})
except Exception as e:
    print(str(e))
else:
    print(response['output'])

```

### Output

```python
> Entering new AmazonComprehendModerationChain chain...
Running AmazonComprehendModerationChain...
Running pii Validation...
Running toxicity Validation...
Running prompt safety Validation...

> Finished chain.


> Entering new AmazonComprehendModerationChain chain...
Running AmazonComprehendModerationChain...
Running pii Validation...
Running toxicity Validation...
Running prompt safety Validation...

> Finished chain.
Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like XXXXXXXXXXXX John Doe's phone number is (999)253-9876.
```

---------

Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Anjan Biswas <84933469+anjanvb@users.noreply.github.com>
pull/12352/head
Nikhil Jha 9 months ago committed by GitHub
parent b9410f2b6f
commit dff24285ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,7 +16,10 @@
"cell_type": "code",
"execution_count": null,
"id": "2c4236d8-4054-473d-84a4-87a4db278a62",
"metadata": {},
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"%pip install boto3 nltk"
@ -24,7 +27,33 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "9c792c3d-c601-409c-8e41-1c05a2fa0e84",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"%pip install -U langchain_experimental"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "496df413-a840-40a1-9ac0-3af7c1303476",
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"%pip install -U langchain pydantic"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f8518ad-c762-413c-b8c9-f1c211fc311d",
"metadata": {
"tags": []
@ -32,6 +61,7 @@
"outputs": [],
"source": [
"import boto3\n",
"import os\n",
"\n",
"comprehend_client = boto3.client('comprehend', region_name='us-east-1')"
]
@ -73,7 +103,6 @@
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain\n",
"from langchain.llms.fake import FakeListLLM\n",
"from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ModerationPiiError\n",
"\n",
@ -85,25 +114,22 @@
"\n",
"responses = [\n",
" \"Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.\", \n",
" \"Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here.\"\n",
" # replace with your own expletive\n",
" \"Final Answer: This is a really <expletive> way of constructing a birdhouse. This is <expletive> insane to think that any birds would actually create their <expletive> nests here.\"\n",
"]\n",
"llm = FakeListLLM(responses=responses)\n",
"\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"chain = (\n",
" prompt \n",
" | comprehend_moderation \n",
" | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
" | llm_chain \n",
" | { \"input\": lambda x: x['text'] } \n",
" | {\"input\": (lambda x: x['output'] ) | llm}\n",
" | comprehend_moderation \n",
")\n",
"\n",
"try:\n",
" response = chain.invoke({\"question\": \"A sample SSN number looks like this 123-456-7890. Can you give me some more samples?\"})\n",
" response = chain.invoke({\"question\": \"A sample SSN number looks like this 123-22-3345. Can you give me some more samples?\"})\n",
"except ModerationPiiError as e:\n",
" print(e.message)\n",
" print(str(e))\n",
"else:\n",
" print(response['output'])\n"
]
@ -117,6 +143,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "bfd550e7-5012-41fa-9546-8b78ddf1c673",
"metadata": {},
@ -125,7 +152,7 @@
"\n",
"- PII (Personally Identifiable Information) checks \n",
"- Toxicity content detection\n",
"- Intention detection\n",
"- Prompt Safety detection\n",
"\n",
"Here is an example of a moderation config."
]
@ -139,46 +166,51 @@
},
"outputs": [],
"source": [
"from langchain_experimental.comprehend_moderation import BaseModerationActions, BaseModerationFilters\n",
"\n",
"moderation_config = { \n",
" \"filters\":[ \n",
" BaseModerationFilters.PII, \n",
" BaseModerationFilters.TOXICITY,\n",
" BaseModerationFilters.INTENT\n",
" ],\n",
" \"pii\":{ \n",
" \"action\": BaseModerationActions.ALLOW, \n",
" \"threshold\":0.5, \n",
" \"labels\":[\"SSN\"],\n",
" \"mask_character\": \"X\"\n",
" },\n",
" \"toxicity\":{ \n",
" \"action\": BaseModerationActions.STOP, \n",
" \"threshold\":0.5\n",
" },\n",
" \"intent\":{ \n",
" \"action\": BaseModerationActions.STOP, \n",
" \"threshold\":0.5\n",
" }\n",
"}"
"from langchain_experimental.comprehend_moderation import (BaseModerationConfig, \n",
" ModerationPromptSafetyConfig, \n",
" ModerationPiiConfig, \n",
" ModerationToxicityConfig\n",
")\n",
"\n",
"pii_config = ModerationPiiConfig(\n",
" labels=[\"SSN\"],\n",
" redact=True,\n",
" mask_character=\"X\"\n",
")\n",
"\n",
"toxicity_config = ModerationToxicityConfig(\n",
" threshold=0.5\n",
")\n",
"\n",
"prompt_safety_config = ModerationPromptSafetyConfig(\n",
" threshold=0.5\n",
")\n",
"\n",
"moderation_config = BaseModerationConfig(\n",
" filters=[pii_config, toxicity_config, prompt_safety_config]\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3634376b-5938-43df-9ed6-70ca7e99290f",
"metadata": {},
"source": [
"At the core of the configuration you have three filters specified in the `filters` key:\n",
"\n",
"1. `BaseModerationFilters.PII`\n",
"2. `BaseModerationFilters.TOXICITY`\n",
"3. `BaseModerationFilters.INTENT`\n",
"At the core of the the configuration there are three configuration models to be used\n",
"\n",
"And an `action` key that defines two possible actions for each moderation function:\n",
"- `ModerationPiiConfig` used for configuring the behavior of the PII validations. Following are the parameters it can be initialized with\n",
" - `labels` the PII entity labels. Defaults to an empty list which means that the PII validation will consider all PII entities.\n",
" - `threshold` the confidence threshold for the detected entities, defaults to 0.5 or 50%\n",
" - `redact` a boolean flag to enforce whether redaction should be performed on the text, defaults to `False`. When `False`, the PII validation will error out when it detects any PII entity, when set to `True` it simply redacts the PII values in the text.\n",
" - `mask_character` the character used for masking, defaults to asterisk (*)\n",
"- `ModerationToxicityConfig` used for configuring the behavior of the toxicity validations. Following are the parameters it can be initialized with\n",
" - `labels` the Toxic entity labels. Defaults to an empty list which means that the toxicity validation will consider all toxic entities. all\n",
" - `threshold` the confidence threshold for the detected entities, defaults to 0.5 or 50% \n",
"- `ModerationPromptSafetyConfig` used for configuring the behavior of the prompt safety validation\n",
" - `threshold` the confidence threshold for the the prompt safety classification, defaults to 0.5 or 50% \n",
"\n",
"1. `BaseModerationActions.ALLOW` - `allows` the prompt to pass through but masks detected PII in case of PII check. The default behavior is to run and redact all PII entities. If there is an entity specified in the `labels` field, then only those entities will go through the PII check and masked.\n",
"2. `BaseModerationActions.STOP` - `stops` the prompt from passing through to the next step in case any PII, Toxicity, or incorrect Intent is detected. The action of `BaseModerationActions.STOP` will raise a Python `Exception` essentially stopping the chain in progress.\n",
"Finally, you use the `BaseModerationConfig` to define the order in which each of these checks are to be performed. The `BaseModerationConfig` takes an optional `filters` parameter which can be a list of one or more than one of the above validation checks, as seen in the previous code block. The `BaseModerationConfig` can also be initialized with any `filters` in which case it will use all the checks with default configuration (more on this explained later).\n",
"\n",
"Using the configuration in the previous cell will perform PII checks and will allow the prompt to pass through however it will mask any SSN numbers present in either the prompt or the LLM output.\n"
]
@ -196,7 +228,20 @@
" moderation_config=moderation_config, #specify the configuration\n",
" client=comprehend_client, #optionally pass the Boto3 Client\n",
" verbose=True\n",
")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a25e6f93-765b-4f99-8c1c-929157dbd4aa",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.llms.fake import FakeListLLM\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
@ -206,23 +251,21 @@
"\n",
"responses = [\n",
" \"Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.\", \n",
" \"Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here.\"\n",
" # replace with your own expletive\n",
" \"Final Answer: This is a really <expletive> way of constructing a birdhouse. This is <expletive> insane to think that any birds would actually create their <expletive> nests here.\"\n",
"]\n",
"llm = FakeListLLM(responses=responses)\n",
"\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"chain = ( \n",
" prompt \n",
" | comp_moderation_with_config \n",
" | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
" | llm_chain \n",
" | { \"input\": lambda x: x['text'] } \n",
" | {\"input\": (lambda x: x['output'] ) | llm}\n",
" | comp_moderation_with_config \n",
")\n",
"\n",
"\n",
"try:\n",
" response = chain.invoke({\"question\": \"A sample SSN number looks like this 123-456-7890. Can you give me some more samples?\"})\n",
" response = chain.invoke({\"question\": \"A sample SSN number looks like this 123-45-7890. Can you give me some more samples?\"})\n",
"except Exception as e:\n",
" print(str(e))\n",
"else:\n",
@ -230,6 +273,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ba890681-feeb-43ca-a0d5-9c11d2d9de3e",
"metadata": {
@ -238,25 +282,25 @@
"source": [
"## Unique ID, and Moderation Callbacks\n",
"\n",
"When Amazon Comprehend moderation action is specified as `STOP`, the chain will raise one of the following exceptions-\n",
"When Amazon Comprehend moderation action identifies any of the configugred entity, the chain will raise one of the following exceptions-\n",
" - `ModerationPiiError`, for PII checks\n",
" - `ModerationToxicityError`, for Toxicity checks \n",
" - `ModerationIntentionError` for Intent checks\n",
" - `ModerationPromptSafetyError` for Prompt Safety checks\n",
"\n",
"In addition to the moderation configuration, the `AmazonComprehendModerationChain` can also be initialized with the following parameters\n",
"\n",
"- `unique_id` [Optional] a string parameter. This parameter can be used to pass any string value or ID. For example, in a chat application, you may want to keep track of abusive users, in this case, you can pass the user's username/email ID etc. This defaults to `None`.\n",
"\n",
"- `moderation_callback` [Optional] the `BaseModerationCallbackHandler` will be called asynchronously (non-blocking to the chain). Callback functions are useful when you want to perform additional actions when the moderation functions are executed, for example logging into a database, or writing a log file. You can override three functions by subclassing `BaseModerationCallbackHandler` - `on_after_pii()`, `on_after_toxicity()`, and `on_after_intent()`. Note that all three functions must be `async` functions. These callback functions receive two arguments:\n",
" - `moderation_beacon` is a dictionary that will contain information about the moderation function, the full response from the Amazon Comprehend model, a unique chain id, the moderation status, and the input string which was validated. The dictionary is of the following schema-\n",
"- `moderation_callback` [Optional] the `BaseModerationCallbackHandler` that will be called asynchronously (non-blocking to the chain). Callback functions are useful when you want to perform additional actions when the moderation functions are executed, for example logging into a database, or writing a log file. You can override three functions by subclassing `BaseModerationCallbackHandler` - `on_after_pii()`, `on_after_toxicity()`, and `on_after_prompt_safety()`. Note that all three functions must be `async` functions. These callback functions receive two arguments:\n",
" - `moderation_beacon` a dictionary that will contain information about the moderation function, the full response from Amazon Comprehend model, a unique chain id, the moderation status, and the input string which was validated. The dictionary is of the following schema-\n",
" \n",
" ```\n",
" { \n",
" 'moderation_chain_id': 'xxx-xxx-xxx', # Unique chain ID\n",
" 'moderation_type': 'Toxicity' | 'PII' | 'Intent', \n",
" 'moderation_type': 'Toxicity' | 'PII' | 'PromptSafety', \n",
" 'moderation_status': 'LABELS_FOUND' | 'LABELS_NOT_FOUND',\n",
" 'moderation_input': 'A sample SSN number looks like this 123-456-7890. Can you give me some more samples?',\n",
" 'moderation_output': {...} #Full Amazon Comprehend PII, Toxicity, or Intent Model Output\n",
" 'moderation_output': {...} #Full Amazon Comprehend PII, Toxicity, or Prompt Safety Model Output\n",
" }\n",
" ```\n",
" \n",
@ -313,7 +357,7 @@
" async def on_after_toxicity(self, output_beacon, unique_id):\n",
" pass\n",
" \n",
" async def on_after_intent(self, output_beacon, unique_id):\n",
" async def on_after_prompt_safety(self, output_beacon, unique_id):\n",
" pass\n",
" '''\n",
" \n",
@ -330,22 +374,19 @@
},
"outputs": [],
"source": [
"moderation_config = { \n",
" \"filters\": [ \n",
" BaseModerationFilters.PII, \n",
" BaseModerationFilters.TOXICITY\n",
" ],\n",
" \"pii\":{ \n",
" \"action\": BaseModerationActions.STOP, \n",
" \"threshold\":0.5, \n",
" \"labels\":[\"SSN\"], \n",
" \"mask_character\": \"X\" \n",
" },\n",
" \"toxicity\":{ \n",
" \"action\": BaseModerationActions.STOP, \n",
" \"threshold\":0.5 \n",
" }\n",
"}\n",
"pii_config = ModerationPiiConfig(\n",
" labels=[\"SSN\"],\n",
" redact=True,\n",
" mask_character=\"X\"\n",
")\n",
"\n",
"toxicity_config = ModerationToxicityConfig(\n",
" threshold=0.5\n",
")\n",
"\n",
"moderation_config = BaseModerationConfig(\n",
" filters=[pii_config, toxicity_config]\n",
")\n",
"\n",
"comp_moderation_with_config = AmazonComprehendModerationChain(\n",
" moderation_config=moderation_config, # specify the configuration\n",
@ -366,7 +407,6 @@
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain\n",
"from langchain.llms.fake import FakeListLLM\n",
"\n",
"template = \"\"\"Question: {question}\n",
@ -377,19 +417,16 @@
"\n",
"responses = [\n",
" \"Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.\", \n",
" \"Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here.\"\n",
" # replace with your own expletive\n",
" \"Final Answer: This is a really <expletive> way of constructing a birdhouse. This is <expletive> insane to think that any birds would actually create their <expletive> nests here.\"\n",
"]\n",
"\n",
"llm = FakeListLLM(responses=responses)\n",
"\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"chain = (\n",
" prompt \n",
" | comp_moderation_with_config \n",
" | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
" | llm_chain \n",
" | { \"input\": lambda x: x['text'] } \n",
" | {\"input\": (lambda x: x['output'] ) | llm}\n",
" | comp_moderation_with_config \n",
") \n",
"\n",
@ -402,6 +439,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "706454b2-2efa-4d41-abc8-ccf2b4e87822",
"metadata": {
@ -410,7 +448,7 @@
"source": [
"## `moderation_config` and moderation execution order\n",
"\n",
"If `AmazonComprehendModerationChain` is not initialized with any `moderation_config` then the default action is `STOP` and the default order of moderation check is as follows.\n",
"If `AmazonComprehendModerationChain` is not initialized with any `moderation_config` then it is initialized with the default values of `BaseModerationConfig`. If no `filters` are used then the sequence of moderation check is as follows.\n",
"\n",
"```\n",
"AmazonComprehendModerationChain\n",
@ -423,39 +461,32 @@
" ├── Callback (if available)\n",
" ├── Label Found ⟶ [Error Stop]\n",
" └── No Label Found\n",
" └──Check Intent with Stop Action\n",
" └──Check Prompt Safety with Stop Action\n",
" ├── Callback (if available)\n",
" ├── Label Found ⟶ [Error Stop]\n",
" └── No Label Found\n",
" └── Return Prompt\n",
"```\n",
"\n",
"If any of the checks raises an exception then the subsequent checks will not be performed. If a `callback` is provided in this case, then it will be called for each of the checks that have been performed. For example, in the case above, if the Chain fails due to the presence of PII then the Toxicity and Intent checks will not be performed.\n",
"If any of the check raises a validation exception then the subsequent checks will not be performed. If a `callback` is provided in this case, then it will be called for each of the checks that have been performed. For example, in the case above, if the Chain fails due to presence of PII then the Toxicity and Prompt Safety checks will not be performed.\n",
"\n",
"You can override the execution order by passing `moderation_config` and simply specifying the desired order in the `filters` key of the configuration. In case you use `moderation_config` then the order of the checks as specified in the `filters` key will be maintained. For example, in the configuration below, first Toxicity check will be performed, then PII, and finally Intent validation will be performed. In this case, `AmazonComprehendModerationChain` will perform the desired checks in the specified order with default values of each model `kwargs`.\n",
"You can override the execution order by passing `moderation_config` and simply specifying the desired order in the `filters` parameter of the `BaseModerationConfig`. In case you specify the filters, then the order of the checks as specified in the `filters` parameter will be maintained. For example, in the configuration below, first Toxicity check will be performed, then PII, and finally Prompt Safety validation will be performed. In this case, `AmazonComprehendModerationChain` will perform the desired checks in the specified order with default values of each model `kwargs`.\n",
"\n",
"```python\n",
"moderation_config = { \n",
" \"filters\":[ BaseModerationFilters.TOXICITY, \n",
" BaseModerationFilters.PII, \n",
" BaseModerationFilters.INTENT]\n",
" }\n",
"pii_check = ModerationPiiConfig()\n",
"toxicity_check = ModerationToxicityConfig()\n",
"prompt_safety_check = ModerationPromptSafetyConfig()\n",
"\n",
"moderation_config = BaseModerationConfig(filters=[toxicity_check, pii_check, prompt_safety_check])\n",
"```\n",
"\n",
"Model `kwargs` are specified by the `pii`, `toxicity`, and `intent` keys within the `moderation_config` dictionary. For example, in the `moderation_config` below, the default order of moderation is overriden and the `pii` & `toxicity` model `kwargs` have been overriden. For `intent` the chain's default `kwargs` will be used.\n",
"You can have also use more than one configuration for a specific moderation check, for example in the sample below, two consecutive PII checks are performed. First the configuration checks for any SSN, if found it would raise an error. If any SSN isn't found then it will next check if any NAME and CREDIT_DEBIT_NUMBER is present in the prompt and will mask it.\n",
"\n",
"```python\n",
" moderation_config = { \n",
" \"filters\":[ BaseModerationFilters.TOXICITY, \n",
" BaseModerationFilters.PII, \n",
" BaseModerationFilters.INTENT],\n",
" \"pii\":{ \"action\": BaseModerationActions.ALLOW, \n",
" \"threshold\":0.5, \n",
" \"labels\":[\"SSN\"], \n",
" \"mask_character\": \"X\" },\n",
" \"toxicity\":{ \"action\": BaseModerationActions.STOP, \n",
" \"threshold\":0.5 }\n",
" }\n",
"pii_check_1 = ModerationPiiConfig(labels=[\"SSN\"])\n",
"pii_check_2 = ModerationPiiConfig(labels=[\"NAME\", \"CREDIT_DEBIT_NUMBER\"], redact=True)\n",
"\n",
"moderation_config = BaseModerationConfig(filters=[pii_check_1, pii_check_2])\n",
"```\n",
"\n",
"1. For a list of PII labels see Amazon Comprehend Universal PII entity types - https://docs.aws.amazon.com/comprehend/latest/dg/how-pii.html#how-pii-types\n",
@ -467,10 +498,11 @@
" - `VIOLENCE_OR_THREAT`: Speech that includes threats which seek to inflict pain, injury or hostility towards a person or group.\n",
" - `INSULT`: Speech that includes demeaning, humiliating, mocking, insulting, or belittling language.\n",
" - `PROFANITY`: Speech that contains words, phrases or acronyms that are impolite, vulgar, or offensive is considered as profane.\n",
"3. For a list of Intent labels refer to documentation [link here]"
"3. For a list of Prompt Safety labels refer to documentation [link here]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "78905aec-55ae-4fc3-a23b-8a69bd1e33f2",
"metadata": {},
@ -504,7 +536,8 @@
},
"outputs": [],
"source": [
"%env HUGGINGFACEHUB_API_TOKEN=\"<HUGGINGFACEHUB_API_TOKEN>\""
"import os\n",
"os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = \"<YOUR HF TOKEN HERE>\""
]
},
{
@ -517,7 +550,7 @@
"outputs": [],
"source": [
"# See https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads for some other options\n",
"repo_id = \"google/flan-t5-xxl\" \n"
"repo_id = \"google/flan-t5-xxl\" "
]
},
{
@ -529,20 +562,15 @@
},
"outputs": [],
"source": [
"from langchain.llms import HuggingFaceHub\n",
"from langchain import HuggingFaceHub\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer:\"\"\"\n",
"template = \"\"\"{question}\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
"\n",
"llm = HuggingFaceHub(\n",
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 256}\n",
")\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
")"
]
},
{
@ -562,22 +590,41 @@
},
"outputs": [],
"source": [
"moderation_config = { \n",
" \"filters\":[ BaseModerationFilters.PII, BaseModerationFilters.TOXICITY, BaseModerationFilters.INTENT ],\n",
" \"pii\":{\"action\": BaseModerationActions.ALLOW, \"threshold\":0.5, \"labels\":[\"SSN\",\"CREDIT_DEBIT_NUMBER\"], \"mask_character\": \"X\"},\n",
" \"toxicity\":{\"action\": BaseModerationActions.STOP, \"threshold\":0.5},\n",
" \"intent\":{\"action\": BaseModerationActions.ALLOW, \"threshold\":0.5,},\n",
" }\n",
"\n",
"# without any callback\n",
"amazon_comp_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config, \n",
"\n",
"# define filter configs\n",
"pii_config = ModerationPiiConfig(\n",
" labels=[\"SSN\", \"CREDIT_DEBIT_NUMBER\"],\n",
" redact=True,\n",
" mask_character=\"X\"\n",
")\n",
"\n",
"toxicity_config = ModerationToxicityConfig(\n",
" threshold=0.5\n",
")\n",
"\n",
"prompt_safety_config = ModerationPromptSafetyConfig(\n",
" threshold=0.8\n",
")\n",
"\n",
"# define different moderation configs using the filter configs above\n",
"moderation_config_1 = BaseModerationConfig(\n",
" filters=[pii_config, toxicity_config, prompt_safety_config]\n",
")\n",
"\n",
"moderation_config_2 = BaseModerationConfig(\n",
" filters=[pii_config]\n",
")\n",
"\n",
"\n",
"# input prompt moderation chain with callback\n",
"amazon_comp_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config_1, \n",
" client=comprehend_client,\n",
" moderation_callback=my_callback,\n",
" verbose=True)\n",
"\n",
"# with callback\n",
"amazon_comp_moderation_out = AmazonComprehendModerationChain(moderation_config=moderation_config, \n",
"# Output from LLM moderation chain without callback\n",
"amazon_comp_moderation_out = AmazonComprehendModerationChain(moderation_config=moderation_config_2, \n",
" client=comprehend_client,\n",
" moderation_callback=my_callback,\n",
" verbose=True)"
]
},
@ -586,7 +633,7 @@
"id": "b1256bc8-1321-4624-9e8a-a2d4a8df59bf",
"metadata": {},
"source": [
"The `moderation_config` will now prevent any inputs and model outputs containing obscene words or sentences, bad intent, or PII with entities other than SSN with score above threshold or 0.5 or 50%. If it finds Pii entities - SSN - it will redact them before allowing the call to proceed. "
"The `moderation_config` will now prevent any inputs containing obscene words or sentences, bad intent, or PII with entities other than SSN with score above threshold or 0.5 or 50%. If it finds Pii entities - SSN - it will redact them before allowing the call to proceed. It will also mask any SSN or credit card numbers from the model's response."
]
},
{
@ -601,14 +648,15 @@
"chain = (\n",
" prompt \n",
" | amazon_comp_moderation \n",
" | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
" | llm_chain \n",
" | { \"input\": lambda x: x['text'] } \n",
" | { \"input\" : (lambda x: x['output']) | llm }\n",
" | amazon_comp_moderation_out\n",
")\n",
"\n",
"try:\n",
" response = chain.invoke({\"question\": \"My AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0008 has 24$ due by July 31st. Can you give me some more credit car number samples?\"})\n",
" response = chain.invoke({\"question\": \"\"\"What is John Doe's address, phone number and SSN from the following text?\n",
"\n",
"John Doe, a resident of 1234 Elm Street in Springfield, recently celebrated his birthday on January 1st. Turning 43 this year, John reflected on the years gone by. He often shares memories of his younger days with his close friends through calls on his phone, (555) 123-4567. Meanwhile, during a casual evening, he received an email at johndoe@example.com reminding him of an old acquaintance's reunion. As he navigated through some old documents, he stumbled upon a paper that listed his SSN as 123-45-6789, reminding him to store it in a safer place.\n",
"\"\"\"})\n",
"except Exception as e:\n",
" print(str(e))\n",
"else:\n",
@ -624,7 +672,7 @@
"source": [
"### With Amazon SageMaker Jumpstart\n",
"\n",
"The example below shows how to use the `Amazon Comprehend Moderation chain` with an Amazon SageMaker Jumpstart hosted LLM. You should have an `Amazon SageMaker Jumpstart` hosted LLM endpoint within your AWS Account. "
"The exmaple below shows how to use Amazon Comprehend Moderation chain with an Amazon SageMaker Jumpstart hosted LLM. You should have an Amazon SageMaker Jumpstart hosted LLM endpoint within your AWS Account. Refer to [this notebook](https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text-generation-falcon.ipynb) for more on how to deploy an LLM with Amazon SageMaker Jumpstart hosted endpoints."
]
},
{
@ -634,7 +682,8 @@
"metadata": {},
"outputs": [],
"source": [
"endpoint_name = \"<SAGEMAKER_ENDPOINT_NAME>\" # replace with your SageMaker Endpoint name"
"endpoint_name = \"<SAGEMAKER_ENDPOINT_NAME>\" # replace with your SageMaker Endpoint name\n",
"region = \"<REGION>\" # replace with your SageMaker Endpoint region"
]
},
{
@ -644,10 +693,9 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import SagemakerEndpoint\n",
"from langchain import SagemakerEndpoint\n",
"from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
"from langchain.chains import LLMChain\n",
"from langchain.prompts import load_prompt, PromptTemplate\n",
"from langchain.prompts import PromptTemplate\n",
"import json\n",
"\n",
"class ContentHandler(LLMContentHandler):\n",
@ -664,23 +712,27 @@
"\n",
"content_handler = ContentHandler()\n",
"\n",
"template = \"\"\"From the following 'Document', precisely answer the 'Question'. Do not add any spurious information in your answer.\n",
"\n",
"Document: John Doe, a resident of 1234 Elm Street in Springfield, recently celebrated his birthday on January 1st. Turning 43 this year, John reflected on the years gone by. He often shares memories of his younger days with his close friends through calls on his phone, (555) 123-4567. Meanwhile, during a casual evening, he received an email at johndoe@example.com reminding him of an old acquaintance's reunion. As he navigated through some old documents, he stumbled upon a paper that listed his SSN as 123-45-6789, reminding him to store it in a safer place.\n",
"Question: {question}\n",
"Answer:\n",
"\"\"\"\n",
"\n",
"#prompt template for input text\n",
"llm_prompt = PromptTemplate(input_variables=[\"input_text\"], template=\"{input_text}\")\n",
"llm_prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
"\n",
"llm_chain = LLMChain(\n",
" llm=SagemakerEndpoint(\n",
"llm=SagemakerEndpoint(\n",
" endpoint_name=endpoint_name, \n",
" region_name='us-east-1',\n",
" model_kwargs={\"temperature\":0.97,\n",
" region_name=region,\n",
" model_kwargs={\"temperature\":0.95,\n",
" \"max_length\": 200,\n",
" \"num_return_sequences\": 3,\n",
" \"top_k\": 50,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": True},\n",
" content_handler=content_handler\n",
" ),\n",
" prompt=llm_prompt\n",
")"
" )"
]
},
{
@ -700,15 +752,37 @@
},
"outputs": [],
"source": [
"moderation_config = { \n",
" \"filters\":[ BaseModerationFilters.PII, BaseModerationFilters.TOXICITY ],\n",
" \"pii\":{\"action\": BaseModerationActions.ALLOW, \"threshold\":0.5, \"labels\":[\"SSN\"], \"mask_character\": \"X\"},\n",
" \"toxicity\":{\"action\": BaseModerationActions.STOP, \"threshold\":0.5},\n",
" \"intent\":{\"action\": BaseModerationActions.ALLOW, \"threshold\":0.5,},\n",
" }\n",
"\n",
"amazon_comp_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config, \n",
" client=comprehend_client ,\n",
"# define filter configs\n",
"pii_config = ModerationPiiConfig(\n",
" labels=[\"SSN\"],\n",
" redact=True,\n",
" mask_character=\"X\"\n",
")\n",
"\n",
"toxicity_config = ModerationToxicityConfig(\n",
" threshold=0.5\n",
")\n",
"\n",
"\n",
"# define different moderation configs using the filter configs above\n",
"moderation_config_1 = BaseModerationConfig(\n",
" filters=[pii_config, toxicity_config]\n",
")\n",
"\n",
"moderation_config_2 = BaseModerationConfig(\n",
" filters=[pii_config]\n",
")\n",
"\n",
"\n",
"# input prompt moderation chain with callback\n",
"amazon_comp_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config_1, \n",
" client=comprehend_client,\n",
" moderation_callback=my_callback,\n",
" verbose=True)\n",
"\n",
"# Output from LLM moderation chain without callback\n",
"amazon_comp_moderation_out = AmazonComprehendModerationChain(moderation_config=moderation_config_2, \n",
" client=comprehend_client,\n",
" verbose=True)"
]
},
@ -732,14 +806,12 @@
"chain = (\n",
" prompt \n",
" | amazon_comp_moderation \n",
" | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
" | llm_chain \n",
" | { \"input\": lambda x: x['text'] } \n",
" | amazon_comp_moderation \n",
" | { \"input\" : (lambda x: x['output']) | llm }\n",
" | amazon_comp_moderation_out\n",
")\n",
"\n",
"try:\n",
" response = chain.invoke({\"question\": \"My AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0008 has 24$ due by July 31st. Can you give me some more samples?\"})\n",
" response = chain.invoke({\"question\": \"What is John Doe's address, phone number and SSN?\"})\n",
"except Exception as e:\n",
" print(str(e))\n",
"else:\n",
@ -1347,7 +1419,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.6"
}
},
"nbformat": 4,

@ -7,23 +7,25 @@ from langchain_experimental.comprehend_moderation.base_moderation_callbacks impo
)
from langchain_experimental.comprehend_moderation.base_moderation_config import (
BaseModerationConfig,
ModerationIntentConfig,
ModerationPiiConfig,
ModerationPromptSafetyConfig,
ModerationToxicityConfig,
)
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
from langchain_experimental.comprehend_moderation.prompt_safety import (
ComprehendPromptSafety,
)
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
__all__ = [
"BaseModeration",
"ComprehendPII",
"ComprehendIntent",
"ComprehendPromptSafety",
"ComprehendToxicity",
"BaseModerationConfig",
"ModerationPiiConfig",
"ModerationToxicityConfig",
"ModerationIntentConfig",
"ModerationPromptSafetyConfig",
"BaseModerationCallbackHandler",
"AmazonComprehendModerationChain",
]

@ -6,8 +6,10 @@ from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import AIMessage, HumanMessage
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
from langchain_experimental.comprehend_moderation.prompt_safety import (
ComprehendPromptSafety,
)
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
@ -109,13 +111,13 @@ class BaseModeration:
def moderate(self, prompt: Any) -> str:
from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
ModerationIntentConfig,
ModerationPiiConfig,
ModerationPromptSafetyConfig,
ModerationToxicityConfig,
)
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
ModerationIntentionError,
ModerationPiiError,
ModerationPromptSafetyError,
ModerationToxicityError,
)
@ -128,7 +130,7 @@ class BaseModeration:
filter_functions = {
"pii": ComprehendPII,
"toxicity": ComprehendToxicity,
"intent": ComprehendIntent,
"prompt_safety": ComprehendPromptSafety,
}
filters = self.config.filters # type: ignore
@ -141,8 +143,8 @@ class BaseModeration:
"toxicity"
if isinstance(_filter, ModerationToxicityConfig)
else (
"intent"
if isinstance(_filter, ModerationIntentConfig)
"prompt_safety"
if isinstance(_filter, ModerationPromptSafetyConfig)
else None
)
)
@ -171,7 +173,7 @@ class BaseModeration:
f"Found Toxic content..stopping..\n{str(e)}\n"
)
raise e
except ModerationIntentionError as e:
except ModerationPromptSafetyError as e:
self._log_message_for_verbose(
f"Found Harmful intention..stopping..\n{str(e)}\n"
)

@ -11,12 +11,13 @@ class BaseModerationCallbackHandler:
BaseModerationCallbackHandler.on_after_toxicity, self.on_after_toxicity
)
and self._is_method_unchanged(
BaseModerationCallbackHandler.on_after_intent, self.on_after_intent
BaseModerationCallbackHandler.on_after_prompt_safety,
self.on_after_prompt_safety,
)
):
raise NotImplementedError(
"Subclasses must override at least one of on_after_pii(), "
"on_after_toxicity(), or on_after_intent() functions."
"on_after_toxicity(), or on_after_prompt_safety() functions."
)
def _is_method_unchanged(
@ -36,10 +37,10 @@ class BaseModerationCallbackHandler:
"""Run after Toxicity validation is complete."""
pass
async def on_after_intent(
async def on_after_prompt_safety(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after Toxicity validation is complete."""
"""Run after Prompt Safety validation is complete."""
pass
@property
@ -57,8 +58,8 @@ class BaseModerationCallbackHandler:
)
@property
def intent_callback(self) -> bool:
def prompt_safety_callback(self) -> bool:
return (
self.on_after_intent.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_intent
self.on_after_prompt_safety.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_prompt_safety
)

@ -28,24 +28,26 @@ class ModerationToxicityConfig(BaseModel):
"""List of toxic labels, defaults to `list[]`"""
class ModerationIntentConfig(BaseModel):
class ModerationPromptSafetyConfig(BaseModel):
threshold: float = 0.5
"""
Threshold for Intent classification
Threshold for Prompt Safety classification
confidence score, defaults to 0.5 i.e. 50%
"""
class BaseModerationConfig(BaseModel):
filters: List[
Union[ModerationPiiConfig, ModerationToxicityConfig, ModerationIntentConfig]
Union[
ModerationPiiConfig, ModerationToxicityConfig, ModerationPromptSafetyConfig
]
] = [
ModerationPiiConfig(),
ModerationToxicityConfig(),
ModerationIntentConfig(),
ModerationPromptSafetyConfig(),
]
"""
Filters applied to the moderation chain, defaults to
`[ModerationPiiConfig(), ModerationToxicityConfig(),
ModerationIntentConfig()]`
ModerationPromptSafetyConfig()]`
"""

@ -26,7 +26,7 @@ class ModerationToxicityError(Exception):
super().__init__(self.message)
class ModerationIntentionError(Exception):
class ModerationPromptSafetyError(Exception):
"""Exception raised if Intention entities are detected.
Attributes:
@ -35,9 +35,7 @@ class ModerationIntentionError(Exception):
def __init__(
self,
message: str = (
"The prompt indicates an un-desired intent and " "cannot be processed"
),
message: str = ("The prompt is unsafe and cannot be processed"),
):
self.message = message
super().__init__(self.message)

@ -2,11 +2,11 @@ import asyncio
from typing import Any, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationIntentionError,
ModerationPromptSafetyError,
)
class ComprehendIntent:
class ComprehendPromptSafety:
def __init__(
self,
client: Any,
@ -17,7 +17,7 @@ class ComprehendIntent:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "Intent",
"moderation_type": "PromptSafety",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
@ -26,62 +26,62 @@ class ComprehendIntent:
def _get_arn(self) -> str:
region_name = self.client.meta.region_name
service = "comprehend"
intent_endpoint = "document-classifier-endpoint/prompt-intent"
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
prompt_safety_endpoint = "document-classifier-endpoint/prompt-safety"
return f"arn:aws:{service}:{region_name}:aws:{prompt_safety_endpoint}"
def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check and validate the intent of the given prompt text.
Check and validate the safety of the given prompt text.
Args:
prompt_value (str): The input text to be checked for unintended intent.
config (Dict[str, Any]): Configuration settings for intent checks.
prompt_value (str): The input text to be checked for unsafe text.
config (Dict[str, Any]): Configuration settings for prompt safety checks.
Raises:
ValueError: If unintended intent is found in the prompt text based
ValueError: If unsafe prompt is found in the prompt text based
on the specified threshold.
Returns:
str: The input prompt_value.
Note:
This function checks the intent of the provided prompt text using
Comprehend's classify_document API and raises an error if unintended
intent is detected with a score above the specified threshold.
This function checks the safety of the provided prompt text using
Comprehend's classify_document API and raises an error if unsafe
text is detected with a score above the specified threshold.
Example:
comprehend_client = boto3.client('comprehend')
prompt_text = "Please tell me your credit card information."
config = {"threshold": 0.7}
checked_prompt = check_intent(comprehend_client, prompt_text, config)
checked_prompt = check_prompt_safety(comprehend_client, prompt_text, config)
"""
threshold = config.get("threshold")
intent_found = False
unsafe_prompt = False
endpoint_arn = self._get_arn()
response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn
)
if self.callback and self.callback.intent_callback:
if self.callback and self.callback.prompt_safety_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = response
for class_result in response["Classes"]:
if (
class_result["Score"] >= threshold
and class_result["Name"] == "UNDESIRED_PROMPT"
and class_result["Name"] == "UNSAFE_PROMPT"
):
intent_found = True
unsafe_prompt = True
break
if self.callback and self.callback.intent_callback:
if intent_found:
if unsafe_prompt:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
)
if intent_found:
raise ModerationIntentionError
if unsafe_prompt:
raise ModerationPromptSafetyError
return prompt_value
Loading…
Cancel
Save