langchain/libs/experimental/langchain_experimental/comprehend_moderation/amazon_comprehend_moderation.py
Harrison Chase 4abe85be57
Harrison/string inplace (#10153)
Co-authored-by: Wrick Talukdar <wrick.talukdar@gmail.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Lucky-Lance <77819606+Lucky-Lance@users.noreply.github.com>
Co-authored-by: 陆徐东 <luxudong@MacBook-Pro.local>
2023-09-03 14:25:29 -07:00

189 lines
6.6 KiB
Python

from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
BaseModerationCallbackHandler,
)
from langchain_experimental.comprehend_moderation.base_moderation_config import (
BaseModerationConfig,
)
from langchain_experimental.pydantic_v1 import root_validator
class AmazonComprehendModerationChain(Chain):
"""A subclass of Chain, designed to apply moderation to LLMs."""
output_key: str = "output" #: :meta private:
"""Key used to fetch/store the output in data containers. Defaults to `output`"""
input_key: str = "input" #: :meta private:
"""Key used to fetch/store the input in data containers. Defaults to `input`"""
moderation_config: BaseModerationConfig = BaseModerationConfig()
"""
Configuration settings for moderation,
defaults to BaseModerationConfig with default values
"""
client: Optional[Any] = None
"""boto3 client object for connection to Amazon Comprehend"""
region_name: Optional[str] = None
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config in case it is not provided here.
"""
credentials_profile_name: Optional[str] = None
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
has either access keys or role information specified.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""
moderation_callback: Optional[BaseModerationCallbackHandler] = None
"""Callback handler for moderation, this is different
from regular callbacks which can be used in addition to this."""
unique_id: Optional[str] = None
"""A unique id that can be used to identify or group a user or session"""
@root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Creates an Amazon Comprehend client
Args:
values (Dict[str, Any]): A dictionary containing configuration values.
Returns:
Dict[str, Any]: A dictionary with the updated configuration values,
including the Amazon Comprehend client.
Raises:
ModuleNotFoundError: If the 'boto3' package is not installed.
ValueError: If there is an issue importing 'boto3' or loading
AWS credentials.
Example:
.. code-block:: python
config = {
"credentials_profile_name": "my-profile",
"region_name": "us-west-2"
}
updated_config = create_client(config)
comprehend_client = updated_config["client"]
"""
if values.get("client") is not None:
return values
try:
import boto3
if values.get("credentials_profile_name"):
session = boto3.Session(profile_name=values["credentials_profile_name"])
else:
# use default credentials
session = boto3.Session()
client_params = {}
if values.get("region_name"):
client_params["region_name"] = values["region_name"]
values["client"] = session.client("comprehend", **client_params)
return values
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
@property
def output_keys(self) -> List[str]:
"""
Returns a list of output keys.
This method defines the output keys that will be used to access the output
values produced by the chain or function. It ensures that the specified keys
are available to access the outputs.
Returns:
List[str]: A list of output keys.
Note:
This method is considered private and may not be intended for direct
external use.
"""
return [self.output_key]
@property
def input_keys(self) -> List[str]:
"""
Returns a list of input keys expected by the prompt.
This method defines the input keys that the prompt expects in order to perform
its processing. It ensures that the specified keys are available for providing
input to the prompt.
Returns:
List[str]: A list of input keys.
Note:
This method is considered private and may not be intended for direct
external use.
"""
return [self.input_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Executes the moderation process on the input text and returns the processed
output.
This internal method performs the moderation process on the input text. It
converts the input prompt value to plain text, applies the specified filters,
and then converts the filtered output back to a suitable prompt value object.
Additionally, it provides the option to log information about the run using
the provided `run_manager`.
Args:
inputs: A dictionary containing input values
run_manager: A run manager to handle run-related events. Default is None
Returns:
Dict[str, str]: A dictionary containing the processed output of the
moderation process.
Raises:
ValueError: If there is an error during the moderation process
"""
if run_manager:
run_manager.on_text("Running AmazonComprehendModerationChain...\n")
moderation = BaseModeration(
client=self.client,
config=self.moderation_config,
moderation_callback=self.moderation_callback,
unique_id=self.unique_id,
run_manager=run_manager,
)
response = moderation.moderate(prompt=inputs[self.input_keys[0]])
return {self.output_key: response}