@ -1,7 +1,8 @@
import re
import xml
import xml . etree . ElementTree as ET
from typing import Any , AsyncIterator , Dict , Iterator , List , Optional , Union
from typing import Any , AsyncIterator , Dict , Iterator , List , Literal , Optional , Union
from xml . etree . ElementTree import TreeBuilder
from langchain_core . exceptions import OutputParserException
from langchain_core . messages import BaseMessage
@ -24,6 +25,105 @@ Here are the output tags:
` ` ` """ # noqa: E501
class _StreamingParser :
""" Streaming parser for XML.
This implementation is pulled into a class to avoid implementation
drift between transform and atransform of the XMLOutputParser .
"""
def __init__ ( self , parser : Literal [ " defusedxml " , " xml " ] ) - > None :
""" Initialize the streaming parser.
Args :
parser : Parser to use for XML parsing . Can be either ' defusedxml ' or ' xml ' .
See documentation in XMLOutputParser for more information .
"""
if parser == " defusedxml " :
try :
from defusedxml import ElementTree as DET # type: ignore
except ImportError :
raise ImportError (
" defusedxml is not installed. "
" Please install it to use the defusedxml parser. "
" You can install it with `pip install defusedxml` "
)
_parser = DET . DefusedXMLParser ( target = TreeBuilder ( ) )
else :
_parser = None
self . pull_parser = ET . XMLPullParser ( [ " start " , " end " ] , _parser = _parser )
self . xml_start_re = re . compile ( r " <[a-zA-Z:_] " )
self . current_path : List [ str ] = [ ]
self . current_path_has_children = False
self . buffer = " "
self . xml_started = False
def parse ( self , chunk : Union [ str , BaseMessage ] ) - > Iterator [ AddableDict ] :
""" Parse a chunk of text.
Args :
chunk : A chunk of text to parse . This can be a string or a BaseMessage .
Yields :
AddableDict : A dictionary representing the parsed XML element .
"""
if isinstance ( chunk , BaseMessage ) :
# extract text
chunk_content = chunk . content
if not isinstance ( chunk_content , str ) :
# ignore non-string messages (e.g., function calls)
return
chunk = chunk_content
# add chunk to buffer of unprocessed text
self . buffer + = chunk
# if xml string hasn't started yet, continue to next chunk
if not self . xml_started :
if match := self . xml_start_re . search ( self . buffer ) :
# if xml string has started, remove all text before it
self . buffer = self . buffer [ match . start ( ) : ]
self . xml_started = True
else :
return
# feed buffer to parser
self . pull_parser . feed ( self . buffer )
self . buffer = " "
# yield all events
try :
for event , elem in self . pull_parser . read_events ( ) :
if event == " start " :
# update current path
self . current_path . append ( elem . tag )
self . current_path_has_children = False
elif event == " end " :
# remove last element from current path
#
self . current_path . pop ( )
# yield element
if not self . current_path_has_children :
yield nested_element ( self . current_path , elem )
# prevent yielding of parent element
if self . current_path :
self . current_path_has_children = True
else :
self . xml_started = False
except xml . etree . ElementTree . ParseError :
# This might be junk at the end of the XML input.
# Let's check whether the current path is empty.
if not self . current_path :
# If it is empty, we can ignore this error.
return
else :
raise
def close ( self ) - > None :
""" Close the parser. """
try :
self . pull_parser . close ( )
except xml . etree . ElementTree . ParseError :
# Ignore. This will ignore any incomplete XML at the end of the input
pass
class XMLOutputParser ( BaseTransformOutputParser ) :
""" Parse an output using xml format. """
@ -31,12 +131,48 @@ class XMLOutputParser(BaseTransformOutputParser):
encoding_matcher : re . Pattern = re . compile (
r " <([^>]*encoding[^>]*)> \ n(.*) " , re . MULTILINE | re . DOTALL
)
parser : Literal [ " defusedxml " , " xml " ] = " defusedxml "
""" Parser to use for XML parsing. Can be either ' defusedxml ' or ' xml ' .
* ' defusedxml ' is the default parser and is used to prevent XML vulnerabilities
present in some distributions of Python ' s standard library xml.
` defusedxml ` is a wrapper around the standard library parser that
sets up the parser with secure defaults .
* ' xml ' is the standard library parser .
Use ` xml ` only if you are sure that your distribution of the standard library
is not vulnerable to XML vulnerabilities .
Please review the following resources for more information :
* https : / / docs . python . org / 3 / library / xml . html #xml-vulnerabilities
* https : / / github . com / tiran / defusedxml
The standard library relies on libexpat for parsing XML :
https : / / github . com / libexpat / libexpat
"""
def get_format_instructions ( self ) - > str :
return XML_FORMAT_INSTRUCTIONS . format ( tags = self . tags )
def parse ( self , text : str ) - > Dict [ str , List [ Any ] ] :
# Try to find XML string within triple backticks
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
if self . parser == " defusedxml " :
try :
from defusedxml import ElementTree as DET # type: ignore
except ImportError :
raise ImportError (
" defusedxml is not installed. "
" Please install it to use the defusedxml parser. "
" You can install it with `pip install defusedxml` "
" See https://github.com/tiran/defusedxml for more details "
)
_ET = DET # Use the defusedxml parser
else :
_ET = ET # Use the standard library parser
match = re . search ( r " ```(xml)?(.*)``` " , text , re . DOTALL )
if match is not None :
# If match found, use the content within the backticks
@ -57,132 +193,19 @@ class XMLOutputParser(BaseTransformOutputParser):
def _transform (
self , input : Iterator [ Union [ str , BaseMessage ] ]
) - > Iterator [ AddableDict ] :
xml_start_re = re . compile ( r " <[a-zA-Z:_] " )
parser = ET . XMLPullParser ( [ " start " , " end " ] )
xml_started = False
current_path : List [ str ] = [ ]
current_path_has_children = False
buffer = " "
streaming_parser = _StreamingParser ( self . parser )
for chunk in input :
if isinstance ( chunk , BaseMessage ) :
# extract text
chunk_content = chunk . content
if not isinstance ( chunk_content , str ) :
continue
chunk = chunk_content
# add chunk to buffer of unprocessed text
buffer + = chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started :
if match := xml_start_re . search ( buffer ) :
# if xml string has started, remove all text before it
buffer = buffer [ match . start ( ) : ]
xml_started = True
else :
continue
# feed buffer to parser
parser . feed ( buffer )
buffer = " "
# yield all events
try :
for event , elem in parser . read_events ( ) :
if event == " start " :
# update current path
current_path . append ( elem . tag )
current_path_has_children = False
elif event == " end " :
# remove last element from current path
#
current_path . pop ( )
# yield element
if not current_path_has_children :
yield nested_element ( current_path , elem )
# prevent yielding of parent element
if current_path :
current_path_has_children = True
else :
xml_started = False
except xml . etree . ElementTree . ParseError :
# This might be junk at the end of the XML input.
# Let's check whether the current path is empty.
if not current_path :
# If it is empty, we can ignore this error.
break
else :
raise
# close parser
try :
parser . close ( )
except xml . etree . ElementTree . ParseError :
# Ignore. This will ignore any incomplete XML at the end of the input
pass
yield from streaming_parser . parse ( chunk )
streaming_parser . close ( )
async def _atransform (
self , input : AsyncIterator [ Union [ str , BaseMessage ] ]
) - > AsyncIterator [ AddableDict ] :
xml_start_re = re . compile ( r " <[a-zA-Z:_] " )
parser = ET . XMLPullParser ( [ " start " , " end " ] )
xml_started = False
current_path : List [ str ] = [ ]
current_path_has_children = False
buffer = " "
streaming_parser = _StreamingParser ( self . parser )
async for chunk in input :
if isinstance ( chunk , BaseMessage ) :
# extract text
chunk_content = chunk . content
if not isinstance ( chunk_content , str ) :
continue
chunk = chunk_content
# add chunk to buffer of unprocessed text
buffer + = chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started :
if match := xml_start_re . search ( buffer ) :
# if xml string has started, remove all text before it
buffer = buffer [ match . start ( ) : ]
xml_started = True
else :
continue
# feed buffer to parser
parser . feed ( buffer )
buffer = " "
# yield all events
try :
for event , elem in parser . read_events ( ) :
if event == " start " :
# update current path
current_path . append ( elem . tag )
current_path_has_children = False
elif event == " end " :
# remove last element from current path
#
current_path . pop ( )
# yield element
if not current_path_has_children :
yield nested_element ( current_path , elem )
# prevent yielding of parent element
if current_path :
current_path_has_children = True
else :
xml_started = False
except xml . etree . ElementTree . ParseError :
# This might be junk at the end of the XML input.
# Let's check whether the current path is empty.
if not current_path :
# If it is empty, we can ignore this error.
break
else :
raise
# close parser
try :
parser . close ( )
except xml . etree . ElementTree . ParseError :
# Ignore. This will ignore any incomplete XML at the end of the input
pass
for output in streaming_parser . parse ( chunk ) :
yield output
streaming_parser . close ( )
def _root_to_dict ( self , root : ET . Element ) - > Dict [ str , List [ Any ] ] :
""" Converts xml tree to python dictionary. """