diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 5a89c5f763..f74bd4c105 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,6 +1,6 @@ import re +import xml.etree.ElementTree as ET from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union -from xml.etree import ElementTree as ET from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage @@ -35,10 +35,6 @@ class XMLOutputParser(BaseTransformOutputParser): return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) def parse(self, text: str) -> Dict[str, List[Any]]: - # Imports are temporarily placed here to avoid issue with caching on CI - # likely if you're reading this you can move them to the top of the file - from defusedxml import ElementTree as DET # type: ignore[import] - # Try to find XML string within triple backticks match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: @@ -50,18 +46,18 @@ class XMLOutputParser(BaseTransformOutputParser): text = text.strip() try: - root = DET.fromstring(text) + root = ET.fromstring(text) return self._root_to_dict(root) - except (DET.ParseError, DET.EntitiesForbidden) as e: + except ET.ParseError as e: msg = f"Failed to parse XML format from completion {text}. Got: {e}" raise OutputParserException(msg, llm_output=text) from e def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: - parser = ET.XMLPullParser(["start", "end"]) xml_start_re = re.compile(r"<[a-zA-Z:_]") + parser = ET.XMLPullParser(["start", "end"]) xml_started = False current_path: List[str] = [] current_path_has_children = False @@ -87,7 +83,6 @@ class XMLOutputParser(BaseTransformOutputParser): parser.feed(buffer) buffer = "" # yield all events - for event, elem in parser.read_events(): if event == "start": # update current path @@ -111,11 +106,8 @@ class XMLOutputParser(BaseTransformOutputParser): self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[AddableDict]: parser = ET.XMLPullParser(["start", "end"]) - xml_start_re = re.compile(r"<[a-zA-Z:_]") - xml_started = False current_path: List[str] = [] current_path_has_children = False - buffer = "" async for chunk in input: if isinstance(chunk, BaseMessage): # extract text @@ -123,19 +115,8 @@ class XMLOutputParser(BaseTransformOutputParser): if not isinstance(chunk_content, str): continue chunk = chunk_content - # add chunk to buffer of unprocessed text - buffer += chunk - # if xml string hasn't started yet, continue to next chunk - if not xml_started: - if match := xml_start_re.search(buffer): - # if xml string has started, remove all text before it - buffer = buffer[match.start() :] - xml_started = True - else: - continue - # feed buffer to parser - parser.feed(buffer) - buffer = "" + # pass chunk to parser + parser.feed(chunk) # yield all events for event, elem in parser.read_events(): if event == "start": @@ -149,10 +130,7 @@ class XMLOutputParser(BaseTransformOutputParser): if not current_path_has_children: yield nested_element(current_path, elem) # prevent yielding of parent element - if current_path: - current_path_has_children = True - else: - xml_started = False + current_path_has_children = True # close parser parser.close() diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index fc812ce734..9495b9d4a8 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -2966,4 +2966,4 @@ extended-testing = ["jinja2"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "2f61e22c118e13c40a1b7980afe06a37a6349ee239c948b9c49e8b1dc06facc1" +content-hash = "203d96b330412ce9defad6739381e4031fc9e995c2d9e0a61a905fc79fff11dd" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index daa6b9bdb8..7f476fca7f 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -18,7 +18,6 @@ PyYAML = ">=5.3" requests = "^2" packaging = "^23.2" jinja2 = { version = "^3", optional = true } -defusedxml = "^0.7" [tool.poetry.group.lint] optional = true diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index 7ba68f42a4..65b095f308 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -1,6 +1,4 @@ """Test XMLOutputParser""" -from typing import AsyncIterator - import pytest from langchain_core.exceptions import OutputParserException @@ -42,24 +40,14 @@ More random text """, ], ) -async def test_xml_output_parser(result: str) -> None: +def test_xml_output_parser(result: str) -> None: """Test XMLOutputParser.""" xml_parser = XMLOutputParser() - assert DEF_RESULT_EXPECTED == xml_parser.parse(result) - assert DEF_RESULT_EXPECTED == (await xml_parser.aparse(result)) - assert list(xml_parser.transform(iter(result))) == [ - {"foo": [{"bar": [{"baz": None}]}]}, - {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, - {"foo": [{"baz": "tag"}]}, - ] - - async def _as_iter(string: str) -> AsyncIterator[str]: - for c in string: - yield c - chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))] - assert chunks == [ + xml_result = xml_parser.parse(result) + assert DEF_RESULT_EXPECTED == xml_result + assert list(xml_parser.transform(iter(result))) == [ {"foo": [{"bar": [{"baz": None}]}]}, {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, {"foo": [{"baz": "tag"}]}, @@ -75,27 +63,3 @@ def test_xml_output_parser_fail(result: str) -> None: with pytest.raises(OutputParserException) as e: xml_parser.parse(result) assert "Failed to parse" in str(e) - - -MALICIOUS_XML = """ - - - - - - - - - - -]> -&lol9;""" - - -async def tests_billion_laughs_attack() -> None: - parser = XMLOutputParser() - with pytest.raises(OutputParserException): - parser.parse(MALICIOUS_XML) - - with pytest.raises(OutputParserException): - await parser.aparse(MALICIOUS_XML)