mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
d57d08fd01
This PR implements a custom chain that wraps Amazon Comprehend API calls. The custom chain is aimed to be used with LLM chains to provide moderation capability that let’s you detect and redact PII, Toxic and Intent content in the LLM prompt, or the LLM response. The implementation accepts a configuration object to control what checks will be performed on a LLM prompt and can be used in a variety of setups using the LangChain expression language to not only detect the configured info in chains, but also other constructs such as a retriever. The included sample notebook goes over the different configuration options and how to use it with other chains. ### Usage sample ```python from langchain_experimental.comprehend_moderation import BaseModerationActions, BaseModerationFilters moderation_config = { "filters":[ BaseModerationFilters.PII, BaseModerationFilters.TOXICITY, BaseModerationFilters.INTENT ], "pii":{ "action": BaseModerationActions.ALLOW, "threshold":0.5, "labels":["SSN"], "mask_character": "X" }, "toxicity":{ "action": BaseModerationActions.STOP, "threshold":0.5 }, "intent":{ "action": BaseModerationActions.STOP, "threshold":0.5 } } comp_moderation_with_config = AmazonComprehendModerationChain( moderation_config=moderation_config, #specify the configuration client=comprehend_client, #optionally pass the Boto3 Client verbose=True ) template = """Question: {question} Answer:""" prompt = PromptTemplate(template=template, input_variables=["question"]) responses = [ "Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.", "Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here." ] llm = FakeListLLM(responses=responses) llm_chain = LLMChain(prompt=prompt, llm=llm) chain = ( prompt | comp_moderation_with_config | {llm_chain.input_keys[0]: lambda x: x['output'] } | llm_chain | { "input": lambda x: x['text'] } | comp_moderation_with_config ) response = chain.invoke({"question": "A sample SSN number looks like this 123-456-7890. Can you give me some more samples?"}) print(response['output']) ``` ### Output ``` > Entering new AmazonComprehendModerationChain chain... Running AmazonComprehendModerationChain... Running pii validation... Found PII content..stopping.. The prompt contains PII entities and cannot be processed ``` --------- Co-authored-by: Piyush Jain <piyushjain@duck.com> Co-authored-by: Anjan Biswas <anjanavb@amazon.com> Co-authored-by: Jha <nikjha@amazon.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
from typing import Any, Callable, Dict
|
|
|
|
|
|
class BaseModerationCallbackHandler:
|
|
def __init__(self) -> None:
|
|
if (
|
|
self._is_method_unchanged(
|
|
BaseModerationCallbackHandler.on_after_pii, self.on_after_pii
|
|
)
|
|
and self._is_method_unchanged(
|
|
BaseModerationCallbackHandler.on_after_toxicity, self.on_after_toxicity
|
|
)
|
|
and self._is_method_unchanged(
|
|
BaseModerationCallbackHandler.on_after_intent, self.on_after_intent
|
|
)
|
|
):
|
|
raise NotImplementedError(
|
|
"Subclasses must override at least one of on_after_pii(), "
|
|
"on_after_toxicity(), or on_after_intent() functions."
|
|
)
|
|
|
|
def _is_method_unchanged(
|
|
self, base_method: Callable, derived_method: Callable
|
|
) -> bool:
|
|
return base_method.__qualname__ == derived_method.__qualname__
|
|
|
|
async def on_after_pii(
|
|
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
|
) -> None:
|
|
"""Run after PII validation is complete."""
|
|
raise NotImplementedError("Subclasses should implement this async method.")
|
|
|
|
async def on_after_toxicity(
|
|
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
|
) -> None:
|
|
"""Run after Toxicity validation is complete."""
|
|
raise NotImplementedError("Subclasses should implement this async method.")
|
|
|
|
async def on_after_intent(
|
|
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
|
) -> None:
|
|
"""Run after Toxicity validation is complete."""
|
|
raise NotImplementedError("Subclasses should implement this async method.")
|
|
|
|
@property
|
|
def pii_callback(self) -> bool:
|
|
return (
|
|
self.on_after_pii.__func__ # type: ignore
|
|
is not BaseModerationCallbackHandler.on_after_pii
|
|
)
|
|
|
|
@property
|
|
def toxicity_callback(self) -> bool:
|
|
return (
|
|
self.on_after_toxicity.__func__ # type: ignore
|
|
is not BaseModerationCallbackHandler.on_after_toxicity
|
|
)
|
|
|
|
@property
|
|
def intent_callback(self) -> bool:
|
|
return (
|
|
self.on_after_intent.__func__ # type: ignore
|
|
is not BaseModerationCallbackHandler.on_after_intent
|
|
)
|