diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index 6deb63ed06..8b2f73d5cb 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -1,9 +1,13 @@ """Pass input through a moderation endpoint.""" + from typing import Any, Dict, List, Optional -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.pydantic_v1 import root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.callbacks import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.utils import check_package_version, get_from_dict_or_env from langchain.chains.base import Chain @@ -25,6 +29,7 @@ class OpenAIModerationChain(Chain): """ client: Any #: :meta private: + async_client: Any #: :meta private: model_name: Optional[str] = None """Moderation model name to use.""" error: bool = False @@ -33,6 +38,7 @@ class OpenAIModerationChain(Chain): output_key: str = "output" #: :meta private: openai_api_key: Optional[str] = None openai_organization: Optional[str] = None + _openai_pre_1_0: bool = Field(default=None) @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -52,7 +58,16 @@ class OpenAIModerationChain(Chain): openai.api_key = openai_api_key if openai_organization: openai.organization = openai_organization - values["client"] = openai.Moderation # type: ignore + values["_openai_pre_1_0"] = False + try: + check_package_version("openai", gte_version="1.0") + except ValueError: + values["_openai_pre_1_0"] = True + if values["_openai_pre_1_0"]: + values["client"] = openai.Moderation + else: + values["client"] = openai.OpenAI() + values["async_client"] = openai.AsyncOpenAI() except ImportError: raise ImportError( "Could not import openai python package. " @@ -76,8 +91,12 @@ class OpenAIModerationChain(Chain): """ return [self.output_key] - def _moderate(self, text: str, results: dict) -> str: - if results["flagged"]: + def _moderate(self, text: str, results: Any) -> str: + if self._openai_pre_1_0: + condition = results["flagged"] + else: + condition = results.flagged + if condition: error_str = "Text was found that violates OpenAI's content policy." if self.error: raise ValueError(error_str) @@ -87,10 +106,26 @@ class OpenAIModerationChain(Chain): def _call( self, - inputs: Dict[str, str], + inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: + ) -> Dict[str, Any]: + text = inputs[self.input_key] + if self._openai_pre_1_0: + results = self.client.create(text) + output = self._moderate(text, results["results"][0]) + else: + results = self.client.moderations.create(input=text) + output = self._moderate(text, results.results[0]) + return {self.output_key: output} + + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + if self._openai_pre_1_0: + return await super()._acall(inputs, run_manager=run_manager) text = inputs[self.input_key] - results = self.client.create(text) - output = self._moderate(text, results["results"][0]) + results = await self.async_client.moderations.create(input=text) + output = self._moderate(text, results.results[0]) return {self.output_key: output}