mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
3f6bf852ea
Added missed docstrings. Formatted docsctrings to the consistent format.
186 lines
7.2 KiB
Python
186 lines
7.2 KiB
Python
import uuid
|
|
from typing import Any, Callable, Optional, cast
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
|
|
|
|
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
|
from langchain_experimental.comprehend_moderation.prompt_safety import (
|
|
ComprehendPromptSafety,
|
|
)
|
|
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
|
|
|
|
|
|
class BaseModeration:
|
|
"""Base class for moderation."""
|
|
|
|
def __init__(
|
|
self,
|
|
client: Any,
|
|
config: Optional[Any] = None,
|
|
moderation_callback: Optional[Any] = None,
|
|
unique_id: Optional[str] = None,
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
):
|
|
self.client = client
|
|
self.config = config
|
|
self.moderation_callback = moderation_callback
|
|
self.unique_id = unique_id
|
|
self.chat_message_index = 0
|
|
self.run_manager = run_manager
|
|
self.chain_id = str(uuid.uuid4())
|
|
|
|
def _convert_prompt_to_text(self, prompt: Any) -> str:
|
|
input_text = str()
|
|
|
|
if isinstance(prompt, StringPromptValue):
|
|
input_text = prompt.text
|
|
elif isinstance(prompt, str):
|
|
input_text = prompt
|
|
elif isinstance(prompt, ChatPromptValue):
|
|
"""
|
|
We will just check the last message in the message Chain of a
|
|
ChatPromptTemplate. The typical chronology is
|
|
SystemMessage > HumanMessage > AIMessage and so on. However assuming
|
|
that with every chat the chain is invoked we will only check the last
|
|
message. This is assuming that all previous messages have been checked
|
|
already. Only HumanMessage and AIMessage will be checked. We can perhaps
|
|
loop through and take advantage of the additional_kwargs property in the
|
|
HumanMessage and AIMessage schema to mark messages that have been moderated.
|
|
However that means that this class could generate multiple text chunks
|
|
and moderate() logics would need to be updated. This also means some
|
|
complexity in re-constructing the prompt while keeping the messages in
|
|
sequence.
|
|
"""
|
|
message = prompt.messages[-1]
|
|
self.chat_message_index = len(prompt.messages) - 1
|
|
if isinstance(message, HumanMessage):
|
|
input_text = cast(str, message.content)
|
|
|
|
if isinstance(message, AIMessage):
|
|
input_text = cast(str, message.content)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid input type {type(input_text)}. "
|
|
"Must be a PromptValue, str, or list of BaseMessages."
|
|
)
|
|
return input_text
|
|
|
|
def _convert_text_to_prompt(self, prompt: Any, text: str) -> Any:
|
|
if isinstance(prompt, StringPromptValue):
|
|
return StringPromptValue(text=text)
|
|
elif isinstance(prompt, str):
|
|
return text
|
|
elif isinstance(prompt, ChatPromptValue):
|
|
# Copy the messages because we may need to mutate them.
|
|
# We don't want to mutate data we don't own.
|
|
messages = list(prompt.messages)
|
|
|
|
message = messages[self.chat_message_index]
|
|
|
|
if isinstance(message, HumanMessage):
|
|
messages[self.chat_message_index] = HumanMessage(
|
|
content=text,
|
|
example=message.example,
|
|
additional_kwargs=message.additional_kwargs,
|
|
)
|
|
if isinstance(message, AIMessage):
|
|
messages[self.chat_message_index] = AIMessage(
|
|
content=text,
|
|
example=message.example,
|
|
additional_kwargs=message.additional_kwargs,
|
|
)
|
|
return ChatPromptValue(messages=messages)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid input type {type(input)}. "
|
|
"Must be a PromptValue, str, or list of BaseMessages."
|
|
)
|
|
|
|
def _moderation_class(self, moderation_class: Any) -> Callable:
|
|
return moderation_class(
|
|
client=self.client,
|
|
callback=self.moderation_callback,
|
|
unique_id=self.unique_id,
|
|
chain_id=self.chain_id,
|
|
).validate
|
|
|
|
def _log_message_for_verbose(self, message: str) -> None:
|
|
if self.run_manager:
|
|
self.run_manager.on_text(message)
|
|
|
|
def moderate(self, prompt: Any) -> str:
|
|
"""Moderate the input prompt."""
|
|
|
|
from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
|
|
ModerationPiiConfig,
|
|
ModerationPromptSafetyConfig,
|
|
ModerationToxicityConfig,
|
|
)
|
|
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
|
|
ModerationPiiError,
|
|
ModerationPromptSafetyError,
|
|
ModerationToxicityError,
|
|
)
|
|
|
|
try:
|
|
# convert prompt to text
|
|
input_text = self._convert_prompt_to_text(prompt=prompt)
|
|
output_text = str()
|
|
|
|
# perform moderation
|
|
filter_functions = {
|
|
"pii": ComprehendPII,
|
|
"toxicity": ComprehendToxicity,
|
|
"prompt_safety": ComprehendPromptSafety,
|
|
}
|
|
|
|
filters = self.config.filters # type: ignore
|
|
|
|
for _filter in filters:
|
|
filter_name = (
|
|
"pii"
|
|
if isinstance(_filter, ModerationPiiConfig)
|
|
else (
|
|
"toxicity"
|
|
if isinstance(_filter, ModerationToxicityConfig)
|
|
else (
|
|
"prompt_safety"
|
|
if isinstance(_filter, ModerationPromptSafetyConfig)
|
|
else None
|
|
)
|
|
)
|
|
)
|
|
if filter_name in filter_functions:
|
|
self._log_message_for_verbose(
|
|
f"Running {filter_name} Validation...\n"
|
|
)
|
|
validation_fn = self._moderation_class(
|
|
moderation_class=filter_functions[filter_name]
|
|
)
|
|
input_text = input_text if not output_text else output_text
|
|
output_text = validation_fn(
|
|
prompt_value=input_text,
|
|
config=_filter.dict(),
|
|
)
|
|
|
|
# convert text to prompt and return
|
|
return self._convert_text_to_prompt(prompt=prompt, text=output_text)
|
|
|
|
except ModerationPiiError as e:
|
|
self._log_message_for_verbose(f"Found PII content..stopping..\n{str(e)}\n")
|
|
raise e
|
|
except ModerationToxicityError as e:
|
|
self._log_message_for_verbose(
|
|
f"Found Toxic content..stopping..\n{str(e)}\n"
|
|
)
|
|
raise e
|
|
except ModerationPromptSafetyError as e:
|
|
self._log_message_for_verbose(
|
|
f"Found Harmful intention..stopping..\n{str(e)}\n"
|
|
)
|
|
raise e
|
|
except Exception as e:
|
|
raise e
|