import uuid from typing import Any, Callable, Optional, cast from langchain_core.callbacks.manager import CallbackManagerForChainRun from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompt_values import ChatPromptValue, StringPromptValue from langchain_experimental.comprehend_moderation.pii import ComprehendPII from langchain_experimental.comprehend_moderation.prompt_safety import ( ComprehendPromptSafety, ) from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity class BaseModeration: """Base class for moderation.""" def __init__( self, client: Any, config: Optional[Any] = None, moderation_callback: Optional[Any] = None, unique_id: Optional[str] = None, run_manager: Optional[CallbackManagerForChainRun] = None, ): self.client = client self.config = config self.moderation_callback = moderation_callback self.unique_id = unique_id self.chat_message_index = 0 self.run_manager = run_manager self.chain_id = str(uuid.uuid4()) def _convert_prompt_to_text(self, prompt: Any) -> str: input_text = str() if isinstance(prompt, StringPromptValue): input_text = prompt.text elif isinstance(prompt, str): input_text = prompt elif isinstance(prompt, ChatPromptValue): """ We will just check the last message in the message Chain of a ChatPromptTemplate. The typical chronology is SystemMessage > HumanMessage > AIMessage and so on. However assuming that with every chat the chain is invoked we will only check the last message. This is assuming that all previous messages have been checked already. Only HumanMessage and AIMessage will be checked. We can perhaps loop through and take advantage of the additional_kwargs property in the HumanMessage and AIMessage schema to mark messages that have been moderated. However that means that this class could generate multiple text chunks and moderate() logics would need to be updated. This also means some complexity in re-constructing the prompt while keeping the messages in sequence. """ message = prompt.messages[-1] self.chat_message_index = len(prompt.messages) - 1 if isinstance(message, HumanMessage): input_text = cast(str, message.content) if isinstance(message, AIMessage): input_text = cast(str, message.content) else: raise ValueError( f"Invalid input type {type(input_text)}. " "Must be a PromptValue, str, or list of BaseMessages." ) return input_text def _convert_text_to_prompt(self, prompt: Any, text: str) -> Any: if isinstance(prompt, StringPromptValue): return StringPromptValue(text=text) elif isinstance(prompt, str): return text elif isinstance(prompt, ChatPromptValue): # Copy the messages because we may need to mutate them. # We don't want to mutate data we don't own. messages = list(prompt.messages) message = messages[self.chat_message_index] if isinstance(message, HumanMessage): messages[self.chat_message_index] = HumanMessage( content=text, example=message.example, additional_kwargs=message.additional_kwargs, ) if isinstance(message, AIMessage): messages[self.chat_message_index] = AIMessage( content=text, example=message.example, additional_kwargs=message.additional_kwargs, ) return ChatPromptValue(messages=messages) else: raise ValueError( f"Invalid input type {type(input)}. " "Must be a PromptValue, str, or list of BaseMessages." ) def _moderation_class(self, moderation_class: Any) -> Callable: return moderation_class( client=self.client, callback=self.moderation_callback, unique_id=self.unique_id, chain_id=self.chain_id, ).validate def _log_message_for_verbose(self, message: str) -> None: if self.run_manager: self.run_manager.on_text(message) def moderate(self, prompt: Any) -> str: """Moderate the input prompt.""" from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501 ModerationPiiConfig, ModerationPromptSafetyConfig, ModerationToxicityConfig, ) from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501 ModerationPiiError, ModerationPromptSafetyError, ModerationToxicityError, ) try: # convert prompt to text input_text = self._convert_prompt_to_text(prompt=prompt) output_text = str() # perform moderation filter_functions = { "pii": ComprehendPII, "toxicity": ComprehendToxicity, "prompt_safety": ComprehendPromptSafety, } filters = self.config.filters # type: ignore for _filter in filters: filter_name = ( "pii" if isinstance(_filter, ModerationPiiConfig) else ( "toxicity" if isinstance(_filter, ModerationToxicityConfig) else ( "prompt_safety" if isinstance(_filter, ModerationPromptSafetyConfig) else None ) ) ) if filter_name in filter_functions: self._log_message_for_verbose( f"Running {filter_name} Validation...\n" ) validation_fn = self._moderation_class( moderation_class=filter_functions[filter_name] ) input_text = input_text if not output_text else output_text output_text = validation_fn( prompt_value=input_text, config=_filter.dict(), ) # convert text to prompt and return return self._convert_text_to_prompt(prompt=prompt, text=output_text) except ModerationPiiError as e: self._log_message_for_verbose(f"Found PII content..stopping..\n{str(e)}\n") raise e except ModerationToxicityError as e: self._log_message_for_verbose( f"Found Toxic content..stopping..\n{str(e)}\n" ) raise e except ModerationPromptSafetyError as e: self._log_message_for_verbose( f"Found Harmful intention..stopping..\n{str(e)}\n" ) raise e except Exception as e: raise e