|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
import re
|
|
|
|
|
import xml.etree.ElementTree as ET
|
|
|
|
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
|
|
|
|
from xml.etree import ElementTree as ET
|
|
|
|
|
|
|
|
|
|
from langchain_core.exceptions import OutputParserException
|
|
|
|
|
from langchain_core.messages import BaseMessage
|
|
|
|
@ -35,10 +35,6 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|
|
|
|
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
|
|
|
|
|
|
|
|
|
def parse(self, text: str) -> Dict[str, List[Any]]:
|
|
|
|
|
# Imports are temporarily placed here to avoid issue with caching on CI
|
|
|
|
|
# likely if you're reading this you can move them to the top of the file
|
|
|
|
|
from defusedxml import ElementTree as DET # type: ignore[import]
|
|
|
|
|
|
|
|
|
|
# Try to find XML string within triple backticks
|
|
|
|
|
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
|
|
|
|
|
if match is not None:
|
|
|
|
@ -50,10 +46,10 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|
|
|
|
|
|
|
|
|
text = text.strip()
|
|
|
|
|
try:
|
|
|
|
|
root = DET.fromstring(text)
|
|
|
|
|
root = ET.fromstring(text)
|
|
|
|
|
return self._root_to_dict(root)
|
|
|
|
|
|
|
|
|
|
except (DET.ParseError, DET.EntitiesForbidden) as e:
|
|
|
|
|
except ET.ParseError as e:
|
|
|
|
|
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
|
|
|
|
|
raise OutputParserException(msg, llm_output=text) from e
|
|
|
|
|
|
|
|
|
|