mirror of https://github.com/hwchase17/langchain
Move json and xml parsers to core (#15026)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->pull/15036/head
parent
d5533b7081
commit
71076cceaf
@ -0,0 +1,195 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
import jsonpatch # type: ignore[import]
|
||||||
|
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_new_line(match: re.Match[str]) -> str:
|
||||||
|
value = match.group(2)
|
||||||
|
value = re.sub(r"\n", r"\\n", value)
|
||||||
|
value = re.sub(r"\r", r"\\r", value)
|
||||||
|
value = re.sub(r"\t", r"\\t", value)
|
||||||
|
value = re.sub(r'(?<!\\)"', r"\"", value)
|
||||||
|
|
||||||
|
return match.group(1) + value + match.group(3)
|
||||||
|
|
||||||
|
|
||||||
|
def _custom_parser(multiline_string: str) -> str:
|
||||||
|
"""
|
||||||
|
The LLM response for `action_input` may be a multiline
|
||||||
|
string containing unescaped newlines, tabs or quotes. This function
|
||||||
|
replaces those characters with their escaped counterparts.
|
||||||
|
(newlines in JSON must be double-escaped: `\\n`)
|
||||||
|
"""
|
||||||
|
if isinstance(multiline_string, (bytes, bytearray)):
|
||||||
|
multiline_string = multiline_string.decode()
|
||||||
|
|
||||||
|
multiline_string = re.sub(
|
||||||
|
r'("action_input"\:\s*")(.*)(")',
|
||||||
|
_replace_new_line,
|
||||||
|
multiline_string,
|
||||||
|
flags=re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
return multiline_string
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
|
||||||
|
# MIT License
|
||||||
|
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
||||||
|
"""Parse a JSON string that may be missing closing braces.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: The JSON string to parse.
|
||||||
|
strict: Whether to use strict parsing. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed JSON object as a Python dictionary.
|
||||||
|
"""
|
||||||
|
# Attempt to parse the string as-is.
|
||||||
|
try:
|
||||||
|
return json.loads(s, strict=strict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Initialize variables.
|
||||||
|
new_s = ""
|
||||||
|
stack = []
|
||||||
|
is_inside_string = False
|
||||||
|
escaped = False
|
||||||
|
|
||||||
|
# Process each character in the string one at a time.
|
||||||
|
for char in s:
|
||||||
|
if is_inside_string:
|
||||||
|
if char == '"' and not escaped:
|
||||||
|
is_inside_string = False
|
||||||
|
elif char == "\n" and not escaped:
|
||||||
|
char = "\\n" # Replace the newline character with the escape sequence.
|
||||||
|
elif char == "\\":
|
||||||
|
escaped = not escaped
|
||||||
|
else:
|
||||||
|
escaped = False
|
||||||
|
else:
|
||||||
|
if char == '"':
|
||||||
|
is_inside_string = True
|
||||||
|
escaped = False
|
||||||
|
elif char == "{":
|
||||||
|
stack.append("}")
|
||||||
|
elif char == "[":
|
||||||
|
stack.append("]")
|
||||||
|
elif char == "}" or char == "]":
|
||||||
|
if stack and stack[-1] == char:
|
||||||
|
stack.pop()
|
||||||
|
else:
|
||||||
|
# Mismatched closing character; the input is malformed.
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Append the processed character to the new string.
|
||||||
|
new_s += char
|
||||||
|
|
||||||
|
# If we're still inside a string at the end of processing,
|
||||||
|
# we need to close the string.
|
||||||
|
if is_inside_string:
|
||||||
|
new_s += '"'
|
||||||
|
|
||||||
|
# Close any remaining open structures in the reverse order that they were opened.
|
||||||
|
for closing_char in reversed(stack):
|
||||||
|
new_s += closing_char
|
||||||
|
|
||||||
|
# Attempt to parse the modified string as JSON.
|
||||||
|
try:
|
||||||
|
return json.loads(new_s, strict=strict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If we still can't parse the string as JSON, return None to indicate failure.
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_markdown(
|
||||||
|
json_string: str, *, parser: Callable[[str], Any] = json.loads
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Parse a JSON string from a Markdown string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_string: The Markdown string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed JSON object as a Python dictionary.
|
||||||
|
"""
|
||||||
|
# Try to find JSON string within triple backticks
|
||||||
|
match = re.search(r"```(json)?(.*)```", json_string, re.DOTALL)
|
||||||
|
|
||||||
|
# If no match found, assume the entire string is a JSON string
|
||||||
|
if match is None:
|
||||||
|
json_str = json_string
|
||||||
|
else:
|
||||||
|
# If match found, use the content within the backticks
|
||||||
|
json_str = match.group(2)
|
||||||
|
|
||||||
|
# Strip whitespace and newlines from the start and end
|
||||||
|
json_str = json_str.strip()
|
||||||
|
|
||||||
|
# handle newlines and other special characters inside the returned value
|
||||||
|
json_str = _custom_parser(json_str)
|
||||||
|
|
||||||
|
# Parse the JSON string into a Python dictionary
|
||||||
|
parsed = parser(json_str)
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||||
|
"""
|
||||||
|
Parse a JSON string from a Markdown string and check that it
|
||||||
|
contains the expected keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The Markdown string.
|
||||||
|
expected_keys: The expected keys in the JSON string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed JSON object as a Python dictionary.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
json_obj = parse_json_markdown(text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||||
|
for key in expected_keys:
|
||||||
|
if key not in json_obj:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Got invalid return object. Expected key `{key}` "
|
||||||
|
f"to be present, but got {json_obj}"
|
||||||
|
)
|
||||||
|
return json_obj
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||||
|
"""Parse the output of an LLM call to a JSON object.
|
||||||
|
|
||||||
|
When used in streaming mode, it will yield partial JSON objects containing
|
||||||
|
all the keys that have been returned so far.
|
||||||
|
|
||||||
|
In streaming, if `diff` is set to `True`, yields JSONPatch operations
|
||||||
|
describing the difference between the previous and the current object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||||
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Any:
|
||||||
|
text = text.strip()
|
||||||
|
try:
|
||||||
|
return parse_json_markdown(text.strip(), parser=parse_partial_json)
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
raise OutputParserException(f"Invalid json output: {text}") from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "simple_json_output_parser"
|
@ -0,0 +1,135 @@
|
|||||||
|
import re
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||||
|
from langchain_core.runnables.utils import AddableDict
|
||||||
|
|
||||||
|
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
|
||||||
|
1. Output should conform to the tags below.
|
||||||
|
2. If tags are not given, make them on your own.
|
||||||
|
3. Remember to always open and close all the tags.
|
||||||
|
|
||||||
|
As an example, for the tags ["foo", "bar", "baz"]:
|
||||||
|
1. String "<foo>\n <bar>\n <baz></baz>\n </bar>\n</foo>" is a well-formatted instance of the schema.
|
||||||
|
2. String "<foo>\n <bar>\n </foo>" is a badly-formatted instance.
|
||||||
|
3. String "<foo>\n <tag>\n </tag>\n</foo>" is a badly-formatted instance.
|
||||||
|
|
||||||
|
Here are the output tags:
|
||||||
|
```
|
||||||
|
{tags}
|
||||||
|
```""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
class XMLOutputParser(BaseTransformOutputParser):
|
||||||
|
"""Parse an output using xml format."""
|
||||||
|
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
encoding_matcher: re.Pattern = re.compile(
|
||||||
|
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Dict[str, List[Any]]:
|
||||||
|
text = text.strip("`").strip("xml")
|
||||||
|
encoding_match = self.encoding_matcher.search(text)
|
||||||
|
if encoding_match:
|
||||||
|
text = encoding_match.group(2)
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
if (text.startswith("<") or text.startswith("\n<")) and (
|
||||||
|
text.endswith(">") or text.endswith(">\n")
|
||||||
|
):
|
||||||
|
root = ET.fromstring(text)
|
||||||
|
return self._root_to_dict(root)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not parse output: {text}")
|
||||||
|
|
||||||
|
def _transform(
|
||||||
|
self, input: Iterator[Union[str, BaseMessage]]
|
||||||
|
) -> Iterator[AddableDict]:
|
||||||
|
parser = ET.XMLPullParser(["start", "end"])
|
||||||
|
current_path: List[str] = []
|
||||||
|
current_path_has_children = False
|
||||||
|
for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessage):
|
||||||
|
# extract text
|
||||||
|
chunk_content = chunk.content
|
||||||
|
if not isinstance(chunk_content, str):
|
||||||
|
continue
|
||||||
|
chunk = chunk_content
|
||||||
|
# pass chunk to parser
|
||||||
|
parser.feed(chunk)
|
||||||
|
# yield all events
|
||||||
|
for event, elem in parser.read_events():
|
||||||
|
if event == "start":
|
||||||
|
# update current path
|
||||||
|
current_path.append(elem.tag)
|
||||||
|
current_path_has_children = False
|
||||||
|
elif event == "end":
|
||||||
|
# remove last element from current path
|
||||||
|
current_path.pop()
|
||||||
|
# yield element
|
||||||
|
if not current_path_has_children:
|
||||||
|
yield nested_element(current_path, elem)
|
||||||
|
# prevent yielding of parent element
|
||||||
|
current_path_has_children = True
|
||||||
|
# close parser
|
||||||
|
parser.close()
|
||||||
|
|
||||||
|
async def _atransform(
|
||||||
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||||
|
) -> AsyncIterator[AddableDict]:
|
||||||
|
parser = ET.XMLPullParser(["start", "end"])
|
||||||
|
current_path: List[str] = []
|
||||||
|
current_path_has_children = False
|
||||||
|
async for chunk in input:
|
||||||
|
if isinstance(chunk, BaseMessage):
|
||||||
|
# extract text
|
||||||
|
chunk_content = chunk.content
|
||||||
|
if not isinstance(chunk_content, str):
|
||||||
|
continue
|
||||||
|
chunk = chunk_content
|
||||||
|
# pass chunk to parser
|
||||||
|
parser.feed(chunk)
|
||||||
|
# yield all events
|
||||||
|
for event, elem in parser.read_events():
|
||||||
|
if event == "start":
|
||||||
|
# update current path
|
||||||
|
current_path.append(elem.tag)
|
||||||
|
current_path_has_children = False
|
||||||
|
elif event == "end":
|
||||||
|
# remove last element from current path
|
||||||
|
current_path.pop()
|
||||||
|
# yield element
|
||||||
|
if not current_path_has_children:
|
||||||
|
yield nested_element(current_path, elem)
|
||||||
|
# prevent yielding of parent element
|
||||||
|
current_path_has_children = True
|
||||||
|
# close parser
|
||||||
|
parser.close()
|
||||||
|
|
||||||
|
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
|
||||||
|
"""Converts xml tree to python dictionary."""
|
||||||
|
result: Dict[str, List[Any]] = {root.tag: []}
|
||||||
|
for child in root:
|
||||||
|
if len(child) == 0:
|
||||||
|
result[root.tag].append({child.tag: child.text})
|
||||||
|
else:
|
||||||
|
result[root.tag].append(self._root_to_dict(child))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "xml"
|
||||||
|
|
||||||
|
|
||||||
|
def nested_element(path: List[str], elem: ET.Element) -> Any:
|
||||||
|
"""Get nested element from path."""
|
||||||
|
if len(path) == 0:
|
||||||
|
return AddableDict({elem.tag: elem.text})
|
||||||
|
else:
|
||||||
|
return AddableDict({path[0]: [nested_element(path[1:], elem)]})
|
@ -0,0 +1,488 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, AsyncIterator, Iterator, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core.output_parsers.json import (
|
||||||
|
SimpleJsonOutputParser,
|
||||||
|
parse_json_markdown,
|
||||||
|
parse_partial_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
GOOD_JSON = """```json
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
JSON_WITH_NEW_LINES = """
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
JSON_WITH_NEW_LINES_INSIDE = """```json
|
||||||
|
{
|
||||||
|
|
||||||
|
"foo": "bar"
|
||||||
|
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
JSON_WITH_NEW_LINES_EVERYWHERE = """
|
||||||
|
|
||||||
|
```json
|
||||||
|
|
||||||
|
{
|
||||||
|
|
||||||
|
"foo": "bar"
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
TICKS_WITH_NEW_LINES_EVERYWHERE = """
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
{
|
||||||
|
|
||||||
|
"foo": "bar"
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
JSON_WITH_MARKDOWN_CODE_BLOCK = """```json
|
||||||
|
{
|
||||||
|
"foo": "```bar```"
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
JSON_WITH_MARKDOWN_CODE_BLOCK_AND_NEWLINES = """```json
|
||||||
|
{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": "```bar\n<div id="1" class=\"value\">\n\ttext\n</div>```"
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
JSON_WITH_UNESCAPED_QUOTES_IN_NESTED_JSON = """```json
|
||||||
|
{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": "{"foo": "bar", "bar": "foo"}"
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
JSON_WITH_ESCAPED_QUOTES_IN_NESTED_JSON = """```json
|
||||||
|
{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": "{\"foo\": \"bar\", \"bar\": \"foo\"}"
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
JSON_WITH_PYTHON_DICT = """```json
|
||||||
|
{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": {"foo": "bar", "bar": "foo"}
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
JSON_WITH_ESCAPED_DOUBLE_QUOTES_IN_NESTED_JSON = """```json
|
||||||
|
{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": "{\\"foo\\": \\"bar\\", \\"bar\\": \\"foo\\"}"
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
NO_TICKS = """{
|
||||||
|
"foo": "bar"
|
||||||
|
}"""
|
||||||
|
|
||||||
|
NO_TICKS_WHITE_SPACE = """
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
TEXT_BEFORE = """Thought: I need to use the search tool
|
||||||
|
|
||||||
|
Action:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
TEXT_AFTER = """```
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
This should do the trick"""
|
||||||
|
|
||||||
|
TEXT_BEFORE_AND_AFTER = """Action: Testing
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
This should do the trick"""
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
GOOD_JSON,
|
||||||
|
JSON_WITH_NEW_LINES,
|
||||||
|
JSON_WITH_NEW_LINES_INSIDE,
|
||||||
|
JSON_WITH_NEW_LINES_EVERYWHERE,
|
||||||
|
TICKS_WITH_NEW_LINES_EVERYWHERE,
|
||||||
|
NO_TICKS,
|
||||||
|
NO_TICKS_WHITE_SPACE,
|
||||||
|
TEXT_BEFORE,
|
||||||
|
TEXT_AFTER,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("json_string", TEST_CASES)
|
||||||
|
def test_parse_json(json_string: str) -> None:
|
||||||
|
parsed = parse_json_markdown(json_string)
|
||||||
|
assert parsed == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_with_code_blocks() -> None:
|
||||||
|
parsed = parse_json_markdown(JSON_WITH_MARKDOWN_CODE_BLOCK)
|
||||||
|
assert parsed == {"foo": "```bar```"}
|
||||||
|
|
||||||
|
parsed = parse_json_markdown(JSON_WITH_MARKDOWN_CODE_BLOCK_AND_NEWLINES)
|
||||||
|
|
||||||
|
assert parsed == {
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": '```bar\n<div id="1" class="value">\n\ttext\n</div>```',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASES_ESCAPED_QUOTES = [
|
||||||
|
JSON_WITH_UNESCAPED_QUOTES_IN_NESTED_JSON,
|
||||||
|
JSON_WITH_ESCAPED_QUOTES_IN_NESTED_JSON,
|
||||||
|
JSON_WITH_ESCAPED_DOUBLE_QUOTES_IN_NESTED_JSON,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("json_string", TEST_CASES_ESCAPED_QUOTES)
|
||||||
|
def test_parse_nested_json_with_escaped_quotes(json_string: str) -> None:
|
||||||
|
parsed = parse_json_markdown(json_string)
|
||||||
|
assert parsed == {
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": '{"foo": "bar", "bar": "foo"}',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_with_python_dict() -> None:
|
||||||
|
parsed = parse_json_markdown(JSON_WITH_PYTHON_DICT)
|
||||||
|
assert parsed == {
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": {"foo": "bar", "bar": "foo"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASES_PARTIAL = [
|
||||||
|
('{"foo": "bar", "bar": "foo"}', '{"foo": "bar", "bar": "foo"}'),
|
||||||
|
('{"foo": "bar", "bar": "foo', '{"foo": "bar", "bar": "foo"}'),
|
||||||
|
('{"foo": "bar", "bar": "foo}', '{"foo": "bar", "bar": "foo}"}'),
|
||||||
|
('{"foo": "bar", "bar": "foo[', '{"foo": "bar", "bar": "foo["}'),
|
||||||
|
('{"foo": "bar", "bar": "foo\\"', '{"foo": "bar", "bar": "foo\\""}'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("json_strings", TEST_CASES_PARTIAL)
|
||||||
|
def test_parse_partial_json(json_strings: Tuple[str, str]) -> None:
|
||||||
|
case, expected = json_strings
|
||||||
|
parsed = parse_partial_json(case)
|
||||||
|
assert parsed == json.loads(expected)
|
||||||
|
|
||||||
|
|
||||||
|
STREAMED_TOKENS = """
|
||||||
|
{
|
||||||
|
|
||||||
|
"
|
||||||
|
setup
|
||||||
|
":
|
||||||
|
"
|
||||||
|
Why
|
||||||
|
did
|
||||||
|
the
|
||||||
|
bears
|
||||||
|
start
|
||||||
|
a
|
||||||
|
band
|
||||||
|
called
|
||||||
|
Bears
|
||||||
|
Bears
|
||||||
|
Bears
|
||||||
|
?
|
||||||
|
"
|
||||||
|
,
|
||||||
|
"
|
||||||
|
punchline
|
||||||
|
":
|
||||||
|
"
|
||||||
|
Because
|
||||||
|
they
|
||||||
|
wanted
|
||||||
|
to
|
||||||
|
play
|
||||||
|
bear
|
||||||
|
-y
|
||||||
|
good
|
||||||
|
music
|
||||||
|
!
|
||||||
|
"
|
||||||
|
,
|
||||||
|
"
|
||||||
|
audience
|
||||||
|
":
|
||||||
|
[
|
||||||
|
"
|
||||||
|
Haha
|
||||||
|
"
|
||||||
|
,
|
||||||
|
"
|
||||||
|
So
|
||||||
|
funny
|
||||||
|
"
|
||||||
|
]
|
||||||
|
|
||||||
|
}
|
||||||
|
""".splitlines()
|
||||||
|
|
||||||
|
EXPECTED_STREAMED_JSON = [
|
||||||
|
{},
|
||||||
|
{"setup": ""},
|
||||||
|
{"setup": "Why"},
|
||||||
|
{"setup": "Why did"},
|
||||||
|
{"setup": "Why did the"},
|
||||||
|
{"setup": "Why did the bears"},
|
||||||
|
{"setup": "Why did the bears start"},
|
||||||
|
{"setup": "Why did the bears start a"},
|
||||||
|
{"setup": "Why did the bears start a band"},
|
||||||
|
{"setup": "Why did the bears start a band called"},
|
||||||
|
{"setup": "Why did the bears start a band called Bears"},
|
||||||
|
{"setup": "Why did the bears start a band called Bears Bears"},
|
||||||
|
{"setup": "Why did the bears start a band called Bears Bears Bears"},
|
||||||
|
{"setup": "Why did the bears start a band called Bears Bears Bears ?"},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because they",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because they wanted",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because they wanted to",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because they wanted to play",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because they wanted to play bear",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because they wanted to play bear -y",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because they wanted to play bear -y good",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because they wanted to play bear -y good music",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"punchline": "Because they wanted to play bear -y good music !",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"punchline": "Because they wanted to play bear -y good music !",
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"audience": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"punchline": "Because they wanted to play bear -y good music !",
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"audience": [""],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"punchline": "Because they wanted to play bear -y good music !",
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"audience": ["Haha"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"punchline": "Because they wanted to play bear -y good music !",
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"audience": ["Haha", ""],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"punchline": "Because they wanted to play bear -y good music !",
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"audience": ["Haha", "So"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"punchline": "Because they wanted to play bear -y good music !",
|
||||||
|
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
"audience": ["Haha", "So funny"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
EXPECTED_STREAMED_JSON_DIFF = [
|
||||||
|
[{"op": "replace", "path": "", "value": {}}],
|
||||||
|
[{"op": "add", "path": "/setup", "value": ""}],
|
||||||
|
[{"op": "replace", "path": "/setup", "value": "Why"}],
|
||||||
|
[{"op": "replace", "path": "/setup", "value": "Why did"}],
|
||||||
|
[{"op": "replace", "path": "/setup", "value": "Why did the"}],
|
||||||
|
[{"op": "replace", "path": "/setup", "value": "Why did the bears"}],
|
||||||
|
[{"op": "replace", "path": "/setup", "value": "Why did the bears start"}],
|
||||||
|
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a"}],
|
||||||
|
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a band"}],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/setup",
|
||||||
|
"value": "Why did the bears start a band called",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/setup",
|
||||||
|
"value": "Why did the bears start a band called Bears",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/setup",
|
||||||
|
"value": "Why did the bears start a band called Bears Bears",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/setup",
|
||||||
|
"value": "Why did the bears start a band called Bears Bears Bears",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/setup",
|
||||||
|
"value": "Why did the bears start a band called Bears Bears Bears ?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[{"op": "add", "path": "/punchline", "value": ""}],
|
||||||
|
[{"op": "replace", "path": "/punchline", "value": "Because"}],
|
||||||
|
[{"op": "replace", "path": "/punchline", "value": "Because they"}],
|
||||||
|
[{"op": "replace", "path": "/punchline", "value": "Because they wanted"}],
|
||||||
|
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to"}],
|
||||||
|
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to play"}],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/punchline",
|
||||||
|
"value": "Because they wanted to play bear",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/punchline",
|
||||||
|
"value": "Because they wanted to play bear -y",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/punchline",
|
||||||
|
"value": "Because they wanted to play bear -y good",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/punchline",
|
||||||
|
"value": "Because they wanted to play bear -y good music",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/punchline",
|
||||||
|
"value": "Because they wanted to play bear -y good music !",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[{"op": "add", "path": "/audience", "value": []}],
|
||||||
|
[{"op": "add", "path": "/audience/0", "value": ""}],
|
||||||
|
[{"op": "replace", "path": "/audience/0", "value": "Haha"}],
|
||||||
|
[{"op": "add", "path": "/audience/1", "value": ""}],
|
||||||
|
[{"op": "replace", "path": "/audience/1", "value": "So"}],
|
||||||
|
[{"op": "replace", "path": "/audience/1", "value": "So funny"}],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_text_json_output_parser() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[str]:
|
||||||
|
for token in STREAMED_TOKENS:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | SimpleJsonOutputParser()
|
||||||
|
|
||||||
|
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_text_json_output_parser_diff() -> None:
|
||||||
|
def input_iter(_: Any) -> Iterator[str]:
|
||||||
|
for token in STREAMED_TOKENS:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | SimpleJsonOutputParser(diff=True)
|
||||||
|
|
||||||
|
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_text_json_output_parser_async() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||||
|
for token in STREAMED_TOKENS:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | SimpleJsonOutputParser()
|
||||||
|
|
||||||
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
|
async def test_partial_text_json_output_parser_diff_async() -> None:
|
||||||
|
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||||
|
for token in STREAMED_TOKENS:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
chain = input_iter | SimpleJsonOutputParser(diff=True)
|
||||||
|
|
||||||
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
@ -1,7 +1,7 @@
|
|||||||
"""Test XMLOutputParser"""
|
"""Test XMLOutputParser"""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.output_parsers.xml import XMLOutputParser
|
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||||
|
|
||||||
DEF_RESULT_ENCODING = """<?xml version="1.0" encoding="UTF-8"?>
|
DEF_RESULT_ENCODING = """<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<foo>
|
<foo>
|
@ -1,194 +1,13 @@
|
|||||||
from __future__ import annotations
|
from langchain_core.output_parsers.json import (
|
||||||
|
SimpleJsonOutputParser,
|
||||||
import json
|
parse_and_check_json_markdown,
|
||||||
import re
|
parse_json_markdown,
|
||||||
from json import JSONDecodeError
|
parse_partial_json,
|
||||||
from typing import Any, Callable, List, Optional
|
)
|
||||||
|
|
||||||
import jsonpatch
|
__all__ = [
|
||||||
from langchain_core.exceptions import OutputParserException
|
"SimpleJsonOutputParser",
|
||||||
from langchain_core.output_parsers import BaseCumulativeTransformOutputParser
|
"parse_partial_json",
|
||||||
|
"parse_json_markdown",
|
||||||
|
"parse_and_check_json_markdown",
|
||||||
def _replace_new_line(match: re.Match[str]) -> str:
|
]
|
||||||
value = match.group(2)
|
|
||||||
value = re.sub(r"\n", r"\\n", value)
|
|
||||||
value = re.sub(r"\r", r"\\r", value)
|
|
||||||
value = re.sub(r"\t", r"\\t", value)
|
|
||||||
value = re.sub(r'(?<!\\)"', r"\"", value)
|
|
||||||
|
|
||||||
return match.group(1) + value + match.group(3)
|
|
||||||
|
|
||||||
|
|
||||||
def _custom_parser(multiline_string: str) -> str:
|
|
||||||
"""
|
|
||||||
The LLM response for `action_input` may be a multiline
|
|
||||||
string containing unescaped newlines, tabs or quotes. This function
|
|
||||||
replaces those characters with their escaped counterparts.
|
|
||||||
(newlines in JSON must be double-escaped: `\\n`)
|
|
||||||
"""
|
|
||||||
if isinstance(multiline_string, (bytes, bytearray)):
|
|
||||||
multiline_string = multiline_string.decode()
|
|
||||||
|
|
||||||
multiline_string = re.sub(
|
|
||||||
r'("action_input"\:\s*")(.*)(")',
|
|
||||||
_replace_new_line,
|
|
||||||
multiline_string,
|
|
||||||
flags=re.DOTALL,
|
|
||||||
)
|
|
||||||
|
|
||||||
return multiline_string
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
|
|
||||||
# MIT License
|
|
||||||
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
|
||||||
"""Parse a JSON string that may be missing closing braces.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
s: The JSON string to parse.
|
|
||||||
strict: Whether to use strict parsing. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The parsed JSON object as a Python dictionary.
|
|
||||||
"""
|
|
||||||
# Attempt to parse the string as-is.
|
|
||||||
try:
|
|
||||||
return json.loads(s, strict=strict)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Initialize variables.
|
|
||||||
new_s = ""
|
|
||||||
stack = []
|
|
||||||
is_inside_string = False
|
|
||||||
escaped = False
|
|
||||||
|
|
||||||
# Process each character in the string one at a time.
|
|
||||||
for char in s:
|
|
||||||
if is_inside_string:
|
|
||||||
if char == '"' and not escaped:
|
|
||||||
is_inside_string = False
|
|
||||||
elif char == "\n" and not escaped:
|
|
||||||
char = "\\n" # Replace the newline character with the escape sequence.
|
|
||||||
elif char == "\\":
|
|
||||||
escaped = not escaped
|
|
||||||
else:
|
|
||||||
escaped = False
|
|
||||||
else:
|
|
||||||
if char == '"':
|
|
||||||
is_inside_string = True
|
|
||||||
escaped = False
|
|
||||||
elif char == "{":
|
|
||||||
stack.append("}")
|
|
||||||
elif char == "[":
|
|
||||||
stack.append("]")
|
|
||||||
elif char == "}" or char == "]":
|
|
||||||
if stack and stack[-1] == char:
|
|
||||||
stack.pop()
|
|
||||||
else:
|
|
||||||
# Mismatched closing character; the input is malformed.
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Append the processed character to the new string.
|
|
||||||
new_s += char
|
|
||||||
|
|
||||||
# If we're still inside a string at the end of processing,
|
|
||||||
# we need to close the string.
|
|
||||||
if is_inside_string:
|
|
||||||
new_s += '"'
|
|
||||||
|
|
||||||
# Close any remaining open structures in the reverse order that they were opened.
|
|
||||||
for closing_char in reversed(stack):
|
|
||||||
new_s += closing_char
|
|
||||||
|
|
||||||
# Attempt to parse the modified string as JSON.
|
|
||||||
try:
|
|
||||||
return json.loads(new_s, strict=strict)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# If we still can't parse the string as JSON, return None to indicate failure.
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def parse_json_markdown(
|
|
||||||
json_string: str, *, parser: Callable[[str], Any] = json.loads
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Parse a JSON string from a Markdown string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_string: The Markdown string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The parsed JSON object as a Python dictionary.
|
|
||||||
"""
|
|
||||||
# Try to find JSON string within triple backticks
|
|
||||||
match = re.search(r"```(json)?(.*)```", json_string, re.DOTALL)
|
|
||||||
|
|
||||||
# If no match found, assume the entire string is a JSON string
|
|
||||||
if match is None:
|
|
||||||
json_str = json_string
|
|
||||||
else:
|
|
||||||
# If match found, use the content within the backticks
|
|
||||||
json_str = match.group(2)
|
|
||||||
|
|
||||||
# Strip whitespace and newlines from the start and end
|
|
||||||
json_str = json_str.strip()
|
|
||||||
|
|
||||||
# handle newlines and other special characters inside the returned value
|
|
||||||
json_str = _custom_parser(json_str)
|
|
||||||
|
|
||||||
# Parse the JSON string into a Python dictionary
|
|
||||||
parsed = parser(json_str)
|
|
||||||
|
|
||||||
return parsed
|
|
||||||
|
|
||||||
|
|
||||||
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
|
||||||
"""
|
|
||||||
Parse a JSON string from a Markdown string and check that it
|
|
||||||
contains the expected keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The Markdown string.
|
|
||||||
expected_keys: The expected keys in the JSON string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The parsed JSON object as a Python dictionary.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
json_obj = parse_json_markdown(text)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
|
||||||
for key in expected_keys:
|
|
||||||
if key not in json_obj:
|
|
||||||
raise OutputParserException(
|
|
||||||
f"Got invalid return object. Expected key `{key}` "
|
|
||||||
f"to be present, but got {json_obj}"
|
|
||||||
)
|
|
||||||
return json_obj
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
|
||||||
"""Parse the output of an LLM call to a JSON object.
|
|
||||||
|
|
||||||
When used in streaming mode, it will yield partial JSON objects containing
|
|
||||||
all the keys that have been returned so far.
|
|
||||||
|
|
||||||
In streaming, if `diff` is set to `True`, yields JSONPatch operations
|
|
||||||
describing the difference between the previous and the current object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
|
||||||
return jsonpatch.make_patch(prev, next).patch
|
|
||||||
|
|
||||||
def parse(self, text: str) -> Any:
|
|
||||||
text = text.strip()
|
|
||||||
try:
|
|
||||||
return parse_json_markdown(text.strip(), parser=parse_partial_json)
|
|
||||||
except JSONDecodeError as e:
|
|
||||||
raise OutputParserException(f"Invalid json output: {text}") from e
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> str:
|
|
||||||
return "simple_json_output_parser"
|
|
||||||
|
@ -1,122 +1,3 @@
|
|||||||
import re
|
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
__all__ = ["XMLOutputParser"]
|
||||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
|
||||||
from langchain_core.runnables.utils import AddableDict
|
|
||||||
|
|
||||||
from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS
|
|
||||||
|
|
||||||
|
|
||||||
class XMLOutputParser(BaseTransformOutputParser):
|
|
||||||
"""Parse an output using xml format."""
|
|
||||||
|
|
||||||
tags: Optional[List[str]] = None
|
|
||||||
encoding_matcher: re.Pattern = re.compile(
|
|
||||||
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
|
||||||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> Dict[str, List[Any]]:
|
|
||||||
text = text.strip("`").strip("xml")
|
|
||||||
encoding_match = self.encoding_matcher.search(text)
|
|
||||||
if encoding_match:
|
|
||||||
text = encoding_match.group(2)
|
|
||||||
|
|
||||||
text = text.strip()
|
|
||||||
if (text.startswith("<") or text.startswith("\n<")) and (
|
|
||||||
text.endswith(">") or text.endswith(">\n")
|
|
||||||
):
|
|
||||||
root = ET.fromstring(text)
|
|
||||||
return self._root_to_dict(root)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Could not parse output: {text}")
|
|
||||||
|
|
||||||
def _transform(
|
|
||||||
self, input: Iterator[Union[str, BaseMessage]]
|
|
||||||
) -> Iterator[AddableDict]:
|
|
||||||
parser = ET.XMLPullParser(["start", "end"])
|
|
||||||
current_path: List[str] = []
|
|
||||||
current_path_has_children = False
|
|
||||||
for chunk in input:
|
|
||||||
if isinstance(chunk, BaseMessage):
|
|
||||||
# extract text
|
|
||||||
chunk_content = chunk.content
|
|
||||||
if not isinstance(chunk_content, str):
|
|
||||||
continue
|
|
||||||
chunk = chunk_content
|
|
||||||
# pass chunk to parser
|
|
||||||
parser.feed(chunk)
|
|
||||||
# yield all events
|
|
||||||
for event, elem in parser.read_events():
|
|
||||||
if event == "start":
|
|
||||||
# update current path
|
|
||||||
current_path.append(elem.tag)
|
|
||||||
current_path_has_children = False
|
|
||||||
elif event == "end":
|
|
||||||
# remove last element from current path
|
|
||||||
current_path.pop()
|
|
||||||
# yield element
|
|
||||||
if not current_path_has_children:
|
|
||||||
yield nested_element(current_path, elem)
|
|
||||||
# prevent yielding of parent element
|
|
||||||
current_path_has_children = True
|
|
||||||
# close parser
|
|
||||||
parser.close()
|
|
||||||
|
|
||||||
async def _atransform(
|
|
||||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
|
||||||
) -> AsyncIterator[AddableDict]:
|
|
||||||
parser = ET.XMLPullParser(["start", "end"])
|
|
||||||
current_path: List[str] = []
|
|
||||||
current_path_has_children = False
|
|
||||||
async for chunk in input:
|
|
||||||
if isinstance(chunk, BaseMessage):
|
|
||||||
# extract text
|
|
||||||
chunk_content = chunk.content
|
|
||||||
if not isinstance(chunk_content, str):
|
|
||||||
continue
|
|
||||||
chunk = chunk_content
|
|
||||||
# pass chunk to parser
|
|
||||||
parser.feed(chunk)
|
|
||||||
# yield all events
|
|
||||||
for event, elem in parser.read_events():
|
|
||||||
if event == "start":
|
|
||||||
# update current path
|
|
||||||
current_path.append(elem.tag)
|
|
||||||
current_path_has_children = False
|
|
||||||
elif event == "end":
|
|
||||||
# remove last element from current path
|
|
||||||
current_path.pop()
|
|
||||||
# yield element
|
|
||||||
if not current_path_has_children:
|
|
||||||
yield nested_element(current_path, elem)
|
|
||||||
# prevent yielding of parent element
|
|
||||||
current_path_has_children = True
|
|
||||||
# close parser
|
|
||||||
parser.close()
|
|
||||||
|
|
||||||
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
|
|
||||||
"""Converts xml tree to python dictionary."""
|
|
||||||
result: Dict[str, List[Any]] = {root.tag: []}
|
|
||||||
for child in root:
|
|
||||||
if len(child) == 0:
|
|
||||||
result[root.tag].append({child.tag: child.text})
|
|
||||||
else:
|
|
||||||
result[root.tag].append(self._root_to_dict(child))
|
|
||||||
return result
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> str:
|
|
||||||
return "xml"
|
|
||||||
|
|
||||||
|
|
||||||
def nested_element(path: List[str], elem: ET.Element) -> Any:
|
|
||||||
"""Get nested element from path."""
|
|
||||||
if len(path) == 0:
|
|
||||||
return AddableDict({elem.tag: elem.text})
|
|
||||||
else:
|
|
||||||
return AddableDict({path[0]: [nested_element(path[1:], elem)]})
|
|
||||||
|
Loading…
Reference in New Issue