mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
langchain[patch]: fix moderation chain init (#25778)
[This
commit](d3ca2cc8c3
)
has broken the moderation chain so we've faced a crash when migrating
the LangChain from v0.1 to v0.2.
The issue appears that the class attribute the code refers to doesn't
hold the value processed in the `validate_environment` method. We had
`extras={}` in this attribute, and it was casted to `True` when it
should've been `False`. Adding a simple assignment seems to resolve the
issue, though I'm not sure it's the right way.
---
---------
Co-authored-by: Michael Rubél <mrubel@oroinc.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
63a1569d5f
commit
9decd0b243
@ -38,7 +38,7 @@ class OpenAIModerationChain(Chain):
|
||||
output_key: str = "output" #: :meta private:
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_organization: Optional[str] = None
|
||||
_openai_pre_1_0: bool = Field(default=None)
|
||||
openai_pre_1_0: bool = Field(default=None)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -58,16 +58,17 @@ class OpenAIModerationChain(Chain):
|
||||
openai.api_key = openai_api_key
|
||||
if openai_organization:
|
||||
openai.organization = openai_organization
|
||||
values["_openai_pre_1_0"] = False
|
||||
values["openai_pre_1_0"] = False
|
||||
try:
|
||||
check_package_version("openai", gte_version="1.0")
|
||||
except ValueError:
|
||||
values["_openai_pre_1_0"] = True
|
||||
if values["_openai_pre_1_0"]:
|
||||
values["openai_pre_1_0"] = True
|
||||
if values["openai_pre_1_0"]:
|
||||
values["client"] = openai.Moderation
|
||||
else:
|
||||
values["client"] = openai.OpenAI()
|
||||
values["async_client"] = openai.AsyncOpenAI()
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
@ -92,7 +93,7 @@ class OpenAIModerationChain(Chain):
|
||||
return [self.output_key]
|
||||
|
||||
def _moderate(self, text: str, results: Any) -> str:
|
||||
if self._openai_pre_1_0:
|
||||
if self.openai_pre_1_0:
|
||||
condition = results["flagged"]
|
||||
else:
|
||||
condition = results.flagged
|
||||
@ -110,7 +111,7 @@ class OpenAIModerationChain(Chain):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
text = inputs[self.input_key]
|
||||
if self._openai_pre_1_0:
|
||||
if self.openai_pre_1_0:
|
||||
results = self.client.create(text)
|
||||
output = self._moderate(text, results["results"][0])
|
||||
else:
|
||||
@ -123,7 +124,7 @@ class OpenAIModerationChain(Chain):
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if self._openai_pre_1_0:
|
||||
if self.openai_pre_1_0:
|
||||
return await super()._acall(inputs, run_manager=run_manager)
|
||||
text = inputs[self.input_key]
|
||||
results = await self.async_client.moderations.create(input=text)
|
||||
|
Loading…
Reference in New Issue
Block a user