core[patch]: Reverting changes with defusedXML (#19604)

DefusedXML is causing parsing errors on previously functional code with
the 0.7.x versions. These do not seem to support newer version of python
well. 0.8.x has only been released as rc, so we're not going to to use
it in the core package
pull/19610/head
Eugene Yurtsev 3 months ago committed by GitHub
parent 9ea2a9b0c1
commit 8bc5cdccee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,6 +1,6 @@
import re
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from xml.etree import ElementTree as ET
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
@ -35,10 +35,6 @@ class XMLOutputParser(BaseTransformOutputParser):
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]:
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml import ElementTree as DET # type: ignore[import]
# Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
@ -50,18 +46,18 @@ class XMLOutputParser(BaseTransformOutputParser):
text = text.strip()
try:
root = DET.fromstring(text)
root = ET.fromstring(text)
return self._root_to_dict(root)
except (DET.ParseError, DET.EntitiesForbidden) as e:
except ET.ParseError as e:
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"])
xml_started = False
current_path: List[str] = []
current_path_has_children = False
@ -87,7 +83,6 @@ class XMLOutputParser(BaseTransformOutputParser):
parser.feed(buffer)
buffer = ""
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
@ -111,11 +106,8 @@ class XMLOutputParser(BaseTransformOutputParser):
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
xml_start_re = re.compile(r"<[a-zA-Z:_]")
xml_started = False
current_path: List[str] = []
current_path_has_children = False
buffer = ""
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
@ -123,19 +115,8 @@ class XMLOutputParser(BaseTransformOutputParser):
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# add chunk to buffer of unprocessed text
buffer += chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started:
if match := xml_start_re.search(buffer):
# if xml string has started, remove all text before it
buffer = buffer[match.start() :]
xml_started = True
else:
continue
# feed buffer to parser
parser.feed(buffer)
buffer = ""
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
@ -149,10 +130,7 @@ class XMLOutputParser(BaseTransformOutputParser):
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
if current_path:
current_path_has_children = True
else:
xml_started = False
current_path_has_children = True
# close parser
parser.close()

@ -2966,4 +2966,4 @@ extended-testing = ["jinja2"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "2f61e22c118e13c40a1b7980afe06a37a6349ee239c948b9c49e8b1dc06facc1"
content-hash = "203d96b330412ce9defad6739381e4031fc9e995c2d9e0a61a905fc79fff11dd"

@ -18,7 +18,6 @@ PyYAML = ">=5.3"
requests = "^2"
packaging = "^23.2"
jinja2 = { version = "^3", optional = true }
defusedxml = "^0.7"
[tool.poetry.group.lint]
optional = true

@ -1,6 +1,4 @@
"""Test XMLOutputParser"""
from typing import AsyncIterator
import pytest
from langchain_core.exceptions import OutputParserException
@ -42,24 +40,14 @@ More random text
""",
],
)
async def test_xml_output_parser(result: str) -> None:
def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser."""
xml_parser = XMLOutputParser()
assert DEF_RESULT_EXPECTED == xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == (await xml_parser.aparse(result))
assert list(xml_parser.transform(iter(result))) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]
async def _as_iter(string: str) -> AsyncIterator[str]:
for c in string:
yield c
chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
assert chunks == [
xml_result = xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == xml_result
assert list(xml_parser.transform(iter(result))) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
@ -75,27 +63,3 @@ def test_xml_output_parser_fail(result: str) -> None:
with pytest.raises(OutputParserException) as e:
xml_parser.parse(result)
assert "Failed to parse" in str(e)
MALICIOUS_XML = """<?xml version="1.0"?>
<!DOCTYPE lolz [<!ENTITY lol "lol"><!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
<!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">
<!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">
<!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">
<!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">
<!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">
<!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">
<!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">
<!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">
]>
<lolz>&lol9;</lolz>"""
async def tests_billion_laughs_attack() -> None:
parser = XMLOutputParser()
with pytest.raises(OutputParserException):
parser.parse(MALICIOUS_XML)
with pytest.raises(OutputParserException):
await parser.aparse(MALICIOUS_XML)

Loading…
Cancel
Save