Implement streaming for xml output parser (#14984)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/14971/head^2
Nuno Campos 6 months ago committed by GitHub
parent 94bc3967a1
commit b471166df7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,13 +1,15 @@
import re
import xml.etree.ElementTree as ET
from typing import Any, Dict, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict
from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS
class XMLOutputParser(BaseOutputParser):
class XMLOutputParser(BaseTransformOutputParser):
"""Parse an output using xml format."""
tags: Optional[List[str]] = None
@ -33,6 +35,70 @@ class XMLOutputParser(BaseOutputParser):
else:
raise ValueError(f"Could not parse output: {text}")
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
current_path: List[str] = []
current_path_has_children = False
for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
# close parser
parser.close()
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
current_path: List[str] = []
current_path_has_children = False
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
# close parser
parser.close()
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary."""
result: Dict[str, List[Any]] = {root.tag: []}
@ -46,3 +112,11 @@ class XMLOutputParser(BaseOutputParser):
@property
def _type(self) -> str:
return "xml"
def nested_element(path: List[str], elem: ET.Element) -> Any:
"""Get nested element from path."""
if len(path) == 0:
return AddableDict({elem.tag: elem.text})
else:
return AddableDict({path[0]: [nested_element(path[1:], elem)]})

@ -31,6 +31,11 @@ def test_xml_output_parser(result: str) -> None:
xml_result = xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == xml_result
assert list(xml_parser.transform(iter(result))) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])

Loading…
Cancel
Save