Harrison/string inplace (#10153)

Co-authored-by: Wrick Talukdar <wrick.talukdar@gmail.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Lucky-Lance <77819606+Lucky-Lance@users.noreply.github.com>
Co-authored-by: 陆徐东 <luxudong@MacBook-Pro.local>
pull/9999/head
Harrison Chase 12 months ago committed by GitHub
parent f5af756397
commit 4abe85be57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -111,7 +111,9 @@ class TaskExecutor:
dep_task = self.id_task_map[dep_id] dep_task = self.id_task_map[dep_id]
for k, v in task.args.items(): for k, v in task.args.items():
if f"<resource-{dep_id}>" in v: if f"<resource-{dep_id}>" in v:
task.args[k].replace(f"<resource-{dep_id}>", dep_task.result) task.args[k] = task.args[k].replace(
f"<resource-{dep_id}>", dep_task.result
)
def run(self) -> str: def run(self) -> str:
for task in self.tasks: for task in self.tasks:

@ -5,9 +5,11 @@ from langchain_experimental.comprehend_moderation.base_moderation import BaseMod
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import ( from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
BaseModerationCallbackHandler, BaseModerationCallbackHandler,
) )
from langchain_experimental.comprehend_moderation.base_moderation_enums import ( from langchain_experimental.comprehend_moderation.base_moderation_config import (
BaseModerationActions, BaseModerationConfig,
BaseModerationFilters, ModerationIntentConfig,
ModerationPiiConfig,
ModerationToxicityConfig,
) )
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
from langchain_experimental.comprehend_moderation.pii import ComprehendPII from langchain_experimental.comprehend_moderation.pii import ComprehendPII
@ -15,11 +17,13 @@ from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxi
__all__ = [ __all__ = [
"BaseModeration", "BaseModeration",
"BaseModerationActions",
"BaseModerationFilters",
"ComprehendPII", "ComprehendPII",
"ComprehendIntent", "ComprehendIntent",
"ComprehendToxicity", "ComprehendToxicity",
"BaseModerationConfig",
"ModerationPiiConfig",
"ModerationToxicityConfig",
"ModerationIntentConfig",
"BaseModerationCallbackHandler", "BaseModerationCallbackHandler",
"AmazonComprehendModerationChain", "AmazonComprehendModerationChain",
] ]

@ -3,12 +3,13 @@ from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain_experimental.comprehend_moderation.base_moderation import ( from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
BaseModeration,
)
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import ( from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
BaseModerationCallbackHandler, BaseModerationCallbackHandler,
) )
from langchain_experimental.comprehend_moderation.base_moderation_config import (
BaseModerationConfig,
)
from langchain_experimental.pydantic_v1 import root_validator from langchain_experimental.pydantic_v1 import root_validator
@ -21,10 +22,13 @@ class AmazonComprehendModerationChain(Chain):
input_key: str = "input" #: :meta private: input_key: str = "input" #: :meta private:
"""Key used to fetch/store the input in data containers. Defaults to `input`""" """Key used to fetch/store the input in data containers. Defaults to `input`"""
moderation_config: Optional[Dict[str, Any]] = None moderation_config: BaseModerationConfig = BaseModerationConfig()
"""Configuration settings for moderation""" """
Configuration settings for moderation,
defaults to BaseModerationConfig with default values
"""
client: Optional[Any] client: Optional[Any] = None
"""boto3 client object for connection to Amazon Comprehend""" """boto3 client object for connection to Amazon Comprehend"""
region_name: Optional[str] = None region_name: Optional[str] = None

