@ -1,6 +1,6 @@
import re
import xml . etree . ElementTree as ET
from typing import Any , AsyncIterator , Dict , Iterator , List , Optional , Union
from xml . etree import ElementTree as ET
from langchain_core . exceptions import OutputParserException
from langchain_core . messages import BaseMessage
@ -35,10 +35,6 @@ class XMLOutputParser(BaseTransformOutputParser):
return XML_FORMAT_INSTRUCTIONS . format ( tags = self . tags )
def parse ( self , text : str ) - > Dict [ str , List [ Any ] ] :
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml import ElementTree as DET # type: ignore[import]
# Try to find XML string within triple backticks
match = re . search ( r " ```(xml)?(.*)``` " , text , re . DOTALL )
if match is not None :
@ -50,18 +46,18 @@ class XMLOutputParser(BaseTransformOutputParser):
text = text . strip ( )
try :
root = D ET. fromstring ( text )
root = ET. fromstring ( text )
return self . _root_to_dict ( root )
except ( D ET. ParseError , DET . EntitiesForbidden ) as e :
except ET. ParseError as e :
msg = f " Failed to parse XML format from completion { text } . Got: { e } "
raise OutputParserException ( msg , llm_output = text ) from e
def _transform (
self , input : Iterator [ Union [ str , BaseMessage ] ]
) - > Iterator [ AddableDict ] :
parser = ET . XMLPullParser ( [ " start " , " end " ] )
xml_start_re = re . compile ( r " <[a-zA-Z:_] " )
parser = ET . XMLPullParser ( [ " start " , " end " ] )
xml_started = False
current_path : List [ str ] = [ ]
current_path_has_children = False
@ -87,7 +83,6 @@ class XMLOutputParser(BaseTransformOutputParser):
parser . feed ( buffer )
buffer = " "
# yield all events
for event , elem in parser . read_events ( ) :
if event == " start " :
# update current path
@ -111,11 +106,8 @@ class XMLOutputParser(BaseTransformOutputParser):
self , input : AsyncIterator [ Union [ str , BaseMessage ] ]
) - > AsyncIterator [ AddableDict ] :
parser = ET . XMLPullParser ( [ " start " , " end " ] )
xml_start_re = re . compile ( r " <[a-zA-Z:_] " )
xml_started = False
current_path : List [ str ] = [ ]
current_path_has_children = False
buffer = " "
async for chunk in input :
if isinstance ( chunk , BaseMessage ) :
# extract text
@ -123,19 +115,8 @@ class XMLOutputParser(BaseTransformOutputParser):
if not isinstance ( chunk_content , str ) :
continue
chunk = chunk_content
# add chunk to buffer of unprocessed text
buffer + = chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started :
if match := xml_start_re . search ( buffer ) :
# if xml string has started, remove all text before it
buffer = buffer [ match . start ( ) : ]
xml_started = True
else :
continue
# feed buffer to parser
parser . feed ( buffer )
buffer = " "
# pass chunk to parser
parser . feed ( chunk )
# yield all events
for event , elem in parser . read_events ( ) :
if event == " start " :
@ -149,10 +130,7 @@ class XMLOutputParser(BaseTransformOutputParser):
if not current_path_has_children :
yield nested_element ( current_path , elem )
# prevent yielding of parent element
if current_path :
current_path_has_children = True
else :
xml_started = False
# close parser
parser . close ( )