diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index aafaedb67d..946aeb6408 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -7,9 +7,10 @@ from typing import Any, Callable, List, Optional import jsonpatch -from langchain.schema import BaseOutputParser, OutputParserException -from langchain.schema.output import ChatGeneration, Generation -from langchain.schema.output_parser import BaseCumulativeTransformOutputParser +from langchain.schema.output_parser import ( + BaseCumulativeTransformOutputParser, + OutputParserException, +) def _replace_new_line(match: re.Match[str]) -> str: @@ -44,10 +45,10 @@ def _custom_parser(multiline_string: str) -> str: # Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py # MIT License -def parse_partial_json(s: str) -> Any: +def parse_partial_json(s: str, *, strict: bool = False) -> Any: # Attempt to parse the string as-is. try: - return json.loads(s) + return json.loads(s, strict=strict) except json.JSONDecodeError: pass @@ -97,7 +98,7 @@ def parse_partial_json(s: str) -> Any: # Attempt to parse the modified string as JSON. try: - return json.loads(new_s) + return json.loads(new_s, strict=strict) except json.JSONDecodeError: # If we still can't parse the string as JSON, return None to indicate failure. return None @@ -162,62 +163,26 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: return json_obj -class SimpleJsonOutputParser(BaseOutputParser[Any]): - """Parse the output of an LLM call to a JSON object.""" +class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): + """Parse the output of an LLM call to a JSON object. - def parse(self, text: str) -> Any: - text = text.strip() - try: - return parse_partial_json(text) - except JSONDecodeError as e: - raise OutputParserException(f"Invalid json output: {text}") from e - - @property - def _type(self) -> str: - return "simple_json_output_parser" - - -class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): - @property - def _type(self) -> str: - return "partial_functions_json" + When used in streaming mode, it will yield partial JSON objects containing + all the keys that have been returned so far. - def parse_result(self, result: List[Generation]) -> Any: - if len(result) != 1: - raise OutputParserException( - f"Expected exactly one result, but got {len(result)}" - ) - generation = result[0] - if not isinstance(generation, ChatGeneration): - raise OutputParserException( - "This output parser can only be used with a chat generation." - ) - message = generation.message - try: - function_call = message.additional_kwargs["function_call"] - except KeyError: - return None - try: - return parse_partial_json(function_call["arguments"]) - except KeyError: - return None + In streaming, if `diff` is set to `True`, yields JSONPatch operations + describing the difference between the previous and the current object. + """ def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch - # This method would be called by the default implementation of `parse_result` - # but we're overriding that method so it's not needed. def parse(self, text: str) -> Any: - raise NotImplementedError() - + text = text.strip() + try: + return parse_json_markdown(text.strip(), parse_partial_json) + except JSONDecodeError as e: + raise OutputParserException(f"Invalid json output: {text}") from e -class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): @property def _type(self) -> str: - return "partial_functions_json" - - def _diff(self, prev: Optional[Any], next: Any) -> Any: - return jsonpatch.make_patch(prev, next).patch - - def parse(self, text: str) -> Any: - return parse_json_markdown(text, parse_partial_json) + return "simple_json_output_parser" diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index cabafd599d..f0016b3e33 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -1,14 +1,20 @@ import copy import json -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, List, Optional, Type, Union +import jsonpatch + +from langchain.output_parsers.json import parse_partial_json from langchain.pydantic_v1 import BaseModel, root_validator from langchain.schema import ( ChatGeneration, Generation, OutputParserException, ) -from langchain.schema.output_parser import BaseGenerationOutputParser +from langchain.schema.output_parser import ( + BaseCumulativeTransformOutputParser, + BaseGenerationOutputParser, +) class OutputFunctionsParser(BaseGenerationOutputParser[Any]): @@ -34,7 +40,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): return func_call -class JsonOutputFunctionsParser(OutputFunctionsParser): +class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): """Parse an output as the Json object.""" strict: bool = False @@ -45,25 +51,42 @@ class JsonOutputFunctionsParser(OutputFunctionsParser): Useful when the parsed output may include unicode characters or new lines. """ + args_only: bool = True + """Whether to only return the arguments to the function call.""" + + def _diff(self, prev: Optional[Any], next: Any) -> Any: + return jsonpatch.make_patch(prev, next).patch + def parse_result(self, result: List[Generation]) -> Any: - function_call_info = super().parse_result(result) - if self.args_only: - try: - return json.loads(function_call_info, strict=self.strict) - except (json.JSONDecodeError, TypeError) as exc: - raise OutputParserException( - f"Could not parse function call data: {exc}" - ) - else: - try: - function_call_info["arguments"] = json.loads( - function_call_info["arguments"], strict=self.strict - ) - except (json.JSONDecodeError, TypeError) as exc: - raise OutputParserException( - f"Could not parse function call data: {exc}" - ) - return function_call_info + if len(result) != 1: + raise OutputParserException( + f"Expected exactly one result, but got {len(result)}" + ) + generation = result[0] + if not isinstance(generation, ChatGeneration): + raise OutputParserException( + "This output parser can only be used with a chat generation." + ) + message = generation.message + try: + function_call = message.additional_kwargs["function_call"] + except KeyError: + return None + try: + if self.args_only: + return parse_partial_json(function_call["arguments"]) + else: + return { + **function_call, + "arguments": parse_partial_json(function_call["arguments"]), + } + except KeyError: + return None + + # This method would be called by the default implementation of `parse_result` + # but we're overriding that method so it's not needed. + def parse(self, text: str) -> Any: + raise NotImplementedError() class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 6d2e388893..5858730256 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -338,6 +338,9 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): """Base class for an output parser that can handle streaming input.""" diff: bool = False + """In streaming mode, whether to yield diffs between the previous and current + parsed output, or just the current parsed output. + """ def _diff(self, prev: Optional[T], next: T) -> T: """Convert parsed outputs into a diff format. The semantics of this are diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index b9daee1a51..90b2d5a7da 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -4,12 +4,12 @@ from typing import Any, AsyncIterator, Iterator, Tuple import pytest from langchain.output_parsers.json import ( - PartialFunctionsJsonOutputParser, - PartialJsonOutputParser, + SimpleJsonOutputParser, parse_json_markdown, parse_partial_json, ) from langchain.schema.messages import AIMessageChunk +from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser GOOD_JSON = """```json { @@ -455,7 +455,7 @@ def test_partial_text_json_output_parser() -> None: for token in STREAMED_TOKENS: yield token - chain = input_iter | PartialJsonOutputParser() + chain = input_iter | SimpleJsonOutputParser() assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON @@ -467,7 +467,7 @@ def test_partial_functions_json_output_parser() -> None: content="", additional_kwargs={"function_call": {"arguments": token}} ) - chain = input_iter | PartialFunctionsJsonOutputParser() + chain = input_iter | JsonOutputFunctionsParser() assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON @@ -477,7 +477,7 @@ def test_partial_text_json_output_parser_diff() -> None: for token in STREAMED_TOKENS: yield token - chain = input_iter | PartialJsonOutputParser(diff=True) + chain = input_iter | SimpleJsonOutputParser(diff=True) assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF @@ -489,7 +489,7 @@ def test_partial_functions_json_output_parser_diff() -> None: content="", additional_kwargs={"function_call": {"arguments": token}} ) - chain = input_iter | PartialFunctionsJsonOutputParser(diff=True) + chain = input_iter | JsonOutputFunctionsParser(diff=True) assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF @@ -500,7 +500,7 @@ async def test_partial_text_json_output_parser_async() -> None: for token in STREAMED_TOKENS: yield token - chain = input_iter | PartialJsonOutputParser() + chain = input_iter | SimpleJsonOutputParser() assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON @@ -513,7 +513,7 @@ async def test_partial_functions_json_output_parser_async() -> None: content="", additional_kwargs={"function_call": {"arguments": token}} ) - chain = input_iter | PartialFunctionsJsonOutputParser() + chain = input_iter | JsonOutputFunctionsParser() assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON @@ -524,7 +524,7 @@ async def test_partial_text_json_output_parser_diff_async() -> None: for token in STREAMED_TOKENS: yield token - chain = input_iter | PartialJsonOutputParser(diff=True) + chain = input_iter | SimpleJsonOutputParser(diff=True) assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF @@ -537,6 +537,6 @@ async def test_partial_functions_json_output_parser_diff_async() -> None: content="", additional_kwargs={"function_call": {"arguments": token}} ) - chain = input_iter | PartialFunctionsJsonOutputParser(diff=True) + chain = input_iter | JsonOutputFunctionsParser(diff=True) assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF