JsonOutputFunctionParser: Fix mutation in place bug (#8758)

Fixes mutation in place in the JsonOutputFunctionParser. This causes
issues when trying to re-use the original AI message.
pull/8902/head
Eugene Yurtsev 11 months ago committed by GitHub
parent ab47557db3
commit f616aee35a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,4 @@
import copy
import json
from typing import Any, Dict, List, Type, Union
@ -25,8 +26,8 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
)
message = generation.message
try:
func_call = message.additional_kwargs["function_call"]
except ValueError as exc:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc:
raise OutputParserException(f"Could not parse function call: {exc}")
if self.args_only:
@ -38,11 +39,16 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as the Json object."""
def parse_result(self, result: List[Generation]) -> Any:
func = super().parse_result(result)
function_call_info = super().parse_result(result)
if self.args_only:
return json.loads(func)
func["arguments"] = json.loads(func["arguments"])
return func
try:
return json.loads(function_call_info)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
function_call_info["arguments"] = json.loads(function_call_info["arguments"])
return function_call_info
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):

@ -0,0 +1,76 @@
import json
import pytest
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
)
from langchain.schema import BaseMessage, ChatGeneration, OutputParserException
from langchain.schema.messages import AIMessage, HumanMessage
@pytest.fixture
def ai_message() -> AIMessage:
"""Return a simple AIMessage."""
content = "This is a test message"
args = json.dumps(
{
"arg1": "value1",
}
)
function_call = {"name": "function_name", "arguments": args}
additional_kwargs = {"function_call": function_call}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
def test_json_output_function_parser(ai_message: AIMessage) -> None:
"""Test that the JsonOutputFunctionsParser with full output."""
chat_generation = ChatGeneration(message=ai_message)
# Full output
parser = JsonOutputFunctionsParser(args_only=False)
result = parser.parse_result([chat_generation])
assert result == {"arguments": {"arg1": "value1"}, "name": "function_name"}
# Args only
parser = JsonOutputFunctionsParser(args_only=True)
result = parser.parse_result([chat_generation])
assert result == {"arg1": "value1"}
# Verify that the original message is not modified
assert ai_message.additional_kwargs == {
"function_call": {"name": "function_name", "arguments": '{"arg1": "value1"}'}
}
@pytest.mark.parametrize(
"bad_message",
[
# Human message has no function call
HumanMessage(content="This is a test message"),
# AIMessage has no function call information.
AIMessage(content="This is a test message", additional_kwargs={}),
# Bad function call information (arguments should be a string)
AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": {}}
},
),
# Bad function call information (arguments should be proper json)
AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": "noqweqwe"}
},
),
],
)
def test_exceptions_raised_while_parsing(bad_message: BaseMessage) -> None:
"""Test exceptions raised correctly while using JSON parser."""
chat_generation = ChatGeneration(message=bad_message)
with pytest.raises(OutputParserException):
JsonOutputFunctionsParser().parse_result([chat_generation])
Loading…
Cancel
Save