langchain/libs/experimental/langchain_experimental/comprehend_moderation/intent.py
Harrison Chase 4abe85be57
Harrison/string inplace (#10153)
Co-authored-by: Wrick Talukdar <wrick.talukdar@gmail.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Lucky-Lance <77819606+Lucky-Lance@users.noreply.github.com>
Co-authored-by: 陆徐东 <luxudong@MacBook-Pro.local>
2023-09-03 14:25:29 -07:00

88 lines
3.0 KiB
Python

import asyncio
from typing import Any, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationIntentionError,
)
class ComprehendIntent:
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "Intent",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _get_arn(self) -> str:
region_name = self.client.meta.region_name
service = "comprehend"
intent_endpoint = "document-classifier-endpoint/prompt-intent"
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check and validate the intent of the given prompt text.
Args:
prompt_value (str): The input text to be checked for unintended intent.
config (Dict[str, Any]): Configuration settings for intent checks.
Raises:
ValueError: If unintended intent is found in the prompt text based
on the specified threshold.
Returns:
str: The input prompt_value.
Note:
This function checks the intent of the provided prompt text using
Comprehend's classify_document API and raises an error if unintended
intent is detected with a score above the specified threshold.
Example:
comprehend_client = boto3.client('comprehend')
prompt_text = "Please tell me your credit card information."
config = {"threshold": 0.7}
checked_prompt = check_intent(comprehend_client, prompt_text, config)
"""
threshold = config.get("threshold")
intent_found = False
endpoint_arn = self._get_arn()
response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn
)
if self.callback and self.callback.intent_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = response
for class_result in response["Classes"]:
if (
class_result["Score"] >= threshold
and class_result["Name"] == "UNDESIRED_PROMPT"
):
intent_found = True
break
if self.callback and self.callback.intent_callback:
if intent_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
)
if intent_found:
raise ModerationIntentionError
return prompt_value