@ -1,5 +1,5 @@
import uuid import uuid
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.prompts.base import StringPromptValue from langchain.prompts.base import StringPromptValue
@ -15,7 +15,7 @@ class BaseModeration:
def __init__( def __init__(
self, self,
client: Any, client: Any,
config: Optional[Dict[str, Any]] = None, config: Optional[Any] = None,
moderation_callback: Optional[Any] = None, moderation_callback: Optional[Any] = None,
unique_id: Optional[str] = None, unique_id: Optional[str] = None,
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
@ -105,6 +105,11 @@ class BaseModeration:
self.run_manager.on_text(message) self.run_manager.on_text(message)
def moderate(self, prompt: Any) -> str: def moderate(self, prompt: Any) -> str:
from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
ModerationIntentConfig,
ModerationPiiConfig,
ModerationToxicityConfig,
)
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501 from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
ModerationIntentionError, ModerationIntentionError,
ModerationPiiError, ModerationPiiError,
@ -115,33 +120,30 @@ class BaseModeration:
# convert prompt to text # convert prompt to text
input_text = self._convert_prompt_to_text(prompt=prompt) input_text = self._convert_prompt_to_text(prompt=prompt)
output_text = str() output_text = str()
# perform moderation
if self.config is None:
# In absence of config Action will default to STOP only
self._log_message_for_verbose("Running pii validation...\n")
pii_validate = self._moderation_class(moderation_class=ComprehendPII)
output_text = pii_validate(prompt_value=input_text)
self._log_message_for_verbose("Running toxicity validation...\n")
toxicity_validate = self._moderation_class(
moderation_class=ComprehendToxicity
)
output_text = toxicity_validate(prompt_value=output_text)
self._log_message_for_verbose("Running intent validation...\n") # perform moderation
intent_validate = self._moderation_class(
moderation_class=ComprehendIntent
)
output_text = intent_validate(prompt_value=output_text)
else:
filter_functions = { filter_functions = {
"pii": ComprehendPII, "pii": ComprehendPII,
"toxicity": ComprehendToxicity, "toxicity": ComprehendToxicity,
"intent": ComprehendIntent, "intent": ComprehendIntent,
} }
filters = self.config["filters"]
filters = self.config.filters # type: ignore
for _filter in filters: for _filter in filters:
filter_name = f"{_filter}" filter_name = (
"pii"
if isinstance(_filter, ModerationPiiConfig)
else (
"toxicity"
if isinstance(_filter, ModerationToxicityConfig)
else (
"intent"
if isinstance(_filter, ModerationIntentConfig)
else None
)
)
)
if filter_name in filter_functions: if filter_name in filter_functions:
self._log_message_for_verbose( self._log_message_for_verbose(
f"Running {filter_name} Validation...\n" f"Running {filter_name} Validation...\n"
@ -152,10 +154,9 @@ class BaseModeration:
input_text = input_text if not output_text else output_text input_text = input_text if not output_text else output_text
output_text = validation_fn( output_text = validation_fn(
prompt_value=input_text, prompt_value=input_text,
config=self.config[filter_name] config=_filter.dict(),
if filter_name in self.config
else None,
) )
# convert text to prompt and return # convert text to prompt and return
return self._convert_text_to_prompt(prompt=prompt, text=output_text) return self._convert_text_to_prompt(prompt=prompt, text=output_text)

@ -28,19 +28,19 @@ class BaseModerationCallbackHandler:
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None: ) -> None:
"""Run after PII validation is complete.""" """Run after PII validation is complete."""
raise NotImplementedError("Subclasses should implement this async method.") pass
async def on_after_toxicity( async def on_after_toxicity(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None: ) -> None:
"""Run after Toxicity validation is complete.""" """Run after Toxicity validation is complete."""
raise NotImplementedError("Subclasses should implement this async method.") pass
async def on_after_intent( async def on_after_intent(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None: ) -> None:
"""Run after Toxicity validation is complete.""" """Run after Toxicity validation is complete."""
raise NotImplementedError("Subclasses should implement this async method.") pass
@property @property
def pii_callback(self) -> bool: def pii_callback(self) -> bool:

@ -0,0 +1,51 @@
from typing import List, Union
from pydantic import BaseModel
class ModerationPiiConfig(BaseModel):
threshold: float = 0.5
"""Threshold for PII confidence score, defaults to 0.5 i.e. 50%"""
labels: List[str] = []
"""
List of PII Universal Labels.
Defaults to `list[]`
"""
redact: bool = False
"""Whether to perform redaction of detected PII entities"""
mask_character: str = "*"
"""Redaction mask character in case redact=True, defaults to asterisk (*)"""
class ModerationToxicityConfig(BaseModel):
threshold: float = 0.5
"""Threshold for Toxic label confidence score, defaults to 0.5 i.e. 50%"""
labels: List[str] = []
"""List of toxic labels, defaults to `list[]`"""
class ModerationIntentConfig(BaseModel):
threshold: float = 0.5
"""
Threshold for Intent classification
confidence score, defaults to 0.5 i.e. 50%
"""
class BaseModerationConfig(BaseModel):
filters: List[
Union[ModerationPiiConfig, ModerationToxicityConfig, ModerationIntentConfig]
] = [
ModerationPiiConfig(),
ModerationToxicityConfig(),
ModerationIntentConfig(),
]
"""
Filters applied to the moderation chain, defaults to
`[ModerationPiiConfig(), ModerationToxicityConfig(),
ModerationIntentConfig()]`
"""

@ -1,12 +0,0 @@
from enum import Enum
class BaseModerationActions(Enum):
STOP = 1
ALLOW = 2
class BaseModerationFilters(str, Enum):
PII = "pii"
TOXICITY = "toxicity"
INTENT = "intent"

@ -1,6 +1,5 @@
import asyncio import asyncio
import warnings from typing import Any, Optional
from typing import Any, Dict, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationIntentionError, ModerationIntentionError,
@ -30,16 +29,13 @@ class ComprehendIntent:
intent_endpoint = "document-classifier-endpoint/prompt-intent" intent_endpoint = "document-classifier-endpoint/prompt-intent"
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}" return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
def validate( def validate(self, prompt_value: str, config: Any = None) -> str:
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
""" """
Check and validate the intent of the given prompt text. Check and validate the intent of the given prompt text.
Args: Args:
comprehend_client: Comprehend client for intent classification prompt_value (str): The input text to be checked for unintended intent.
prompt_value (str): The input text to be checked for unintended intent config (Dict[str, Any]): Configuration settings for intent checks.
config (Dict[str, Any]): Configuration settings for intent checks
Raises: Raises:
ValueError: If unintended intent is found in the prompt text based ValueError: If unintended intent is found in the prompt text based
@ -53,26 +49,16 @@ class ComprehendIntent:
Comprehend's classify_document API and raises an error if unintended Comprehend's classify_document API and raises an error if unintended
intent is detected with a score above the specified threshold. intent is detected with a score above the specified threshold.
Example:
comprehend_client = boto3.client('comprehend')
prompt_text = "Please tell me your credit card information."
config = {"threshold": 0.7}
checked_prompt = check_intent(comprehend_client, prompt_text, config)
""" """
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
BaseModerationActions,
)
threshold = config.get("threshold", 0.5) if config else 0.5 threshold = config.get("threshold")
action = (
config.get("action", BaseModerationActions.STOP)
if config
else BaseModerationActions.STOP
)
intent_found = False intent_found = False
if action == BaseModerationActions.ALLOW:
warnings.warn(
"You have allowed content with Harmful content."
"Defaulting to STOP action..."
)
action = BaseModerationActions.STOP
endpoint_arn = self._get_arn() endpoint_arn = self._get_arn()
response = self.client.classify_document( response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn Text=prompt_value, EndpointArn=endpoint_arn

@ -23,33 +23,19 @@ class ComprehendPII:
self.callback = callback self.callback = callback
self.unique_id = unique_id self.unique_id = unique_id
def validate( def validate(self, prompt_value: str, config: Any = None) -> str:
self, prompt_value: str, config: Optional[Dict[str, Any]] = None redact = config.get("redact")
) -> str:
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
BaseModerationActions,
)
if config:
action = config.get("action", BaseModerationActions.STOP)
if action not in [BaseModerationActions.STOP, BaseModerationActions.ALLOW]:
raise ValueError("Action can either be stop or allow")
return ( return (
self._contains_pii(prompt_value=prompt_value, config=config) self._detect_pii(prompt_value=prompt_value, config=config)
if action == BaseModerationActions.STOP if redact
else self._detect_pii(prompt_value=prompt_value, config=config) else self._contains_pii(prompt_value=prompt_value, config=config)
) )
else:
return self._contains_pii(prompt_value=prompt_value)
def _contains_pii( def _contains_pii(self, prompt_value: str, config: Any = None) -> str:
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
""" """
Checks for Personally Identifiable Information (PII) labels above a Checks for Personally Identifiable Information (PII) labels above a
specified threshold. specified threshold. Uses Amazon Comprehend Contains PII Entities API. See -
https://docs.aws.amazon.com/comprehend/latest/APIReference/API_ContainsPiiEntities.html
Args: Args:
prompt_value (str): The input text to be checked for PII labels. prompt_value (str): The input text to be checked for PII labels.
config (Dict[str, Any]): Configuration for PII check and actions. config (Dict[str, Any]): Configuration for PII check and actions.
@ -68,8 +54,8 @@ class ComprehendPII:
self.moderation_beacon["moderation_input"] = prompt_value self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = pii_identified self.moderation_beacon["moderation_output"] = pii_identified
threshold = config.get("threshold", 0.5) if config else 0.5 threshold = config.get("threshold")
pii_labels = config.get("labels", []) if config else [] pii_labels = config.get("labels")
pii_found = False pii_found = False
for entity in pii_identified["Labels"]: for entity in pii_identified["Labels"]:
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or ( if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
@ -93,7 +79,8 @@ class ComprehendPII:
Detects and handles Personally Identifiable Information (PII) entities in the Detects and handles Personally Identifiable Information (PII) entities in the
given prompt text using Amazon Comprehend's detect_pii_entities API. The given prompt text using Amazon Comprehend's detect_pii_entities API. The
function provides options to redact or stop processing based on the identified function provides options to redact or stop processing based on the identified
PII entities and a provided configuration. PII entities and a provided configuration. Uses Amazon Comprehend Detect PII
Entities API.
Args: Args:
prompt_value (str): The input text to be checked for PII entities. prompt_value (str): The input text to be checked for PII entities.
@ -143,9 +130,9 @@ class ComprehendPII:
if pii_found: if pii_found:
raise ModerationPiiError raise ModerationPiiError
else: else:
threshold = config.get("threshold", 0.5) # type: ignore threshold = config.get("threshold") # type: ignore
pii_labels = config.get("labels", []) # type: ignore pii_labels = config.get("labels") # type: ignore
mask_marker = config.get("mask_character", "*") # type: ignore mask_marker = config.get("mask_character") # type: ignore
pii_found = False pii_found = False
for entity in pii_identified["Entities"]: for entity in pii_identified["Entities"]:
@ -157,10 +144,14 @@ class ComprehendPII:
pii_found = True pii_found = True
char_offset_begin = entity["BeginOffset"] char_offset_begin = entity["BeginOffset"]
char_offset_end = entity["EndOffset"] char_offset_end = entity["EndOffset"]
mask_length = char_offset_end - char_offset_begin + 1
masked_part = mask_marker * mask_length
prompt_value = ( prompt_value = (
prompt_value[:char_offset_begin] prompt_value[:char_offset_begin]
+ mask_marker * (char_offset_end - char_offset_begin) + masked_part
+ prompt_value[char_offset_end:] + prompt_value[char_offset_end + 1 :]
) )
if self.callback and self.callback.pii_callback: if self.callback and self.callback.pii_callback:

@ -1,7 +1,6 @@
import asyncio import asyncio
import importlib import importlib
import warnings from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationToxicityError, ModerationToxicityError,
@ -30,14 +29,15 @@ class ComprehendToxicity:
Validate and initialize toxicity processing configuration. Validate and initialize toxicity processing configuration.
Args: Args:
max_size (int): Maximum sentence size defined in the configuration object. max_size (int): Maximum sentence size defined in the
configuration object.
Raises: Raises:
Exception: If the maximum sentence size exceeds the 5KB limit. Exception: If the maximum sentence size exceeds the 5KB limit.
Note: Note:
This function ensures that the NLTK punkt tokenizer is downloaded if not This function ensures that the NLTK punkt tokenizer is downloaded
already present. if not already present.
Returns: Returns:
None None
@ -63,34 +63,36 @@ class ComprehendToxicity:
Split a paragraph into chunks of sentences, respecting the maximum size limit. Split a paragraph into chunks of sentences, respecting the maximum size limit.
Args: Args:
paragraph (str): The input paragraph to be split into chunks paragraph (str): The input paragraph to be split into chunks.
max_size (int, optional): The maximum size limit in bytes for each chunk max_size (int, optional): The maximum size limit in bytes for
Defaults to 1024. each chunk. Defaults to 1024.
Returns: Returns:
List[List[str]]: A list of chunks, where each chunk is a list of sentences List[List[str]]: A list of chunks, where each chunk is a list
of sentences.
Note: Note:
This function validates the maximum sentence size based on service limits This function validates the maximum sentence size based on service
using the 'toxicity_init_validate' function. It uses the NLTK sentence limits using the 'toxicity_init_validate' function. It uses the NLTK
tokenizer to split the paragraph into sentences. sentence tokenizer to split the paragraph into sentences.
Example:
paragraph = "This is a sample paragraph. It
contains multiple sentences. ..."
chunks = split_paragraph(paragraph, max_size=2048)
""" """
# validate max. sentence size based on Service limits # validate max. sentence size based on Service limits
nltk = self._toxicity_init_validate(max_size) nltk = self._toxicity_init_validate(max_size)
sentences = nltk.sent_tokenize(prompt_value) sentences = nltk.sent_tokenize(prompt_value)
chunks = list() # type: ignore
chunks = [] current_chunk = list() # type: ignore
current_chunk = [] # type: ignore
current_size = 0 current_size = 0
for sentence in sentences: for sentence in sentences:
sentence_size = len(sentence.encode("utf-8")) sentence_size = len(sentence.encode("utf-8"))
# If adding a new sentence exceeds max_size
# If adding a new sentence exceeds max_size or # or current_chunk has 10 sentences, start a new chunk
# current_chunk has 10 sentences, start a new chunk
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10): if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
if current_chunk: # Avoid appending empty chunks if current_chunk: # Avoid appending empty chunks
chunks.append(current_chunk) chunks.append(current_chunk)
@ -103,16 +105,12 @@ class ComprehendToxicity:
# Add any remaining sentences # Add any remaining sentences
if current_chunk: if current_chunk:
chunks.append(current_chunk) chunks.append(current_chunk)
return chunks return chunks
def validate( def validate(self, prompt_value: str, config: Any = None) -> str:
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
""" """
Check the toxicity of a given text prompt using AWS Comprehend service Check the toxicity of a given text prompt using AWS
and apply actions based on configuration. Comprehend service and apply actions based on configuration.
Args: Args:
prompt_value (str): The text content to be checked for toxicity. prompt_value (str): The text content to be checked for toxicity.
config (Dict[str, Any]): Configuration for toxicity checks and actions. config (Dict[str, Any]): Configuration for toxicity checks and actions.
@ -134,43 +132,16 @@ class ComprehendToxicity:
if self.callback and self.callback.toxicity_callback: if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_input"] = segments # type: ignore self.moderation_beacon["moderation_input"] = segments # type: ignore
self.moderation_beacon["moderation_output"] = response self.moderation_beacon["moderation_output"] = response
if config:
from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501
BaseModerationActions,
)
toxicity_found = False toxicity_found = False
action = config.get("action", BaseModerationActions.STOP) threshold = config.get("threshold")
if action not in [ toxicity_labels = config.get("labels")
BaseModerationActions.STOP,
BaseModerationActions.ALLOW,
]:
raise ValueError("Action can either be stop or allow")
threshold = config.get("threshold", 0.5) if config else 0.5
toxicity_labels = config.get("labels", []) if config else []
if action == BaseModerationActions.STOP: if not toxicity_labels:
for item in response["ResultList"]: for item in response["ResultList"]:
for label in item["Labels"]: for label in item["Labels"]:
if ( if label["Score"] >= threshold:
label
and (
not toxicity_labels
or label["Name"] in toxicity_labels
)
and label["Score"] >= threshold
):
toxicity_found = True toxicity_found = True
break break
if action == BaseModerationActions.ALLOW:
if not toxicity_labels:
warnings.warn(
"You have allowed toxic content without specifying "
"any toxicity labels."
)
else: else:
for item in response["ResultList"]: for item in response["ResultList"]:
for label in item["Labels"]: for label in item["Labels"]:
@ -191,19 +162,4 @@ class ComprehendToxicity:
) )
if toxicity_found: if toxicity_found:
raise ModerationToxicityError raise ModerationToxicityError
else:
if response["ResultList"]:
detected_toxic_labels = list()
for item in response["ResultList"]:
detected_toxic_labels.extend(item["Labels"])
if any(item["Score"] >= 0.5 for item in detected_toxic_labels):
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
raise ModerationToxicityError
return prompt_value return prompt_value

Loading…
Cancel
Save