From d3ca2cc8c3488eb973dd5ee3ebd4a8d10d8e2dbf Mon Sep 17 00:00:00 2001 From: Matt Florence Date: Fri, 10 May 2024 13:04:13 -0600 Subject: [PATCH] langchain: Fix broken `OpenAIModerationChain` and implement async (#18537) Thank you for contributing to LangChain! ## PR title lancghain[patch]: fix `OpenAIModerationChain` and implement async ## PR message Description: fix `OpenAIModerationChain` and implement async Issues: - https://github.com/langchain-ai/langchain/issues/18533 - https://github.com/langchain-ai/langchain/issues/13685 Dependencies: none Twitter handle: mattflo ## Add tests and docs Existing documentation is broken: https://python.langchain.com/docs/guides/safety/moderation - [ x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Emilia Katari Co-authored-by: ccurme Co-authored-by: Erick Friis --- libs/langchain/langchain/chains/moderation.py | 55 +++++++++++++++---- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index 6deb63ed06..8b2f73d5cb 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -1,9 +1,13 @@ """Pass input through a moderation endpoint.""" + from typing import Any, Dict, List, Optional -from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.pydantic_v1 import root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.callbacks import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.utils import check_package_version, get_from_dict_or_env from langchain.chains.base import Chain @@ -25,6 +29,7 @@ class OpenAIModerationChain(Chain): """ client: Any #: :meta private: + async_client: Any #: :meta private: model_name: Optional[str] = None """Moderation model name to use.""" error: bool = False @@ -33,6 +38,7 @@ class OpenAIModerationChain(Chain): output_key: str = "output" #: :meta private: openai_api_key: Optional[str] = None openai_organization: Optional[str] = None + _openai_pre_1_0: bool = Field(default=None) @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -52,7 +58,16 @@ class OpenAIModerationChain(Chain): openai.api_key = openai_api_key if openai_organization: openai.organization = openai_organization - values["client"] = openai.Moderation # type: ignore + values["_openai_pre_1_0"] = False + try: + check_package_version("openai", gte_version="1.0") + except ValueError: + values["_openai_pre_1_0"] = True + if values["_openai_pre_1_0"]: + values["client"] = openai.Moderation + else: + values["client"] = openai.OpenAI() + values["async_client"] = openai.AsyncOpenAI() except ImportError: raise ImportError( "Could not import openai python package. " @@ -76,8 +91,12 @@ class OpenAIModerationChain(Chain): """ return [self.output_key] - def _moderate(self, text: str, results: dict) -> str: - if results["flagged"]: + def _moderate(self, text: str, results: Any) -> str: + if self._openai_pre_1_0: + condition = results["flagged"] + else: + condition = results.flagged + if condition: error_str = "Text was found that violates OpenAI's content policy." if self.error: raise ValueError(error_str) @@ -87,10 +106,26 @@ class OpenAIModerationChain(Chain): def _call( self, - inputs: Dict[str, str], + inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: + ) -> Dict[str, Any]: + text = inputs[self.input_key] + if self._openai_pre_1_0: + results = self.client.create(text) + output = self._moderate(text, results["results"][0]) + else: + results = self.client.moderations.create(input=text) + output = self._moderate(text, results.results[0]) + return {self.output_key: output} + + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + if self._openai_pre_1_0: + return await super()._acall(inputs, run_manager=run_manager) text = inputs[self.input_key] - results = self.client.create(text) - output = self._moderate(text, results["results"][0]) + results = await self.async_client.moderations.create(input=text) + output = self._moderate(text, results.results[0]) return {self.output_key: output}