core[patch]: Patch XML vulnerability in XMLOutputParser (CVE-2024-1455) (#19653)

Patch potential XML vulnerability CVE-2024-1455

This patches a potential XML vulnerability in the XMLOutputParser in
langchain-core. The vulnerability in some situations could lead to a
denial of service attack.

At risk are users that:

1) Running older distributions of python that have older version of
libexpat
2) Are using XMLOutputParser with an agent
3) Accept inputs from untrusted sources with this agent (e.g., endpoint
on the web that allows an untrusted user to interact wiith the parser)
pull/19660/head
Eugene Yurtsev 3 months ago committed by GitHub
parent 7042934b5f
commit e8339b1d83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,7 +1,8 @@
import re
import xml
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
from xml.etree.ElementTree import TreeBuilder
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
@ -24,6 +25,105 @@ Here are the output tags:
```""" # noqa: E501
class _StreamingParser:
"""Streaming parser for XML.
This implementation is pulled into a class to avoid implementation
drift between transform and atransform of the XMLOutputParser.
"""
def __init__(self, parser: Literal["defusedxml", "xml"]) -> None:
"""Initialize the streaming parser.
Args:
parser: Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'.
See documentation in XMLOutputParser for more information.
"""
if parser == "defusedxml":
try:
from defusedxml import ElementTree as DET # type: ignore
except ImportError:
raise ImportError(
"defusedxml is not installed. "
"Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml` "
)
_parser = DET.DefusedXMLParser(target=TreeBuilder())
else:
_parser = None
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
self.current_path: List[str] = []
self.current_path_has_children = False
self.buffer = ""
self.xml_started = False
def parse(self, chunk: Union[str, BaseMessage]) -> Iterator[AddableDict]:
"""Parse a chunk of text.
Args:
chunk: A chunk of text to parse. This can be a string or a BaseMessage.
Yields:
AddableDict: A dictionary representing the parsed XML element.
"""
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
# ignore non-string messages (e.g., function calls)
return
chunk = chunk_content
# add chunk to buffer of unprocessed text
self.buffer += chunk
# if xml string hasn't started yet, continue to next chunk
if not self.xml_started:
if match := self.xml_start_re.search(self.buffer):
# if xml string has started, remove all text before it
self.buffer = self.buffer[match.start() :]
self.xml_started = True
else:
return
# feed buffer to parser
self.pull_parser.feed(self.buffer)
self.buffer = ""
# yield all events
try:
for event, elem in self.pull_parser.read_events():
if event == "start":
# update current path
self.current_path.append(elem.tag)
self.current_path_has_children = False
elif event == "end":
# remove last element from current path
#
self.current_path.pop()
# yield element
if not self.current_path_has_children:
yield nested_element(self.current_path, elem)
# prevent yielding of parent element
if self.current_path:
self.current_path_has_children = True
else:
self.xml_started = False
except xml.etree.ElementTree.ParseError:
# This might be junk at the end of the XML input.
# Let's check whether the current path is empty.
if not self.current_path:
# If it is empty, we can ignore this error.
return
else:
raise
def close(self) -> None:
"""Close the parser."""
try:
self.pull_parser.close()
except xml.etree.ElementTree.ParseError:
# Ignore. This will ignore any incomplete XML at the end of the input
pass
class XMLOutputParser(BaseTransformOutputParser):
"""Parse an output using xml format."""
@ -31,12 +131,48 @@ class XMLOutputParser(BaseTransformOutputParser):
encoding_matcher: re.Pattern = re.compile(
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
)
parser: Literal["defusedxml", "xml"] = "defusedxml"
"""Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'.
* 'defusedxml' is the default parser and is used to prevent XML vulnerabilities
present in some distributions of Python's standard library xml.
`defusedxml` is a wrapper around the standard library parser that
sets up the parser with secure defaults.
* 'xml' is the standard library parser.
Use `xml` only if you are sure that your distribution of the standard library
is not vulnerable to XML vulnerabilities.
Please review the following resources for more information:
* https://docs.python.org/3/library/xml.html#xml-vulnerabilities
* https://github.com/tiran/defusedxml
The standard library relies on libexpat for parsing XML:
https://github.com/libexpat/libexpat
"""
def get_format_instructions(self) -> str:
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]:
# Try to find XML string within triple backticks
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
if self.parser == "defusedxml":
try:
from defusedxml import ElementTree as DET # type: ignore
except ImportError:
raise ImportError(
"defusedxml is not installed. "
"Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml`"
"See https://github.com/tiran/defusedxml for more details"
)
_ET = DET # Use the defusedxml parser
else:
_ET = ET # Use the standard library parser
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
# If match found, use the content within the backticks
@ -57,132 +193,19 @@ class XMLOutputParser(BaseTransformOutputParser):
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"])
xml_started = False
current_path: List[str] = []
current_path_has_children = False
buffer = ""
streaming_parser = _StreamingParser(self.parser)
for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# add chunk to buffer of unprocessed text
buffer += chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started:
if match := xml_start_re.search(buffer):
# if xml string has started, remove all text before it
buffer = buffer[match.start() :]
xml_started = True
else:
continue
# feed buffer to parser
parser.feed(buffer)
buffer = ""
# yield all events
try:
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
#
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
if current_path:
current_path_has_children = True
else:
xml_started = False
except xml.etree.ElementTree.ParseError:
# This might be junk at the end of the XML input.
# Let's check whether the current path is empty.
if not current_path:
# If it is empty, we can ignore this error.
break
else:
raise
# close parser
try:
parser.close()
except xml.etree.ElementTree.ParseError:
# Ignore. This will ignore any incomplete XML at the end of the input
pass
yield from streaming_parser.parse(chunk)
streaming_parser.close()
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"])
xml_started = False
current_path: List[str] = []
current_path_has_children = False
buffer = ""
streaming_parser = _StreamingParser(self.parser)
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# add chunk to buffer of unprocessed text
buffer += chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started:
if match := xml_start_re.search(buffer):
# if xml string has started, remove all text before it
buffer = buffer[match.start() :]
xml_started = True
else:
continue
# feed buffer to parser
parser.feed(buffer)
buffer = ""
# yield all events
try:
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
#
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
if current_path:
current_path_has_children = True
else:
xml_started = False
except xml.etree.ElementTree.ParseError:
# This might be junk at the end of the XML input.
# Let's check whether the current path is empty.
if not current_path:
# If it is empty, we can ignore this error.
break
else:
raise
# close parser
try:
parser.close()
except xml.etree.ElementTree.ParseError:
# Ignore. This will ignore any incomplete XML at the end of the input
pass
for output in streaming_parser.parse(chunk):
yield output
streaming_parser.close()
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary."""

@ -1,4 +1,5 @@
"""Test XMLOutputParser"""
import importlib
from typing import AsyncIterator, Iterable
import pytest
@ -42,24 +43,12 @@ DEF_RESULT_EXPECTED = {
}
@pytest.mark.parametrize(
"result",
[
DATA, # has no xml header
WITH_XML_HEADER,
IN_XML_TAGS_WITH_XML_HEADER,
IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
],
)
async def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser."""
async def _test_parser(parser: XMLOutputParser, content: str) -> None:
"""Test parser."""
xml_content = parser.parse(content)
assert DEF_RESULT_EXPECTED == xml_content
xml_parser = XMLOutputParser()
xml_result = xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == xml_result
assert list(xml_parser.transform(iter(result))) == [
assert list(parser.transform(iter(content))) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
@ -69,7 +58,7 @@ async def test_xml_output_parser(result: str) -> None:
for item in iterable:
yield item
chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
chunks = [chunk async for chunk in parser.atransform(_as_iter(content))]
assert list(chunks) == [
{"foo": [{"bar": [{"baz": None}]}]},
@ -78,12 +67,72 @@ async def test_xml_output_parser(result: str) -> None:
]
@pytest.mark.parametrize(
"content",
[
DATA, # has no xml header
WITH_XML_HEADER,
IN_XML_TAGS_WITH_XML_HEADER,
IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
],
)
async def test_xml_output_parser(content: str) -> None:
"""Test XMLOutputParser."""
xml_parser = XMLOutputParser(parser="xml")
await _test_parser(xml_parser, content)
@pytest.mark.skipif(
importlib.util.find_spec("defusedxml") is None,
reason="defusedxml is not installed",
)
@pytest.mark.parametrize(
"content",
[
DATA, # has no xml header
WITH_XML_HEADER,
IN_XML_TAGS_WITH_XML_HEADER,
IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
],
)
async def test_xml_output_parser_defused(content: str) -> None:
"""Test XMLOutputParser."""
xml_parser = XMLOutputParser(parser="defusedxml")
await _test_parser(xml_parser, content)
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])
def test_xml_output_parser_fail(result: str) -> None:
"""Test XMLOutputParser where complete output is not in XML format."""
xml_parser = XMLOutputParser()
xml_parser = XMLOutputParser(parser="xml")
with pytest.raises(OutputParserException) as e:
xml_parser.parse(result)
assert "Failed to parse" in str(e)
MALICIOUS_XML = """<?xml version="1.0"?>
<!DOCTYPE lolz [<!ENTITY lol "lol"><!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
<!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">
<!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">
<!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">
<!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">
<!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">
<!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">
<!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">
<!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">
]>
<lolz>&lol9;</lolz>"""
async def tests_billion_laughs_attack() -> None:
# Testing with standard XML parser since it's safe to use in
# newer versions of Python
parser = XMLOutputParser(parser="xml")
with pytest.raises(OutputParserException):
parser.parse(MALICIOUS_XML)
with pytest.raises(OutputParserException):
await parser.aparse(MALICIOUS_XML)

Loading…
Cancel
Save