forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
2.5 KiB
Python
83 lines
2.5 KiB
Python
"""Pass input through a moderation endpoint."""
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
from langchain.chains.base import Chain
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
|
|
class OpenAIModerationChain(Chain, BaseModel):
|
|
"""Pass input through a moderation endpoint.
|
|
|
|
To use, you should have the ``openai`` python package installed, and the
|
|
environment variable ``OPENAI_API_KEY`` set with your API key.
|
|
|
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
|
in, even if not explicitly saved on this class.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain.chains import OpenAIModerationChain
|
|
moderation = OpenAIModerationChain()
|
|
"""
|
|
|
|
client: Any #: :meta private:
|
|
model_name: Optional[str] = None
|
|
"""Moderation model name to use."""
|
|
error: bool = False
|
|
"""Whether or not to error if bad content was found."""
|
|
input_key: str = "input" #: :meta private:
|
|
output_key: str = "output" #: :meta private:
|
|
openai_api_key: Optional[str] = None
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key and python package exists in environment."""
|
|
openai_api_key = get_from_dict_or_env(
|
|
values, "openai_api_key", "OPENAI_API_KEY"
|
|
)
|
|
try:
|
|
import openai
|
|
|
|
openai.api_key = openai_api_key
|
|
values["client"] = openai.Moderation
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import openai python package. "
|
|
"Please it install it with `pip install openai`."
|
|
)
|
|
return values
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Expect input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return output key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.output_key]
|
|
|
|
def _moderate(self, text: str, results: dict) -> str:
|
|
if results["flagged"]:
|
|
error_str = "Text was found that violates OpenAI's content policy."
|
|
if self.error:
|
|
raise ValueError(error_str)
|
|
else:
|
|
return error_str
|
|
return text
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
text = inputs[self.input_key]
|
|
results = self.client.create(text)
|
|
output = self._moderate(text, results["results"][0])
|
|
return {self.output_key: output}
|