Combine with existing json output parsers

pull/11193/head
Nuno Campos 10 months ago
parent 4b8442896b
commit c9d0f2b984

@ -7,9 +7,10 @@ from typing import Any, Callable, List, Optional
import jsonpatch
from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema.output import ChatGeneration, Generation
from langchain.schema.output_parser import BaseCumulativeTransformOutputParser
from langchain.schema.output_parser import (
BaseCumulativeTransformOutputParser,
OutputParserException,
)
def _replace_new_line(match: re.Match[str]) -> str:
@ -44,10 +45,10 @@ def _custom_parser(multiline_string: str) -> str:
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s: str) -> Any:
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
# Attempt to parse the string as-is.
try:
return json.loads(s)
return json.loads(s, strict=strict)
except json.JSONDecodeError:
pass
@ -97,7 +98,7 @@ def parse_partial_json(s: str) -> Any:
# Attempt to parse the modified string as JSON.
try:
return json.loads(new_s)
return json.loads(new_s, strict=strict)
except json.JSONDecodeError:
# If we still can't parse the string as JSON, return None to indicate failure.
return None
@ -162,62 +163,26 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
return json_obj
class SimpleJsonOutputParser(BaseOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object."""
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object.
def parse(self, text: str) -> Any:
text = text.strip()
try:
return parse_partial_json(text)
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e
@property
def _type(self) -> str:
return "simple_json_output_parser"
class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
@property
def _type(self) -> str:
return "partial_functions_json"
When used in streaming mode, it will yield partial JSON objects containing
all the keys that have been returned so far.
def parse_result(self, result: List[Generation]) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError:
return None
try:
return parse_partial_json(function_call["arguments"])
except KeyError:
return None
In streaming, if `diff` is set to `True`, yields JSONPatch operations
describing the difference between the previous and the current object.
"""
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
raise NotImplementedError()
text = text.strip()
try:
return parse_json_markdown(text.strip(), parse_partial_json)
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e
class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
@property
def _type(self) -> str:
return "partial_functions_json"
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse(self, text: str) -> Any:
return parse_json_markdown(text, parse_partial_json)
return "simple_json_output_parser"

@ -1,14 +1,20 @@
import copy
import json
from typing import Any, Dict, List, Type, Union
from typing import Any, Dict, List, Optional, Type, Union
import jsonpatch
from langchain.output_parsers.json import parse_partial_json
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import (
ChatGeneration,
Generation,
OutputParserException,
)
from langchain.schema.output_parser import BaseGenerationOutputParser
from langchain.schema.output_parser import (
BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser,
)
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
@ -34,7 +40,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
return func_call
class JsonOutputFunctionsParser(OutputFunctionsParser):
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse an output as the Json object."""
strict: bool = False
@ -45,25 +51,42 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
Useful when the parsed output may include unicode characters or new lines.
"""
args_only: bool = True
"""Whether to only return the arguments to the function call."""
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation]) -> Any:
function_call_info = super().parse_result(result)
if self.args_only:
try:
return json.loads(function_call_info, strict=self.strict)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
function_call_info["arguments"] = json.loads(
function_call_info["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
return function_call_info
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError:
return None
try:
if self.args_only:
return parse_partial_json(function_call["arguments"])
else:
return {
**function_call,
"arguments": parse_partial_json(function_call["arguments"]),
}
except KeyError:
return None
# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
raise NotImplementedError()
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):

@ -338,6 +338,9 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
diff: bool = False
"""In streaming mode, whether to yield diffs between the previous and current
parsed output, or just the current parsed output.
"""
def _diff(self, prev: Optional[T], next: T) -> T:
"""Convert parsed outputs into a diff format. The semantics of this are

@ -4,12 +4,12 @@ from typing import Any, AsyncIterator, Iterator, Tuple
import pytest
from langchain.output_parsers.json import (
PartialFunctionsJsonOutputParser,
PartialJsonOutputParser,
SimpleJsonOutputParser,
parse_json_markdown,
parse_partial_json,
)
from langchain.schema.messages import AIMessageChunk
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
GOOD_JSON = """```json
{
@ -455,7 +455,7 @@ def test_partial_text_json_output_parser() -> None:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | PartialJsonOutputParser()
chain = input_iter | SimpleJsonOutputParser()
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
@ -467,7 +467,7 @@ def test_partial_functions_json_output_parser() -> None:
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | PartialFunctionsJsonOutputParser()
chain = input_iter | JsonOutputFunctionsParser()
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
@ -477,7 +477,7 @@ def test_partial_text_json_output_parser_diff() -> None:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | PartialJsonOutputParser(diff=True)
chain = input_iter | SimpleJsonOutputParser(diff=True)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
@ -489,7 +489,7 @@ def test_partial_functions_json_output_parser_diff() -> None:
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
chain = input_iter | JsonOutputFunctionsParser(diff=True)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
@ -500,7 +500,7 @@ async def test_partial_text_json_output_parser_async() -> None:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | PartialJsonOutputParser()
chain = input_iter | SimpleJsonOutputParser()
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@ -513,7 +513,7 @@ async def test_partial_functions_json_output_parser_async() -> None:
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | PartialFunctionsJsonOutputParser()
chain = input_iter | JsonOutputFunctionsParser()
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@ -524,7 +524,7 @@ async def test_partial_text_json_output_parser_diff_async() -> None:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | PartialJsonOutputParser(diff=True)
chain = input_iter | SimpleJsonOutputParser(diff=True)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
@ -537,6 +537,6 @@ async def test_partial_functions_json_output_parser_diff_async() -> None:
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
chain = input_iter | JsonOutputFunctionsParser(diff=True)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF

Loading…
Cancel
Save