diff --git a/docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb b/docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb new file mode 100644 index 0000000000..614db1b885 --- /dev/null +++ b/docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb @@ -0,0 +1,1422 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "25a3f834-60b7-4c21-bfb4-ad16d30fd3f7", + "metadata": {}, + "source": [ + "# Amazon Comprehend Moderation Chain\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c4236d8-4054-473d-84a4-87a4db278a62", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install boto3 nltk" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b39ac41a", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -U langchain" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3f8518ad-c762-413c-b8c9-f1c211fc311d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import boto3\n", + "\n", + "comprehend_client = boto3.client('comprehend', region_name='us-east-1')" + ] + }, + { + "cell_type": "markdown", + "id": "d1f0ba28", + "metadata": {}, + "source": [ + "Import `AmazonComprehendModerationChain`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "74550d74-3c01-4ba7-ad32-ca66d955d001", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain_experimental.comprehend_moderation import AmazonComprehendModerationChain" + ] + }, + { + "cell_type": "markdown", + "id": "f00c338b-de9f-40e5-9295-93c9e26058e3", + "metadata": {}, + "source": [ + "Initialize an instance of the Amazon Comprehend Moderation Chain to be used with your LLM chain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cde58cc6-ff83-493a-9aed-93d755f984a7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "comprehend_moderation = AmazonComprehendModerationChain(\n", + " client=comprehend_client, #optional\n", + " verbose=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ad646d01-82d2-435a-939b-c450693857ab", + "metadata": {}, + "source": [ + "Using it with your LLM chain. \n", + "\n", + "**Note**: The example below uses the _Fake LLM_ from LangChain, but same concept could be applied to other LLMs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0efa1946-d4a9-467a-920a-a8fb78720fc2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain import PromptTemplate, LLMChain\n", + "from langchain.llms.fake import FakeListLLM\n", + "from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ModerationPiiError\n", + "\n", + "template = \"\"\"Question: {question}\n", + "\n", + "Answer:\"\"\"\n", + "\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", + "\n", + "responses = [\n", + " \"Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.\", \n", + " # replace with your own expletive\n", + " \"Final Answer: This is a really way of constructing a birdhouse. This is insane to think that any birds would actually create their nests here.\"\n", + "]\n", + "llm = FakeListLLM(responses=responses)\n", + "\n", + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "\n", + "chain = (\n", + " prompt \n", + " | comprehend_moderation \n", + " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n", + " | llm_chain \n", + " | { \"input\": lambda x: x['text'] } \n", + " | comprehend_moderation \n", + ")\n", + "\n", + "try:\n", + " response = chain.invoke({\"question\": \"A sample SSN number looks like this . Can you give me some more samples?\"})\n", + "except ModerationPiiError as e:\n", + " print(str(e))\n", + "else:\n", + " print(response['output'])\n" + ] + }, + { + "cell_type": "markdown", + "id": "6da25d96-0d96-4c01-94ae-a2ead17f10aa", + "metadata": {}, + "source": [ + "## Using `moderation_config` to customize your moderation\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "bfd550e7-5012-41fa-9546-8b78ddf1c673", + "metadata": {}, + "source": [ + "Use Amazon Comprehend Moderation with a configuration to control what moderations you wish to perform and what actions should be taken for each of them. There are three different moderations that happen when no configuration is passed as demonstrated above. These moderations are:\n", + "\n", + "- PII (Personally Identifiable Information) checks \n", + "- Toxicity content detection\n", + "- Intention detection\n", + "\n", + "Here is an example of a moderation config." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d6e8900a-44ef-4967-bde8-b88af282139d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain_experimental.comprehend_moderation import (BaseModerationConfig, \n", + " ModerationIntentConfig, \n", + " ModerationPiiConfig, \n", + " ModerationToxicityConfig\n", + ")\n", + "\n", + "pii_config = ModerationPiiConfig(\n", + " labels=[\"SSN\"],\n", + " redact=True,\n", + " mask_character=\"X\"\n", + ")\n", + "\n", + "toxicity_config = ModerationToxicityConfig(\n", + " threshold=0.5\n", + ")\n", + "\n", + "intent_config = ModerationIntentConfig(\n", + " threshold=0.5\n", + ")\n", + "\n", + "moderation_config = BaseModerationConfig(\n", + " filters=[pii_config, toxicity_config, intent_config]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3634376b-5938-43df-9ed6-70ca7e99290f", + "metadata": {}, + "source": [ + "At the core of the the configuration there are three configuration models to be used\n", + "\n", + "- `ModerationPiiConfig` used for configuring the behavior of the PII validations. Following are the parameters it can be initialized with\n", + " - `labels` the PII entity labels. Defaults to an empty list which means that the PII validation will consider all PII entities.\n", + " - `threshold` the confidence threshold for the detected entities, defaults to 0.5 or 50%\n", + " - `redact` a boolean flag to enforce whether redaction should be performed on the text, defaults to `False`. When `False`, the PII validation will error out when it detects any PII entity, when set to `True` it simply redacts the PII values in the text.\n", + " - `mask_character` the character used for masking, defaults to asterisk (*)\n", + "- `ModerationToxicityConfig` used for configuring the behavior of the toxicity validations. Following are the parameters it can be initialized with\n", + " - `labels` the Toxic entity labels. Defaults to an empty list which means that the toxicity validation will consider all toxic entities. all\n", + " - `threshold` the confidence threshold for the detected entities, defaults to 0.5 or 50% \n", + "- `ModerationIntentConfig` used for configuring the behavior of the intent validation\n", + " - `threshold` the confidence threshold for the the intent classification, defaults to 0.5 or 50% \n", + "\n", + "Finally, you use the `BaseModerationConfig` to define the order in which each of these checks are to be performed. The `BaseModerationConfig` takes an optional `filters` parameter which can be a list of one or more than one of the above validation checks, as seen in the previous code block. The `BaseModerationConfig` can also be initialized with any `filters` in which case it will use all the checks with default configuration (more on this explained later).\n", + "\n", + "Using the configuration in the previous cell will perform PII checks and will allow the prompt to pass through however it will mask any SSN numbers present in either the prompt or the LLM output.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a4f7e65-f733-4863-ae6d-34c9faffd849", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "comp_moderation_with_config = AmazonComprehendModerationChain(\n", + " moderation_config=moderation_config, #specify the configuration\n", + " client=comprehend_client, #optionally pass the Boto3 Client\n", + " verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a25e6f93-765b-4f99-8c1c-929157dbd4aa", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "template = \"\"\"Question: {question}\n", + "\n", + "Answer:\"\"\"\n", + "\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", + "\n", + "responses = [\n", + " \"Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.\", \n", + " # replace with your own expletive\n", + " \"Final Answer: This is a really way of constructing a birdhouse. This is insane to think that any birds would actually create their nests here.\"\n", + "]\n", + "llm = FakeListLLM(responses=responses)\n", + "\n", + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "\n", + "chain = ( \n", + " prompt \n", + " | comp_moderation_with_config \n", + " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n", + " | llm_chain \n", + " | { \"input\": lambda x: x['text'] } \n", + " | comp_moderation_with_config \n", + ")\n", + "\n", + "try:\n", + " response = chain.invoke({\"question\": \"A sample SSN number looks like this 123-456-7890. Can you give me some more samples?\"})\n", + "except Exception as e:\n", + " print(str(e))\n", + "else:\n", + " print(response['output'])" + ] + }, + { + "cell_type": "markdown", + "id": "ba890681-feeb-43ca-a0d5-9c11d2d9de3e", + "metadata": { + "tags": [] + }, + "source": [ + "## Unique ID, and Moderation Callbacks\n", + "---\n", + "\n", + "When Amazon Comprehend moderation action is specified as `STOP`, the chain will raise one of the following exceptions-\n", + " - `ModerationPiiError`, for PII checks\n", + " - `ModerationToxicityError`, for Toxicity checks \n", + " - `ModerationIntentionError` for Intent checks\n", + "\n", + "In addition to the moderation configuration, the `AmazonComprehendModerationChain` can also be initialized with the following parameters\n", + "\n", + "- `unique_id` [Optional] a string parameter. This parameter can be used to pass any string value or ID. For example, in a chat application you may want to keep track of abusive users, in this case you can pass the user's username/email id etc. This defaults to `None`.\n", + "\n", + "- `moderation_callback` [Optional] the `BaseModerationCallbackHandler` that will be called asynchronously (non-blocking to the chain). Callback functions are useful when you want to perform additional actions when the moderation functions are executed, for example logging into a database, or writing a log file. You can override three functions by subclassing `BaseModerationCallbackHandler` - `on_after_pii()`, `on_after_toxicity()`, and `on_after_intent()`. Note that all three functions must be `async` functions. These callback functions receive two arguments:\n", + " - `moderation_beacon` a dictionary that will contain information about the moderation function, the full response from Amazon Comprehend model, a unique chain id, the moderation status, and the input string which was validated. The dictionary is of the following schema-\n", + " \n", + " ```\n", + " { \n", + " 'moderation_chain_id': 'xxx-xxx-xxx', # Unique chain ID\n", + " 'moderation_type': 'Toxicity' | 'PII' | 'Intent', \n", + " 'moderation_status': 'LABELS_FOUND' | 'LABELS_NOT_FOUND',\n", + " 'moderation_input': 'A sample SSN number looks like this 123-456-7890. Can you give me some more samples?',\n", + " 'moderation_output': {...} #Full Amazon Comprehend PII, Toxicity, or Intent Model Output\n", + " }\n", + " ```\n", + " \n", + " - `unique_id` if passed to the `AmazonComprehendModerationChain`" + ] + }, + { + "cell_type": "markdown", + "id": "3c178835-0264-4ac6-aef4-091d2993d06c", + "metadata": {}, + "source": [ + "
NOTE: moderation_callback is different from LangChain Chain Callbacks. You can still use LangChain Chain callbacks with AmazonComprehendModerationChain via the callbacks parameter. Example:
\n", + "
\n",
+    "from langchain.callbacks.stdout import StdOutCallbackHandler\n",
+    "comp_moderation_with_config = AmazonComprehendModerationChain(verbose=True, callbacks=[StdOutCallbackHandler()])\n",
+    "
\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ec38536-8cc9-408e-860b-e4a439283643", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain_experimental.comprehend_moderation import BaseModerationCallbackHandler" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1be744c7-3f99-4165-bf7f-9c5c249bbb53", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define callback handlers by subclassing BaseModerationCallbackHandler\n", + "\n", + "class MyModCallback(BaseModerationCallbackHandler):\n", + " \n", + " async def on_after_pii(self, output_beacon, unique_id):\n", + " import json\n", + " moderation_type = output_beacon['moderation_type']\n", + " chain_id = output_beacon['moderation_chain_id']\n", + " with open(f'output-{moderation_type}-{chain_id}.json', 'w') as file:\n", + " data = { 'beacon_data': output_beacon, 'unique_id': unique_id }\n", + " json.dump(data, file)\n", + " \n", + " '''\n", + " async def on_after_toxicity(self, output_beacon, unique_id):\n", + " pass\n", + " \n", + " async def on_after_intent(self, output_beacon, unique_id):\n", + " pass\n", + " '''\n", + " \n", + "\n", + "my_callback = MyModCallback()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "362a3fe0-f09f-411e-9df1-d79b3e87510c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "pii_config = ModerationPiiConfig(\n", + " labels=[\"SSN\"],\n", + " redact=True,\n", + " mask_character=\"X\"\n", + ")\n", + "\n", + "toxicity_config = ModerationToxicityConfig(\n", + " threshold=0.5\n", + ")\n", + "\n", + "moderation_config = BaseModerationConfig(\n", + " filters=[pii_config, toxicity_config]\n", + ")\n", + "\n", + "comp_moderation_with_config = AmazonComprehendModerationChain(\n", + " moderation_config=moderation_config, # specify the configuration\n", + " client=comprehend_client, # optionally pass the Boto3 Client\n", + " unique_id='john.doe@email.com', # A unique ID\n", + " moderation_callback=my_callback, # BaseModerationCallbackHandler\n", + " verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2af07937-67ea-4738-8343-c73d4d28c2cc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain import PromptTemplate, LLMChain\n", + "from langchain.llms.fake import FakeListLLM\n", + "\n", + "template = \"\"\"Question: {question}\n", + "\n", + "Answer:\"\"\"\n", + "\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", + "\n", + "responses = [\n", + " \"Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.\", \n", + " # replace with your own expletive\n", + " \"Final Answer: This is a really way of constructing a birdhouse. This is insane to think that any birds would actually create their nests here.\"\n", + "]\n", + "\n", + "llm = FakeListLLM(responses=responses)\n", + "\n", + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "\n", + "chain = (\n", + " prompt \n", + " | comp_moderation_with_config \n", + " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n", + " | llm_chain \n", + " | { \"input\": lambda x: x['text'] } \n", + " | comp_moderation_with_config \n", + ") \n", + "\n", + "try:\n", + " response = chain.invoke({\"question\": \"A sample SSN number looks like this 123-456-7890. Can you give me some more samples?\"})\n", + "except Exception as e:\n", + " print(str(e))\n", + "else:\n", + " print(response['output'])" + ] + }, + { + "cell_type": "markdown", + "id": "706454b2-2efa-4d41-abc8-ccf2b4e87822", + "metadata": { + "tags": [] + }, + "source": [ + "## `moderation_config` and moderation execution order\n", + "---\n", + "\n", + "If `AmazonComprehendModerationChain` is not initialized with any `moderation_config` then it is initialized with the default values of `BaseModerationConfig`. If no `filters` are used then the sequence of moderation check is as follows.\n", + "\n", + "```\n", + "AmazonComprehendModerationChain\n", + "│\n", + "└──Check PII with Stop Action\n", + " ├── Callback (if available)\n", + " ├── Label Found ⟶ [Error Stop]\n", + " └── No Label Found \n", + " └──Check Toxicity with Stop Action\n", + " ├── Callback (if available)\n", + " ├── Label Found ⟶ [Error Stop]\n", + " └── No Label Found\n", + " └──Check Intent with Stop Action\n", + " ├── Callback (if available)\n", + " ├── Label Found ⟶ [Error Stop]\n", + " └── No Label Found\n", + " └── Return Prompt\n", + "```\n", + "\n", + "If any of the check raises a validation exception then the subsequent checks will not be performed. If a `callback` is provided in this case, then it will be called for each of the checks that have been performed. For example, in the case above, if the Chain fails due to presence of PII then the Toxicity and Intent checks will not be performed.\n", + "\n", + "You can override the execution order by passing `moderation_config` and simply specifying the desired order in the `filters` parameter of the `BaseModerationConfig`. In case you specify the filters, then the order of the checks as specified in the `filters` parameter will be maintained. For example, in the configuration below, first Toxicity check will be performed, then PII, and finally Intent validation will be performed. In this case, `AmazonComprehendModerationChain` will perform the desired checks in the specified order with default values of each model `kwargs`.\n", + "\n", + "```python\n", + "pii_check = ModerationPiiConfig()\n", + "toxicity_check = ModerationToxicityConfig()\n", + "intent_check = ModerationIntentConfig()\n", + "\n", + "moderation_config = BaseModerationConfig(filters=[toxicity_check, pii_check, intent_check])\n", + "```\n", + "\n", + "You can have also use more than one configuration for a specific moderation check, for example in the sample below, two consecutive PII checks are performed. First the configuration checks for any SSN, if found it would raise an error. If any SSN isn't found then it will next check if any NAME and CREDIT_DEBIT_NUMBER is present in the prompt and will mask it.\n", + "\n", + "```python\n", + "pii_check_1 = ModerationPiiConfig(labels=[\"SSN\"])\n", + "pii_check_2 = ModerationPiiConfig(labels=[\"NAME\", \"CREDIT_DEBIT_NUMBER\"], redact=True)\n", + "\n", + "moderation_config = BaseModerationConfig(filters=[pii_check_1, pii_check_2])\n", + "```\n", + "\n", + "1. For a list of PII labels see Amazon Comprehend Universal PII entity types - https://docs.aws.amazon.com/comprehend/latest/dg/how-pii.html#how-pii-types\n", + "2. Following are the list of available Toxicity labels-\n", + " - `HATE_SPEECH`: Speech that criticizes, insults, denounces or dehumanizes a person or a group on the basis of an identity, be it race, ethnicity, gender identity, religion, sexual orientation, ability, national origin, or another identity-group.\n", + " - `GRAPHIC`: Speech that uses visually descriptive, detailed and unpleasantly vivid imagery is considered as graphic. Such language is often made verbose so as to amplify an insult, discomfort or harm to the recipient.\n", + " - `HARASSMENT_OR_ABUSE`: Speech that imposes disruptive power dynamics between the speaker and hearer, regardless of intent, seeks to affect the psychological well-being of the recipient, or objectifies a person should be classified as Harassment.\n", + " - `SEXUAL`: Speech that indicates sexual interest, activity or arousal by using direct or indirect references to body parts or physical traits or sex is considered as toxic with toxicityType \"sexual\". \n", + " - `VIOLENCE_OR_THREAT`: Speech that includes threats which seek to inflict pain, injury or hostility towards a person or group.\n", + " - `INSULT`: Speech that includes demeaning, humiliating, mocking, insulting, or belittling language.\n", + " - `PROFANITY`: Speech that contains words, phrases or acronyms that are impolite, vulgar, or offensive is considered as profane.\n", + "3. For a list of Intent labels refer to documentation [link here]" + ] + }, + { + "cell_type": "markdown", + "id": "78905aec-55ae-4fc3-a23b-8a69bd1e33f2", + "metadata": {}, + "source": [ + "# Examples\n", + "---\n", + "\n", + "## With HuggingFace Hub Models\n", + "\n", + "Get your API Key from Huggingface hub - https://huggingface.co/docs/api-inference/quicktour#get-your-api-token" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "359b9627-769b-46ce-8be2-c8a5cf7728ba", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%pip install huggingface_hub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41b7ea98-ad16-4454-8f12-c03c17113a86", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b235427-cc06-4c07-874b-1f67c2d1f924", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# See https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads for some other options\n", + "repo_id = \"google/flan-t5-xxl\" " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d86e256-34fb-4c8e-8092-1a4f863a5c96", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain import HuggingFaceHub\n", + "from langchain import PromptTemplate, LLMChain\n", + "\n", + "template = \"\"\"Question: {question}\"\"\"\n", + "\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", + "llm = HuggingFaceHub(\n", + " repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 256}\n", + ")\n", + "llm_chain = LLMChain(prompt=prompt, llm=llm)" + ] + }, + { + "cell_type": "markdown", + "id": "ad603796-ad8b-4599-9022-a486f1c1b89a", + "metadata": {}, + "source": [ + "Create a configuration and initialize an Amazon Comprehend Moderation chain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "decc3409-5be5-433d-b6da-38b9e5c5ee3f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "pii_config = ModerationPiiConfig(\n", + " labels=[\"SSN\", \"CREDIT_DEBIT_NUMBER\"],\n", + " redact=True,\n", + " mask_character=\"X\"\n", + ")\n", + "\n", + "toxicity_config = ModerationToxicityConfig(\n", + " threshold=0.5\n", + ")\n", + "\n", + "intent_config = ModerationIntentConfig(\n", + " threshold=0.8\n", + ")\n", + "\n", + "moderation_config = BaseModerationConfig(\n", + " filters=[pii_config, toxicity_config, intent_config]\n", + ")\n", + "# with callback\n", + "amazon_comp_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config, \n", + " client=comprehend_client,\n", + " moderation_callback=my_callback,\n", + " verbose=True)\n", + "\n", + "# without callback\n", + "amazon_comp_moderation_out = AmazonComprehendModerationChain(moderation_config=moderation_config, \n", + " client=comprehend_client,\n", + " verbose=True)" + ] + }, + { + "cell_type": "markdown", + "id": "b1256bc8-1321-4624-9e8a-a2d4a8df59bf", + "metadata": {}, + "source": [ + "The `moderation_config` will now prevent any inputs and model outputs containing obscene words or sentences, bad intent, or PII with entities other than SSN with score above threshold or 0.5 or 50%. If it finds Pii entities - SSN - it will redact them before allowing the call to proceed. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0337becc-7c3c-483e-a55c-a225226cb9ee", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chain = (\n", + " prompt \n", + " | amazon_comp_moderation \n", + " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n", + " | llm_chain \n", + " | { \"input\": lambda x: x['text'] } \n", + " | amazon_comp_moderation_out\n", + ")\n", + "\n", + "try:\n", + " response = chain.invoke({\"question\": \"\"\"What is John Doe's address, phone number and SSN from the following text?\n", + "\n", + "John Doe, a resident of 1234 Elm Street in Springfield, recently celebrated his birthday on January 1st. Turning 43 this year, John reflected on the years gone by. He often shares memories of his younger days with his close friends through calls on his phone, (555) 123-4567. Meanwhile, during a casual evening, he received an email at johndoe@example.com reminding him of an old acquaintance's reunion. As he navigated through some old documents, he stumbled upon a paper that listed his SSN as 123-45-6789, reminding him to store it in a safer place.\n", + "\"\"\"})\n", + "except Exception as e:\n", + " print(str(e))\n", + "else:\n", + " print(response['output'])" + ] + }, + { + "cell_type": "markdown", + "id": "ee52c7b8-6526-4f68-a2b3-b5ad3cf82489", + "metadata": { + "tags": [] + }, + "source": [ + "---\n", + "## With Amazon SageMaker Jumpstart\n", + "\n", + "The exmaple below shows how to use Amazon Comprehend Moderation chain with an Amazon SageMaker Jumpstart hosted LLM. You should have an Amazon SageMaker Jumpstart hosted LLM endpoint within your AWS Account. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd49d075-bc23-4ab8-a92c-0ddbbc436c30", + "metadata": {}, + "outputs": [], + "source": [ + "endpoint_name = \"\" # replace with your SageMaker Endpoint name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5978a5e6-667d-4926-842c-d965f88e5640", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import SagemakerEndpoint\n", + "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", + "from langchain.chains import LLMChain\n", + "from langchain.prompts import load_prompt, PromptTemplate\n", + "import json\n", + "\n", + "class ContentHandler(LLMContentHandler):\n", + " content_type = \"application/json\"\n", + " accepts = \"application/json\"\n", + "\n", + " def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:\n", + " input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n", + " return input_str.encode('utf-8')\n", + " \n", + " def transform_output(self, output: bytes) -> str:\n", + " response_json = json.loads(output.read().decode(\"utf-8\"))\n", + " return response_json['generated_texts'][0]\n", + "\n", + "content_handler = ContentHandler()\n", + "\n", + "#prompt template for input text\n", + "llm_prompt = PromptTemplate(input_variables=[\"input_text\"], template=\"{input_text}\")\n", + "\n", + "llm_chain = LLMChain(\n", + " llm=SagemakerEndpoint(\n", + " endpoint_name=endpoint_name, \n", + " region_name='us-east-1',\n", + " model_kwargs={\"temperature\":0.97,\n", + " \"max_length\": 200,\n", + " \"num_return_sequences\": 3,\n", + " \"top_k\": 50,\n", + " \"top_p\": 0.95,\n", + " \"do_sample\": True},\n", + " content_handler=content_handler\n", + " ),\n", + " prompt=llm_prompt\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d577b036-99a4-47fe-9a8e-4a34aa4cd88d", + "metadata": {}, + "source": [ + "Create a configuration and initialize an Amazon Comprehend Moderation chain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "859da135-94d3-4a9c-970e-a873913592e2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "pii_config = ModerationPiiConfig(\n", + " labels=[\"SSN\"],\n", + " redact=True,\n", + " mask_character=\"X\"\n", + ")\n", + "\n", + "toxicity_config = ModerationToxicityConfig(\n", + " threshold=0.5\n", + ")\n", + "\n", + "intent_config = ModerationIntentConfig(\n", + " threshold=0.8\n", + ")\n", + "\n", + "moderation_config = BaseModerationConfig(\n", + " filters=[pii_config, toxicity_config, intent_config]\n", + ")\n", + "\n", + "amazon_comp_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config, \n", + " client=comprehend_client,\n", + " verbose=True)" + ] + }, + { + "cell_type": "markdown", + "id": "9abb191f-7a96-4077-8c30-b9ddc225bd6b", + "metadata": {}, + "source": [ + "The `moderation_config` will now prevent any inputs and model outputs containing obscene words or sentences, bad intent, or Pii with entities other than SSN with score above threshold or 0.5 or 50%. If it finds Pii entities - SSN - it will redact them before allowing the call to proceed. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6db5aa2a-9c00-42a0-8e24-c5ba39994f7d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chain = (\n", + " prompt \n", + " | amazon_comp_moderation \n", + " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n", + " | llm_chain \n", + " | { \"input\": lambda x: x['text'] } \n", + " | amazon_comp_moderation \n", + ")\n", + "\n", + "try:\n", + " response = chain.invoke({\"question\": \"\"\"What is John Doe's address, phone number and SSN from the following text?\n", + "\n", + "John Doe, a resident of 1234 Elm Street in Springfield, recently celebrated his birthday on January 1st. Turning 43 this year, John reflected on the years gone by. He often shares memories of his younger days with his close friends through calls on his phone, (555) 123-4567. Meanwhile, during a casual evening, he received an email at johndoe@example.com reminding him of an old acquaintance's reunion. As he navigated through some old documents, he stumbled upon a paper that listed his SSN as 123-45-6789, reminding him to store it in a safer place.\n", + "\"\"\"})\n", + "except Exception as e:\n", + " print(str(e))\n", + "else:\n", + " print(response['output'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7fdfedf9-1a0a-4a9f-a6b0-d9ed2dbaa5ad", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "availableInstances": [ + { + "_defaultOrder": 0, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.t3.medium", + "vcpuNum": 2 + }, + { + "_defaultOrder": 1, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.t3.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 2, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.t3.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 3, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.t3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 4, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 5, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 6, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 7, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 8, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 9, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 10, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 11, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 12, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5d.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 13, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5d.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 14, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5d.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 15, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5d.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 16, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5d.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 17, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5d.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 18, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5d.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 19, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 20, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": true, + "memoryGiB": 0, + "name": "ml.geospatial.interactive", + "supportedImageNames": [ + "sagemaker-geospatial-v1-0" + ], + "vcpuNum": 0 + }, + { + "_defaultOrder": 21, + "_isFastLaunch": true, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.c5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 22, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.c5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 23, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.c5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 24, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.c5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 25, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 72, + "name": "ml.c5.9xlarge", + "vcpuNum": 36 + }, + { + "_defaultOrder": 26, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 96, + "name": "ml.c5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 27, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 144, + "name": "ml.c5.18xlarge", + "vcpuNum": 72 + }, + { + "_defaultOrder": 28, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.c5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 29, + "_isFastLaunch": true, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g4dn.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 30, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g4dn.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 31, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g4dn.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 32, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g4dn.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 33, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g4dn.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 34, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g4dn.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 35, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 61, + "name": "ml.p3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 36, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 244, + "name": "ml.p3.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 37, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 488, + "name": "ml.p3.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 38, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.p3dn.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 39, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.r5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 40, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.r5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 41, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.r5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 42, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.r5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 43, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.r5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 44, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.r5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 45, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.r5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 46, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.r5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 47, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 48, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 49, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 50, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 51, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 52, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 53, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.g5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 54, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.g5.48xlarge", + "vcpuNum": 192 + }, + { + "_defaultOrder": 55, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 56, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4de.24xlarge", + "vcpuNum": 96 + } + ], + "instance_type": "ml.t3.medium", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/experimental/langchain_experimental/autonomous_agents/hugginggpt/task_executor.py b/libs/experimental/langchain_experimental/autonomous_agents/hugginggpt/task_executor.py index bcdea12b90..bf742d192b 100644 --- a/libs/experimental/langchain_experimental/autonomous_agents/hugginggpt/task_executor.py +++ b/libs/experimental/langchain_experimental/autonomous_agents/hugginggpt/task_executor.py @@ -111,7 +111,9 @@ class TaskExecutor: dep_task = self.id_task_map[dep_id] for k, v in task.args.items(): if f"" in v: - task.args[k].replace(f"", dep_task.result) + task.args[k] = task.args[k].replace( + f"", dep_task.result + ) def run(self) -> str: for task in self.tasks: diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/__init__.py b/libs/experimental/langchain_experimental/comprehend_moderation/__init__.py index 5e4a268631..d129f9fc7f 100644 --- a/libs/experimental/langchain_experimental/comprehend_moderation/__init__.py +++ b/libs/experimental/langchain_experimental/comprehend_moderation/__init__.py @@ -5,9 +5,11 @@ from langchain_experimental.comprehend_moderation.base_moderation import BaseMod from langchain_experimental.comprehend_moderation.base_moderation_callbacks import ( BaseModerationCallbackHandler, ) -from langchain_experimental.comprehend_moderation.base_moderation_enums import ( - BaseModerationActions, - BaseModerationFilters, +from langchain_experimental.comprehend_moderation.base_moderation_config import ( + BaseModerationConfig, + ModerationIntentConfig, + ModerationPiiConfig, + ModerationToxicityConfig, ) from langchain_experimental.comprehend_moderation.intent import ComprehendIntent from langchain_experimental.comprehend_moderation.pii import ComprehendPII @@ -15,11 +17,13 @@ from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxi __all__ = [ "BaseModeration", - "BaseModerationActions", - "BaseModerationFilters", "ComprehendPII", "ComprehendIntent", "ComprehendToxicity", + "BaseModerationConfig", + "ModerationPiiConfig", + "ModerationToxicityConfig", + "ModerationIntentConfig", "BaseModerationCallbackHandler", "AmazonComprehendModerationChain", ] diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/amazon_comprehend_moderation.py b/libs/experimental/langchain_experimental/comprehend_moderation/amazon_comprehend_moderation.py index d00520e627..100e061556 100644 --- a/libs/experimental/langchain_experimental/comprehend_moderation/amazon_comprehend_moderation.py +++ b/libs/experimental/langchain_experimental/comprehend_moderation/amazon_comprehend_moderation.py @@ -3,12 +3,13 @@ from typing import Any, Dict, List, Optional from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from langchain_experimental.comprehend_moderation.base_moderation import ( - BaseModeration, -) +from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration from langchain_experimental.comprehend_moderation.base_moderation_callbacks import ( BaseModerationCallbackHandler, ) +from langchain_experimental.comprehend_moderation.base_moderation_config import ( + BaseModerationConfig, +) from langchain_experimental.pydantic_v1 import root_validator @@ -21,10 +22,13 @@ class AmazonComprehendModerationChain(Chain): input_key: str = "input" #: :meta private: """Key used to fetch/store the input in data containers. Defaults to `input`""" - moderation_config: Optional[Dict[str, Any]] = None - """Configuration settings for moderation""" + moderation_config: BaseModerationConfig = BaseModerationConfig() + """ + Configuration settings for moderation, + defaults to BaseModerationConfig with default values + """ - client: Optional[Any] + client: Optional[Any] = None """boto3 client object for connection to Amazon Comprehend""" region_name: Optional[str] = None diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation.py b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation.py index c639112b95..3724ac081d 100644 --- a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation.py +++ b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.prompts.base import StringPromptValue @@ -15,7 +15,7 @@ class BaseModeration: def __init__( self, client: Any, - config: Optional[Dict[str, Any]] = None, + config: Optional[Any] = None, moderation_callback: Optional[Any] = None, unique_id: Optional[str] = None, run_manager: Optional[CallbackManagerForChainRun] = None, @@ -105,6 +105,11 @@ class BaseModeration: self.run_manager.on_text(message) def moderate(self, prompt: Any) -> str: + from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501 + ModerationIntentConfig, + ModerationPiiConfig, + ModerationToxicityConfig, + ) from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501 ModerationIntentionError, ModerationPiiError, @@ -115,47 +120,43 @@ class BaseModeration: # convert prompt to text input_text = self._convert_prompt_to_text(prompt=prompt) output_text = str() + # perform moderation - if self.config is None: - # In absence of config Action will default to STOP only - self._log_message_for_verbose("Running pii validation...\n") - pii_validate = self._moderation_class(moderation_class=ComprehendPII) - output_text = pii_validate(prompt_value=input_text) - - self._log_message_for_verbose("Running toxicity validation...\n") - toxicity_validate = self._moderation_class( - moderation_class=ComprehendToxicity + filter_functions = { + "pii": ComprehendPII, + "toxicity": ComprehendToxicity, + "intent": ComprehendIntent, + } + + filters = self.config.filters # type: ignore + + for _filter in filters: + filter_name = ( + "pii" + if isinstance(_filter, ModerationPiiConfig) + else ( + "toxicity" + if isinstance(_filter, ModerationToxicityConfig) + else ( + "intent" + if isinstance(_filter, ModerationIntentConfig) + else None + ) + ) ) - output_text = toxicity_validate(prompt_value=output_text) + if filter_name in filter_functions: + self._log_message_for_verbose( + f"Running {filter_name} Validation...\n" + ) + validation_fn = self._moderation_class( + moderation_class=filter_functions[filter_name] + ) + input_text = input_text if not output_text else output_text + output_text = validation_fn( + prompt_value=input_text, + config=_filter.dict(), + ) - self._log_message_for_verbose("Running intent validation...\n") - intent_validate = self._moderation_class( - moderation_class=ComprehendIntent - ) - output_text = intent_validate(prompt_value=output_text) - else: - filter_functions = { - "pii": ComprehendPII, - "toxicity": ComprehendToxicity, - "intent": ComprehendIntent, - } - filters = self.config["filters"] - for _filter in filters: - filter_name = f"{_filter}" - if filter_name in filter_functions: - self._log_message_for_verbose( - f"Running {filter_name} Validation...\n" - ) - validation_fn = self._moderation_class( - moderation_class=filter_functions[filter_name] - ) - input_text = input_text if not output_text else output_text - output_text = validation_fn( - prompt_value=input_text, - config=self.config[filter_name] - if filter_name in self.config - else None, - ) # convert text to prompt and return return self._convert_text_to_prompt(prompt=prompt, text=output_text) diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_callbacks.py b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_callbacks.py index d7fcd76a10..2e3e51db2d 100644 --- a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_callbacks.py +++ b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_callbacks.py @@ -28,19 +28,19 @@ class BaseModerationCallbackHandler: self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any ) -> None: """Run after PII validation is complete.""" - raise NotImplementedError("Subclasses should implement this async method.") + pass async def on_after_toxicity( self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any ) -> None: """Run after Toxicity validation is complete.""" - raise NotImplementedError("Subclasses should implement this async method.") + pass async def on_after_intent( self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any ) -> None: """Run after Toxicity validation is complete.""" - raise NotImplementedError("Subclasses should implement this async method.") + pass @property def pii_callback(self) -> bool: diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_config.py b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_config.py new file mode 100644 index 0000000000..7e91cb783b --- /dev/null +++ b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_config.py @@ -0,0 +1,51 @@ +from typing import List, Union + +from pydantic import BaseModel + + +class ModerationPiiConfig(BaseModel): + threshold: float = 0.5 + """Threshold for PII confidence score, defaults to 0.5 i.e. 50%""" + + labels: List[str] = [] + """ + List of PII Universal Labels. + Defaults to `list[]` + """ + + redact: bool = False + """Whether to perform redaction of detected PII entities""" + + mask_character: str = "*" + """Redaction mask character in case redact=True, defaults to asterisk (*)""" + + +class ModerationToxicityConfig(BaseModel): + threshold: float = 0.5 + """Threshold for Toxic label confidence score, defaults to 0.5 i.e. 50%""" + + labels: List[str] = [] + """List of toxic labels, defaults to `list[]`""" + + +class ModerationIntentConfig(BaseModel): + threshold: float = 0.5 + """ + Threshold for Intent classification + confidence score, defaults to 0.5 i.e. 50% + """ + + +class BaseModerationConfig(BaseModel): + filters: List[ + Union[ModerationPiiConfig, ModerationToxicityConfig, ModerationIntentConfig] + ] = [ + ModerationPiiConfig(), + ModerationToxicityConfig(), + ModerationIntentConfig(), + ] + """ + Filters applied to the moderation chain, defaults to + `[ModerationPiiConfig(), ModerationToxicityConfig(), + ModerationIntentConfig()]` + """ diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_enums.py b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_enums.py deleted file mode 100644 index aec629ebcc..0000000000 --- a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_enums.py +++ /dev/null @@ -1,12 +0,0 @@ -from enum import Enum - - -class BaseModerationActions(Enum): - STOP = 1 - ALLOW = 2 - - -class BaseModerationFilters(str, Enum): - PII = "pii" - TOXICITY = "toxicity" - INTENT = "intent" diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/intent.py b/libs/experimental/langchain_experimental/comprehend_moderation/intent.py index 761c072868..d71f54d3aa 100644 --- a/libs/experimental/langchain_experimental/comprehend_moderation/intent.py +++ b/libs/experimental/langchain_experimental/comprehend_moderation/intent.py @@ -1,6 +1,5 @@ import asyncio -import warnings -from typing import Any, Dict, Optional +from typing import Any, Optional from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( ModerationIntentionError, @@ -30,20 +29,17 @@ class ComprehendIntent: intent_endpoint = "document-classifier-endpoint/prompt-intent" return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}" - def validate( - self, prompt_value: str, config: Optional[Dict[str, Any]] = None - ) -> str: + def validate(self, prompt_value: str, config: Any = None) -> str: """ Check and validate the intent of the given prompt text. Args: - comprehend_client: Comprehend client for intent classification - prompt_value (str): The input text to be checked for unintended intent - config (Dict[str, Any]): Configuration settings for intent checks + prompt_value (str): The input text to be checked for unintended intent. + config (Dict[str, Any]): Configuration settings for intent checks. Raises: ValueError: If unintended intent is found in the prompt text based - on the specified threshold. + on the specified threshold. Returns: str: The input prompt_value. @@ -53,26 +49,16 @@ class ComprehendIntent: Comprehend's classify_document API and raises an error if unintended intent is detected with a score above the specified threshold. + Example: + comprehend_client = boto3.client('comprehend') + prompt_text = "Please tell me your credit card information." + config = {"threshold": 0.7} + checked_prompt = check_intent(comprehend_client, prompt_text, config) """ - from langchain_experimental.comprehend_moderation.base_moderation_enums import ( - BaseModerationActions, - ) - threshold = config.get("threshold", 0.5) if config else 0.5 - action = ( - config.get("action", BaseModerationActions.STOP) - if config - else BaseModerationActions.STOP - ) + threshold = config.get("threshold") intent_found = False - if action == BaseModerationActions.ALLOW: - warnings.warn( - "You have allowed content with Harmful content." - "Defaulting to STOP action..." - ) - action = BaseModerationActions.STOP - endpoint_arn = self._get_arn() response = self.client.classify_document( Text=prompt_value, EndpointArn=endpoint_arn diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/pii.py b/libs/experimental/langchain_experimental/comprehend_moderation/pii.py index 2c82b7a400..0d3e07ae64 100644 --- a/libs/experimental/langchain_experimental/comprehend_moderation/pii.py +++ b/libs/experimental/langchain_experimental/comprehend_moderation/pii.py @@ -23,33 +23,19 @@ class ComprehendPII: self.callback = callback self.unique_id = unique_id - def validate( - self, prompt_value: str, config: Optional[Dict[str, Any]] = None - ) -> str: - from langchain_experimental.comprehend_moderation.base_moderation_enums import ( - BaseModerationActions, + def validate(self, prompt_value: str, config: Any = None) -> str: + redact = config.get("redact") + return ( + self._detect_pii(prompt_value=prompt_value, config=config) + if redact + else self._contains_pii(prompt_value=prompt_value, config=config) ) - if config: - action = config.get("action", BaseModerationActions.STOP) - if action not in [BaseModerationActions.STOP, BaseModerationActions.ALLOW]: - raise ValueError("Action can either be stop or allow") - - return ( - self._contains_pii(prompt_value=prompt_value, config=config) - if action == BaseModerationActions.STOP - else self._detect_pii(prompt_value=prompt_value, config=config) - ) - else: - return self._contains_pii(prompt_value=prompt_value) - - def _contains_pii( - self, prompt_value: str, config: Optional[Dict[str, Any]] = None - ) -> str: + def _contains_pii(self, prompt_value: str, config: Any = None) -> str: """ Checks for Personally Identifiable Information (PII) labels above a - specified threshold. - + specified threshold. Uses Amazon Comprehend Contains PII Entities API. See - + https://docs.aws.amazon.com/comprehend/latest/APIReference/API_ContainsPiiEntities.html Args: prompt_value (str): The input text to be checked for PII labels. config (Dict[str, Any]): Configuration for PII check and actions. @@ -68,8 +54,8 @@ class ComprehendPII: self.moderation_beacon["moderation_input"] = prompt_value self.moderation_beacon["moderation_output"] = pii_identified - threshold = config.get("threshold", 0.5) if config else 0.5 - pii_labels = config.get("labels", []) if config else [] + threshold = config.get("threshold") + pii_labels = config.get("labels") pii_found = False for entity in pii_identified["Labels"]: if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or ( @@ -93,7 +79,8 @@ class ComprehendPII: Detects and handles Personally Identifiable Information (PII) entities in the given prompt text using Amazon Comprehend's detect_pii_entities API. The function provides options to redact or stop processing based on the identified - PII entities and a provided configuration. + PII entities and a provided configuration. Uses Amazon Comprehend Detect PII + Entities API. Args: prompt_value (str): The input text to be checked for PII entities. @@ -143,9 +130,9 @@ class ComprehendPII: if pii_found: raise ModerationPiiError else: - threshold = config.get("threshold", 0.5) # type: ignore - pii_labels = config.get("labels", []) # type: ignore - mask_marker = config.get("mask_character", "*") # type: ignore + threshold = config.get("threshold") # type: ignore + pii_labels = config.get("labels") # type: ignore + mask_marker = config.get("mask_character") # type: ignore pii_found = False for entity in pii_identified["Entities"]: @@ -157,10 +144,14 @@ class ComprehendPII: pii_found = True char_offset_begin = entity["BeginOffset"] char_offset_end = entity["EndOffset"] + + mask_length = char_offset_end - char_offset_begin + 1 + masked_part = mask_marker * mask_length + prompt_value = ( prompt_value[:char_offset_begin] - + mask_marker * (char_offset_end - char_offset_begin) - + prompt_value[char_offset_end:] + + masked_part + + prompt_value[char_offset_end + 1 :] ) if self.callback and self.callback.pii_callback: diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/toxicity.py b/libs/experimental/langchain_experimental/comprehend_moderation/toxicity.py index b66320ec55..c616e506e7 100644 --- a/libs/experimental/langchain_experimental/comprehend_moderation/toxicity.py +++ b/libs/experimental/langchain_experimental/comprehend_moderation/toxicity.py @@ -1,7 +1,6 @@ import asyncio import importlib -import warnings -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( ModerationToxicityError, @@ -30,14 +29,15 @@ class ComprehendToxicity: Validate and initialize toxicity processing configuration. Args: - max_size (int): Maximum sentence size defined in the configuration object. + max_size (int): Maximum sentence size defined in the + configuration object. Raises: Exception: If the maximum sentence size exceeds the 5KB limit. Note: - This function ensures that the NLTK punkt tokenizer is downloaded if not - already present. + This function ensures that the NLTK punkt tokenizer is downloaded + if not already present. Returns: None @@ -63,34 +63,36 @@ class ComprehendToxicity: Split a paragraph into chunks of sentences, respecting the maximum size limit. Args: - paragraph (str): The input paragraph to be split into chunks - max_size (int, optional): The maximum size limit in bytes for each chunk - Defaults to 1024. + paragraph (str): The input paragraph to be split into chunks. + max_size (int, optional): The maximum size limit in bytes for + each chunk. Defaults to 1024. Returns: - List[List[str]]: A list of chunks, where each chunk is a list of sentences + List[List[str]]: A list of chunks, where each chunk is a list + of sentences. Note: - This function validates the maximum sentence size based on service limits - using the 'toxicity_init_validate' function. It uses the NLTK sentence - tokenizer to split the paragraph into sentences. - + This function validates the maximum sentence size based on service + limits using the 'toxicity_init_validate' function. It uses the NLTK + sentence tokenizer to split the paragraph into sentences. + + Example: + paragraph = "This is a sample paragraph. It + contains multiple sentences. ..." + chunks = split_paragraph(paragraph, max_size=2048) """ # validate max. sentence size based on Service limits nltk = self._toxicity_init_validate(max_size) - sentences = nltk.sent_tokenize(prompt_value) - - chunks = [] - current_chunk = [] # type: ignore + chunks = list() # type: ignore + current_chunk = list() # type: ignore current_size = 0 for sentence in sentences: sentence_size = len(sentence.encode("utf-8")) - - # If adding a new sentence exceeds max_size or - # current_chunk has 10 sentences, start a new chunk + # If adding a new sentence exceeds max_size + # or current_chunk has 10 sentences, start a new chunk if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10): if current_chunk: # Avoid appending empty chunks chunks.append(current_chunk) @@ -103,16 +105,12 @@ class ComprehendToxicity: # Add any remaining sentences if current_chunk: chunks.append(current_chunk) - return chunks - def validate( - self, prompt_value: str, config: Optional[Dict[str, Any]] = None - ) -> str: + def validate(self, prompt_value: str, config: Any = None) -> str: """ - Check the toxicity of a given text prompt using AWS Comprehend service - and apply actions based on configuration. - + Check the toxicity of a given text prompt using AWS + Comprehend service and apply actions based on configuration. Args: prompt_value (str): The text content to be checked for toxicity. config (Dict[str, Any]): Configuration for toxicity checks and actions. @@ -122,7 +120,7 @@ class ComprehendToxicity: Raises: ValueError: If the prompt contains toxic labels and cannot be - processed based on the configuration. + processed based on the configuration. """ chunks = self._split_paragraph(prompt_value=prompt_value) @@ -134,76 +132,34 @@ class ComprehendToxicity: if self.callback and self.callback.toxicity_callback: self.moderation_beacon["moderation_input"] = segments # type: ignore self.moderation_beacon["moderation_output"] = response - - if config: - from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501 - BaseModerationActions, - ) - - toxicity_found = False - action = config.get("action", BaseModerationActions.STOP) - if action not in [ - BaseModerationActions.STOP, - BaseModerationActions.ALLOW, - ]: - raise ValueError("Action can either be stop or allow") - - threshold = config.get("threshold", 0.5) if config else 0.5 - toxicity_labels = config.get("labels", []) if config else [] - - if action == BaseModerationActions.STOP: - for item in response["ResultList"]: - for label in item["Labels"]: - if ( - label - and ( - not toxicity_labels - or label["Name"] in toxicity_labels - ) - and label["Score"] >= threshold - ): - toxicity_found = True - break - - if action == BaseModerationActions.ALLOW: - if not toxicity_labels: - warnings.warn( - "You have allowed toxic content without specifying " - "any toxicity labels." - ) - else: - for item in response["ResultList"]: - for label in item["Labels"]: - if ( - label["Name"] in toxicity_labels - and label["Score"] >= threshold - ): - toxicity_found = True - break - - if self.callback and self.callback.toxicity_callback: - if toxicity_found: - self.moderation_beacon["moderation_status"] = "LABELS_FOUND" - asyncio.create_task( - self.callback.on_after_toxicity( - self.moderation_beacon, self.unique_id - ) - ) - if toxicity_found: - raise ModerationToxicityError + toxicity_found = False + threshold = config.get("threshold") + toxicity_labels = config.get("labels") + + if not toxicity_labels: + for item in response["ResultList"]: + for label in item["Labels"]: + if label["Score"] >= threshold: + toxicity_found = True + break else: - if response["ResultList"]: - detected_toxic_labels = list() - for item in response["ResultList"]: - detected_toxic_labels.extend(item["Labels"]) - if any(item["Score"] >= 0.5 for item in detected_toxic_labels): - if self.callback and self.callback.toxicity_callback: - self.moderation_beacon["moderation_status"] = "LABELS_FOUND" - asyncio.create_task( - self.callback.on_after_toxicity( - self.moderation_beacon, self.unique_id - ) - ) - raise ModerationToxicityError + for item in response["ResultList"]: + for label in item["Labels"]: + if ( + label["Name"] in toxicity_labels + and label["Score"] >= threshold + ): + toxicity_found = True + break + if self.callback and self.callback.toxicity_callback: + if toxicity_found: + self.moderation_beacon["moderation_status"] = "LABELS_FOUND" + asyncio.create_task( + self.callback.on_after_toxicity( + self.moderation_beacon, self.unique_id + ) + ) + if toxicity_found: + raise ModerationToxicityError return prompt_value