Harrison/string inplace (#10153)

Co-authored-by: Wrick Talukdar <wrick.talukdar@gmail.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Lucky-Lance <77819606+Lucky-Lance@users.noreply.github.com>
Co-authored-by: 陆徐东 <luxudong@MacBook-Pro.local>
pull/9999/head
Harrison Chase 10 months ago committed by GitHub
parent f5af756397
commit 4abe85be57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -111,7 +111,9 @@ class TaskExecutor:
dep_task = self.id_task_map[dep_id]
for k, v in task.args.items():
if f"<resource-{dep_id}>" in v:
task.args[k].replace(f"<resource-{dep_id}>", dep_task.result)
task.args[k] = task.args[k].replace(
f"<resource-{dep_id}>", dep_task.result
)
def run(self) -> str:
for task in self.tasks:

@ -5,9 +5,11 @@ from langchain_experimental.comprehend_moderation.base_moderation import BaseMod
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
BaseModerationCallbackHandler,
)
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
BaseModerationActions,
BaseModerationFilters,
from langchain_experimental.comprehend_moderation.base_moderation_config import (
BaseModerationConfig,
ModerationIntentConfig,
ModerationPiiConfig,
ModerationToxicityConfig,
)
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
@ -15,11 +17,13 @@ from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxi
__all__ = [
"BaseModeration",
"BaseModerationActions",
"BaseModerationFilters",
"ComprehendPII",
"ComprehendIntent",
"ComprehendToxicity",
"BaseModerationConfig",
"ModerationPiiConfig",
"ModerationToxicityConfig",
"ModerationIntentConfig",
"BaseModerationCallbackHandler",
"AmazonComprehendModerationChain",
]

@ -3,12 +3,13 @@ from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain_experimental.comprehend_moderation.base_moderation import (
BaseModeration,
)
from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
BaseModerationCallbackHandler,
)
from langchain_experimental.comprehend_moderation.base_moderation_config import (
BaseModerationConfig,
)
from langchain_experimental.pydantic_v1 import root_validator
@ -21,10 +22,13 @@ class AmazonComprehendModerationChain(Chain):
input_key: str = "input" #: :meta private:
"""Key used to fetch/store the input in data containers. Defaults to `input`"""
moderation_config: Optional[Dict[str, Any]] = None
"""Configuration settings for moderation"""
moderation_config: BaseModerationConfig = BaseModerationConfig()
"""
Configuration settings for moderation,
defaults to BaseModerationConfig with default values
"""
client: Optional[Any]
client: Optional[Any] = None
"""boto3 client object for connection to Amazon Comprehend"""
region_name: Optional[str] = None

