From f616aee35aee6af3bd8ad52734055113b754ade3 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 7 Aug 2023 14:32:46 -0400 Subject: [PATCH] JsonOutputFunctionParser: Fix mutation in place bug (#8758) Fixes mutation in place in the JsonOutputFunctionParser. This causes issues when trying to re-use the original AI message. --- .../output_parsers/openai_functions.py | 18 +++-- .../output_parsers/test_openai_functions.py | 76 +++++++++++++++++++ 2 files changed, 88 insertions(+), 6 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index c55801c9bd..646d895962 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -1,3 +1,4 @@ +import copy import json from typing import Any, Dict, List, Type, Union @@ -25,8 +26,8 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): ) message = generation.message try: - func_call = message.additional_kwargs["function_call"] - except ValueError as exc: + func_call = copy.deepcopy(message.additional_kwargs["function_call"]) + except KeyError as exc: raise OutputParserException(f"Could not parse function call: {exc}") if self.args_only: @@ -38,11 +39,16 @@ class JsonOutputFunctionsParser(OutputFunctionsParser): """Parse an output as the Json object.""" def parse_result(self, result: List[Generation]) -> Any: - func = super().parse_result(result) + function_call_info = super().parse_result(result) if self.args_only: - return json.loads(func) - func["arguments"] = json.loads(func["arguments"]) - return func + try: + return json.loads(function_call_info) + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) + function_call_info["arguments"] = json.loads(function_call_info["arguments"]) + return function_call_info class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py b/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py new file mode 100644 index 0000000000..7c6411e61f --- /dev/null +++ b/libs/langchain/tests/unit_tests/output_parsers/test_openai_functions.py @@ -0,0 +1,76 @@ +import json + +import pytest + +from langchain.output_parsers.openai_functions import ( + JsonOutputFunctionsParser, +) +from langchain.schema import BaseMessage, ChatGeneration, OutputParserException +from langchain.schema.messages import AIMessage, HumanMessage + + +@pytest.fixture +def ai_message() -> AIMessage: + """Return a simple AIMessage.""" + content = "This is a test message" + + args = json.dumps( + { + "arg1": "value1", + } + ) + + function_call = {"name": "function_name", "arguments": args} + additional_kwargs = {"function_call": function_call} + return AIMessage(content=content, additional_kwargs=additional_kwargs) + + +def test_json_output_function_parser(ai_message: AIMessage) -> None: + """Test that the JsonOutputFunctionsParser with full output.""" + chat_generation = ChatGeneration(message=ai_message) + + # Full output + parser = JsonOutputFunctionsParser(args_only=False) + result = parser.parse_result([chat_generation]) + assert result == {"arguments": {"arg1": "value1"}, "name": "function_name"} + + # Args only + parser = JsonOutputFunctionsParser(args_only=True) + result = parser.parse_result([chat_generation]) + assert result == {"arg1": "value1"} + + # Verify that the original message is not modified + assert ai_message.additional_kwargs == { + "function_call": {"name": "function_name", "arguments": '{"arg1": "value1"}'} + } + + +@pytest.mark.parametrize( + "bad_message", + [ + # Human message has no function call + HumanMessage(content="This is a test message"), + # AIMessage has no function call information. + AIMessage(content="This is a test message", additional_kwargs={}), + # Bad function call information (arguments should be a string) + AIMessage( + content="This is a test message", + additional_kwargs={ + "function_call": {"name": "function_name", "arguments": {}} + }, + ), + # Bad function call information (arguments should be proper json) + AIMessage( + content="This is a test message", + additional_kwargs={ + "function_call": {"name": "function_name", "arguments": "noqweqwe"} + }, + ), + ], +) +def test_exceptions_raised_while_parsing(bad_message: BaseMessage) -> None: + """Test exceptions raised correctly while using JSON parser.""" + chat_generation = ChatGeneration(message=bad_message) + + with pytest.raises(OutputParserException): + JsonOutputFunctionsParser().parse_result([chat_generation])