mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
d57d08fd01
This PR implements a custom chain that wraps Amazon Comprehend API calls. The custom chain is aimed to be used with LLM chains to provide moderation capability that let’s you detect and redact PII, Toxic and Intent content in the LLM prompt, or the LLM response. The implementation accepts a configuration object to control what checks will be performed on a LLM prompt and can be used in a variety of setups using the LangChain expression language to not only detect the configured info in chains, but also other constructs such as a retriever. The included sample notebook goes over the different configuration options and how to use it with other chains. ### Usage sample ```python from langchain_experimental.comprehend_moderation import BaseModerationActions, BaseModerationFilters moderation_config = { "filters":[ BaseModerationFilters.PII, BaseModerationFilters.TOXICITY, BaseModerationFilters.INTENT ], "pii":{ "action": BaseModerationActions.ALLOW, "threshold":0.5, "labels":["SSN"], "mask_character": "X" }, "toxicity":{ "action": BaseModerationActions.STOP, "threshold":0.5 }, "intent":{ "action": BaseModerationActions.STOP, "threshold":0.5 } } comp_moderation_with_config = AmazonComprehendModerationChain( moderation_config=moderation_config, #specify the configuration client=comprehend_client, #optionally pass the Boto3 Client verbose=True ) template = """Question: {question} Answer:""" prompt = PromptTemplate(template=template, input_variables=["question"]) responses = [ "Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.", "Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here." ] llm = FakeListLLM(responses=responses) llm_chain = LLMChain(prompt=prompt, llm=llm) chain = ( prompt | comp_moderation_with_config | {llm_chain.input_keys[0]: lambda x: x['output'] } | llm_chain | { "input": lambda x: x['text'] } | comp_moderation_with_config ) response = chain.invoke({"question": "A sample SSN number looks like this 123-456-7890. Can you give me some more samples?"}) print(response['output']) ``` ### Output ``` > Entering new AmazonComprehendModerationChain chain... Running AmazonComprehendModerationChain... Running pii validation... Found PII content..stopping.. The prompt contains PII entities and cannot be processed ``` --------- Co-authored-by: Piyush Jain <piyushjain@duck.com> Co-authored-by: Anjan Biswas <anjanavb@amazon.com> Co-authored-by: Jha <nikjha@amazon.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
210 lines
7.9 KiB
Python
210 lines
7.9 KiB
Python
import asyncio
|
|
import importlib
|
|
import warnings
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
|
ModerationToxicityError,
|
|
)
|
|
|
|
|
|
class ComprehendToxicity:
|
|
def __init__(
|
|
self,
|
|
client: Any,
|
|
callback: Optional[Any] = None,
|
|
unique_id: Optional[str] = None,
|
|
chain_id: Optional[str] = None,
|
|
) -> None:
|
|
self.client = client
|
|
self.moderation_beacon = {
|
|
"moderation_chain_id": chain_id,
|
|
"moderation_type": "Toxicity",
|
|
"moderation_status": "LABELS_NOT_FOUND",
|
|
}
|
|
self.callback = callback
|
|
self.unique_id = unique_id
|
|
|
|
def _toxicity_init_validate(self, max_size: int) -> Any:
|
|
"""
|
|
Validate and initialize toxicity processing configuration.
|
|
|
|
Args:
|
|
max_size (int): Maximum sentence size defined in the configuration object.
|
|
|
|
Raises:
|
|
Exception: If the maximum sentence size exceeds the 5KB limit.
|
|
|
|
Note:
|
|
This function ensures that the NLTK punkt tokenizer is downloaded if not
|
|
already present.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if max_size > 1024 * 5:
|
|
raise Exception("The sentence length should not exceed 5KB.")
|
|
try:
|
|
nltk = importlib.import_module("nltk")
|
|
nltk.data.find("tokenizers/punkt")
|
|
return nltk
|
|
except ImportError:
|
|
raise ModuleNotFoundError(
|
|
"Could not import nltk python package. "
|
|
"Please install it with `pip install nltk`."
|
|
)
|
|
except LookupError:
|
|
nltk.download("punkt")
|
|
|
|
def _split_paragraph(
|
|
self, prompt_value: str, max_size: int = 1024 * 4
|
|
) -> List[List[str]]:
|
|
"""
|
|
Split a paragraph into chunks of sentences, respecting the maximum size limit.
|
|
|
|
Args:
|
|
paragraph (str): The input paragraph to be split into chunks
|
|
max_size (int, optional): The maximum size limit in bytes for each chunk
|
|
Defaults to 1024.
|
|
|
|
Returns:
|
|
List[List[str]]: A list of chunks, where each chunk is a list of sentences
|
|
|
|
Note:
|
|
This function validates the maximum sentence size based on service limits
|
|
using the 'toxicity_init_validate' function. It uses the NLTK sentence
|
|
tokenizer to split the paragraph into sentences.
|
|
|
|
"""
|
|
|
|
# validate max. sentence size based on Service limits
|
|
nltk = self._toxicity_init_validate(max_size)
|
|
|
|
sentences = nltk.sent_tokenize(prompt_value)
|
|
|
|
chunks = []
|
|
current_chunk = [] # type: ignore
|
|
current_size = 0
|
|
|
|
for sentence in sentences:
|
|
sentence_size = len(sentence.encode("utf-8"))
|
|
|
|
# If adding a new sentence exceeds max_size or
|
|
# current_chunk has 10 sentences, start a new chunk
|
|
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
|
|
if current_chunk: # Avoid appending empty chunks
|
|
chunks.append(current_chunk)
|
|
current_chunk = []
|
|
current_size = 0
|
|
|
|
current_chunk.append(sentence)
|
|
current_size += sentence_size
|
|
|
|
# Add any remaining sentences
|
|
if current_chunk:
|
|
chunks.append(current_chunk)
|
|
|
|
return chunks
|
|
|
|
def validate(
|
|
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
|
) -> str:
|
|
"""
|
|
Check the toxicity of a given text prompt using AWS Comprehend service
|
|
and apply actions based on configuration.
|
|
|
|
Args:
|
|
prompt_value (str): The text content to be checked for toxicity.
|
|
config (Dict[str, Any]): Configuration for toxicity checks and actions.
|
|
|
|
Returns:
|
|
str: The original prompt_value if allowed or no toxicity found.
|
|
|
|
Raises:
|
|
ValueError: If the prompt contains toxic labels and cannot be
|
|
processed based on the configuration.
|
|
"""
|
|
|
|
chunks = self._split_paragraph(prompt_value=prompt_value)
|
|
for sentence_list in chunks:
|
|
segments = [{"Text": sentence} for sentence in sentence_list]
|
|
response = self.client.detect_toxic_content(
|
|
TextSegments=segments, LanguageCode="en"
|
|
)
|
|
if self.callback and self.callback.toxicity_callback:
|
|
self.moderation_beacon["moderation_input"] = segments # type: ignore
|
|
self.moderation_beacon["moderation_output"] = response
|
|
|
|
if config:
|
|
from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501
|
|
BaseModerationActions,
|
|
)
|
|
|
|
toxicity_found = False
|
|
action = config.get("action", BaseModerationActions.STOP)
|
|
if action not in [
|
|
BaseModerationActions.STOP,
|
|
BaseModerationActions.ALLOW,
|
|
]:
|
|
raise ValueError("Action can either be stop or allow")
|
|
|
|
threshold = config.get("threshold", 0.5) if config else 0.5
|
|
toxicity_labels = config.get("labels", []) if config else []
|
|
|
|
if action == BaseModerationActions.STOP:
|
|
for item in response["ResultList"]:
|
|
for label in item["Labels"]:
|
|
if (
|
|
label
|
|
and (
|
|
not toxicity_labels
|
|
or label["Name"] in toxicity_labels
|
|
)
|
|
and label["Score"] >= threshold
|
|
):
|
|
toxicity_found = True
|
|
break
|
|
|
|
if action == BaseModerationActions.ALLOW:
|
|
if not toxicity_labels:
|
|
warnings.warn(
|
|
"You have allowed toxic content without specifying "
|
|
"any toxicity labels."
|
|
)
|
|
else:
|
|
for item in response["ResultList"]:
|
|
for label in item["Labels"]:
|
|
if (
|
|
label["Name"] in toxicity_labels
|
|
and label["Score"] >= threshold
|
|
):
|
|
toxicity_found = True
|
|
break
|
|
|
|
if self.callback and self.callback.toxicity_callback:
|
|
if toxicity_found:
|
|
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
|
asyncio.create_task(
|
|
self.callback.on_after_toxicity(
|
|
self.moderation_beacon, self.unique_id
|
|
)
|
|
)
|
|
if toxicity_found:
|
|
raise ModerationToxicityError
|
|
else:
|
|
if response["ResultList"]:
|
|
detected_toxic_labels = list()
|
|
for item in response["ResultList"]:
|
|
detected_toxic_labels.extend(item["Labels"])
|
|
if any(item["Score"] >= 0.5 for item in detected_toxic_labels):
|
|
if self.callback and self.callback.toxicity_callback:
|
|
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
|
asyncio.create_task(
|
|
self.callback.on_after_toxicity(
|
|
self.moderation_beacon, self.unique_id
|
|
)
|
|
)
|
|
raise ModerationToxicityError
|
|
|
|
return prompt_value
|