@ -1,5 +1,5 @@
import uuid
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.prompts.base import StringPromptValue
@ -15,7 +15,7 @@ class BaseModeration:
def __init__(
self,
client: Any,
config: Optional[Dict[str, Any]] = None,
config: Optional[Any] = None,
moderation_callback: Optional[Any] = None,
unique_id: Optional[str] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
@ -105,6 +105,11 @@ class BaseModeration:
self.run_manager.on_text(message)
def moderate(self, prompt: Any) -> str:
from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
ModerationIntentConfig,
ModerationPiiConfig,
ModerationToxicityConfig,
)
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
ModerationIntentionError,
ModerationPiiError,
@ -115,47 +120,43 @@ class BaseModeration:
# convert prompt to text
input_text = self._convert_prompt_to_text(prompt=prompt)
output_text = str()
# perform moderation
if self.config is None:
# In absence of config Action will default to STOP only
self._log_message_for_verbose("Running pii validation...\n")
pii_validate = self._moderation_class(moderation_class=ComprehendPII)
output_text = pii_validate(prompt_value=input_text)
self._log_message_for_verbose("Running toxicity validation...\n")
toxicity_validate = self._moderation_class(
moderation_class=ComprehendToxicity
filter_functions = {
"pii": ComprehendPII,
"toxicity": ComprehendToxicity,
"intent": ComprehendIntent,
}
filters = self.config.filters # type: ignore
for _filter in filters:
filter_name = (
"pii"
if isinstance(_filter, ModerationPiiConfig)
else (
"toxicity"
if isinstance(_filter, ModerationToxicityConfig)
else (
"intent"
if isinstance(_filter, ModerationIntentConfig)
else None
)
)
)
output_text = toxicity_validate(prompt_value=output_text)
if filter_name in filter_functions:
self._log_message_for_verbose(
f"Running {filter_name} Validation...\n"
)
validation_fn = self._moderation_class(
moderation_class=filter_functions[filter_name]
)
input_text = input_text if not output_text else output_text
output_text = validation_fn(
prompt_value=input_text,
config=_filter.dict(),
)
self._log_message_for_verbose("Running intent validation...\n")
intent_validate = self._moderation_class(
moderation_class=ComprehendIntent
)
output_text = intent_validate(prompt_value=output_text)
else:
filter_functions = {
"pii": ComprehendPII,
"toxicity": ComprehendToxicity,
"intent": ComprehendIntent,
}
filters = self.config["filters"]
for _filter in filters:
filter_name = f"{_filter}"
if filter_name in filter_functions:
self._log_message_for_verbose(
f"Running {filter_name} Validation...\n"
)
validation_fn = self._moderation_class(
moderation_class=filter_functions[filter_name]
)
input_text = input_text if not output_text else output_text
output_text = validation_fn(
prompt_value=input_text,
config=self.config[filter_name]
if filter_name in self.config
else None,
)
# convert text to prompt and return
return self._convert_text_to_prompt(prompt=prompt, text=output_text)

@ -28,19 +28,19 @@ class BaseModerationCallbackHandler:
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after PII validation is complete."""
raise NotImplementedError("Subclasses should implement this async method.")
pass
async def on_after_toxicity(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after Toxicity validation is complete."""
raise NotImplementedError("Subclasses should implement this async method.")
pass
async def on_after_intent(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after Toxicity validation is complete."""
raise NotImplementedError("Subclasses should implement this async method.")
pass
@property
def pii_callback(self) -> bool:

@ -0,0 +1,51 @@
from typing import List, Union
from pydantic import BaseModel
class ModerationPiiConfig(BaseModel):
threshold: float = 0.5
"""Threshold for PII confidence score, defaults to 0.5 i.e. 50%"""
labels: List[str] = []
"""
List of PII Universal Labels.
Defaults to `list[]`
"""
redact: bool = False
"""Whether to perform redaction of detected PII entities"""
mask_character: str = "*"
"""Redaction mask character in case redact=True, defaults to asterisk (*)"""
class ModerationToxicityConfig(BaseModel):
threshold: float = 0.5
"""Threshold for Toxic label confidence score, defaults to 0.5 i.e. 50%"""
labels: List[str] = []
"""List of toxic labels, defaults to `list[]`"""
class ModerationIntentConfig(BaseModel):
threshold: float = 0.5
"""
Threshold for Intent classification
confidence score, defaults to 0.5 i.e. 50%
"""
class BaseModerationConfig(BaseModel):
filters: List[
Union[ModerationPiiConfig, ModerationToxicityConfig, ModerationIntentConfig]
] = [
ModerationPiiConfig(),
ModerationToxicityConfig(),
ModerationIntentConfig(),
]
"""
Filters applied to the moderation chain, defaults to
`[ModerationPiiConfig(), ModerationToxicityConfig(),
ModerationIntentConfig()]`
"""

@ -1,12 +0,0 @@
from enum import Enum
class BaseModerationActions(Enum):
STOP = 1
ALLOW = 2
class BaseModerationFilters(str, Enum):
PII = "pii"
TOXICITY = "toxicity"
INTENT = "intent"

@ -1,6 +1,5 @@
import asyncio
import warnings
from typing import Any, Dict, Optional
from typing import Any, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationIntentionError,
@ -30,20 +29,17 @@ class ComprehendIntent:
intent_endpoint = "document-classifier-endpoint/prompt-intent"
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
def validate(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check and validate the intent of the given prompt text.
Args:
comprehend_client: Comprehend client for intent classification
prompt_value (str): The input text to be checked for unintended intent
config (Dict[str, Any]): Configuration settings for intent checks
prompt_value (str): The input text to be checked for unintended intent.
config (Dict[str, Any]): Configuration settings for intent checks.
Raises:
ValueError: If unintended intent is found in the prompt text based
on the specified threshold.
on the specified threshold.
Returns:
str: The input prompt_value.
@ -53,26 +49,16 @@ class ComprehendIntent:
Comprehend's classify_document API and raises an error if unintended
intent is detected with a score above the specified threshold.
Example:
comprehend_client = boto3.client('comprehend')
prompt_text = "Please tell me your credit card information."
config = {"threshold": 0.7}
checked_prompt = check_intent(comprehend_client, prompt_text, config)
"""
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
BaseModerationActions,
)
threshold = config.get("threshold", 0.5) if config else 0.5
action = (
config.get("action", BaseModerationActions.STOP)
if config
else BaseModerationActions.STOP
)
threshold = config.get("threshold")
intent_found = False
if action == BaseModerationActions.ALLOW:
warnings.warn(
"You have allowed content with Harmful content."
"Defaulting to STOP action..."
)
action = BaseModerationActions.STOP
endpoint_arn = self._get_arn()
response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn

@ -23,33 +23,19 @@ class ComprehendPII:
self.callback = callback
self.unique_id = unique_id
def validate(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
BaseModerationActions,
def validate(self, prompt_value: str, config: Any = None) -> str:
redact = config.get("redact")
return (
self._detect_pii(prompt_value=prompt_value, config=config)
if redact
else self._contains_pii(prompt_value=prompt_value, config=config)
)
if config:
action = config.get("action", BaseModerationActions.STOP)
if action not in [BaseModerationActions.STOP, BaseModerationActions.ALLOW]:
raise ValueError("Action can either be stop or allow")
return (
self._contains_pii(prompt_value=prompt_value, config=config)
if action == BaseModerationActions.STOP
else self._detect_pii(prompt_value=prompt_value, config=config)
)
else:
return self._contains_pii(prompt_value=prompt_value)
def _contains_pii(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
def _contains_pii(self, prompt_value: str, config: Any = None) -> str:
"""
Checks for Personally Identifiable Information (PII) labels above a
specified threshold.
specified threshold. Uses Amazon Comprehend Contains PII Entities API. See -
https://docs.aws.amazon.com/comprehend/latest/APIReference/API_ContainsPiiEntities.html
Args:
prompt_value (str): The input text to be checked for PII labels.
config (Dict[str, Any]): Configuration for PII check and actions.
@ -68,8 +54,8 @@ class ComprehendPII:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = pii_identified
threshold = config.get("threshold", 0.5) if config else 0.5
pii_labels = config.get("labels", []) if config else []
threshold = config.get("threshold")
pii_labels = config.get("labels")
pii_found = False
for entity in pii_identified["Labels"]:
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
@ -93,7 +79,8 @@ class ComprehendPII:
Detects and handles Personally Identifiable Information (PII) entities in the
given prompt text using Amazon Comprehend's detect_pii_entities API. The
function provides options to redact or stop processing based on the identified
PII entities and a provided configuration.
PII entities and a provided configuration. Uses Amazon Comprehend Detect PII
Entities API.
Args:
prompt_value (str): The input text to be checked for PII entities.
@ -143,9 +130,9 @@ class ComprehendPII:
if pii_found:
raise ModerationPiiError
else:
threshold = config.get("threshold", 0.5) # type: ignore
pii_labels = config.get("labels", []) # type: ignore
mask_marker = config.get("mask_character", "*") # type: ignore
threshold = config.get("threshold") # type: ignore
pii_labels = config.get("labels") # type: ignore
mask_marker = config.get("mask_character") # type: ignore
pii_found = False
for entity in pii_identified["Entities"]:
@ -157,10 +144,14 @@ class ComprehendPII:
pii_found = True
char_offset_begin = entity["BeginOffset"]
char_offset_end = entity["EndOffset"]
mask_length = char_offset_end - char_offset_begin + 1
masked_part = mask_marker * mask_length
prompt_value = (
prompt_value[:char_offset_begin]
+ mask_marker * (char_offset_end - char_offset_begin)
+ prompt_value[char_offset_end:]
+ masked_part
+ prompt_value[char_offset_end + 1 :]
)
if self.callback and self.callback.pii_callback:

@ -1,7 +1,6 @@
import asyncio
import importlib
import warnings
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationToxicityError,
@ -30,14 +29,15 @@ class ComprehendToxicity:
Validate and initialize toxicity processing configuration.
Args:
max_size (int): Maximum sentence size defined in the configuration object.
max_size (int): Maximum sentence size defined in the
configuration object.
Raises:
Exception: If the maximum sentence size exceeds the 5KB limit.
Note:
This function ensures that the NLTK punkt tokenizer is downloaded if not
already present.
This function ensures that the NLTK punkt tokenizer is downloaded
if not already present.
Returns:
None
@ -63,34 +63,36 @@ class ComprehendToxicity:
Split a paragraph into chunks of sentences, respecting the maximum size limit.
Args:
paragraph (str): The input paragraph to be split into chunks
max_size (int, optional): The maximum size limit in bytes for each chunk
Defaults to 1024.
paragraph (str): The input paragraph to be split into chunks.
max_size (int, optional): The maximum size limit in bytes for
each chunk. Defaults to 1024.
Returns:
List[List[str]]: A list of chunks, where each chunk is a list of sentences
List[List[str]]: A list of chunks, where each chunk is a list
of sentences.
Note:
This function validates the maximum sentence size based on service limits
using the 'toxicity_init_validate' function. It uses the NLTK sentence
tokenizer to split the paragraph into sentences.
This function validates the maximum sentence size based on service
limits using the 'toxicity_init_validate' function. It uses the NLTK
sentence tokenizer to split the paragraph into sentences.
Example:
paragraph = "This is a sample paragraph. It
contains multiple sentences. ..."
chunks = split_paragraph(paragraph, max_size=2048)
"""
# validate max. sentence size based on Service limits
nltk = self._toxicity_init_validate(max_size)
sentences = nltk.sent_tokenize(prompt_value)
chunks = []
current_chunk = [] # type: ignore
chunks = list() # type: ignore
current_chunk = list() # type: ignore
current_size = 0
for sentence in sentences:
sentence_size = len(sentence.encode("utf-8"))
# If adding a new sentence exceeds max_size or
# current_chunk has 10 sentences, start a new chunk
# If adding a new sentence exceeds max_size
# or current_chunk has 10 sentences, start a new chunk
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
if current_chunk: # Avoid appending empty chunks
chunks.append(current_chunk)
@ -103,16 +105,12 @@ class ComprehendToxicity:
# Add any remaining sentences
if current_chunk:
chunks.append(current_chunk)
return chunks
def validate(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check the toxicity of a given text prompt using AWS Comprehend service
and apply actions based on configuration.
Check the toxicity of a given text prompt using AWS
Comprehend service and apply actions based on configuration.
Args:
prompt_value (str): The text content to be checked for toxicity.
config (Dict[str, Any]): Configuration for toxicity checks and actions.
@ -122,7 +120,7 @@ class ComprehendToxicity:
Raises:
ValueError: If the prompt contains toxic labels and cannot be
processed based on the configuration.
processed based on the configuration.
"""
chunks = self._split_paragraph(prompt_value=prompt_value)
@ -134,76 +132,34 @@ class ComprehendToxicity:
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_input"] = segments # type: ignore
self.moderation_beacon["moderation_output"] = response
if config:
from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501
BaseModerationActions,
)
toxicity_found = False
action = config.get("action", BaseModerationActions.STOP)
if action not in [
BaseModerationActions.STOP,
BaseModerationActions.ALLOW,
]:
raise ValueError("Action can either be stop or allow")
threshold = config.get("threshold", 0.5) if config else 0.5
toxicity_labels = config.get("labels", []) if config else []
if action == BaseModerationActions.STOP:
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label
and (
not toxicity_labels
or label["Name"] in toxicity_labels
)
and label["Score"] >= threshold
):
toxicity_found = True
break
if action == BaseModerationActions.ALLOW:
if not toxicity_labels:
warnings.warn(
"You have allowed toxic content without specifying "
"any toxicity labels."
)
else:
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label["Name"] in toxicity_labels
and label["Score"] >= threshold
):
toxicity_found = True
break
if self.callback and self.callback.toxicity_callback:
if toxicity_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
if toxicity_found:
raise ModerationToxicityError
toxicity_found = False
threshold = config.get("threshold")
toxicity_labels = config.get("labels")
if not toxicity_labels:
for item in response["ResultList"]:
for label in item["Labels"]:
if label["Score"] >= threshold:
toxicity_found = True
break
else:
if response["ResultList"]:
detected_toxic_labels = list()
for item in response["ResultList"]:
detected_toxic_labels.extend(item["Labels"])
if any(item["Score"] >= 0.5 for item in detected_toxic_labels):
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
raise ModerationToxicityError
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label["Name"] in toxicity_labels
and label["Score"] >= threshold
):
toxicity_found = True
break
if self.callback and self.callback.toxicity_callback:
if toxicity_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
if toxicity_found:
raise ModerationToxicityError
return prompt_value

Loading…
Cancel
Save