From 0df76bee3770c17fb4d6997de99e2dd5fc96e338 Mon Sep 17 00:00:00 2001
From: ale-delfino <105441283+ale-delfino@users.noreply.github.com>
Date: Fri, 29 Mar 2024 00:55:23 +0000
Subject: [PATCH] core[patch]:: XML parser to cover the case when the xml only
contains the root level tag (#17456)
Description: Fix xml parser to handle strings that only contain the root
tag
Issue: N/A
Dependencies: None
Twitter handle: N/A
A valid xml text can contain only the root level tag. Example:
Some text here
The example above is a valid xml string. If parsed with the current
implementation the result is {"body": []}. This fix checks if the root
level text contains any non-whitespace character and if that's the case
it returns {root.tag: root.text}. The result is that the above text is
correctly parsed as {"body": "Some text here"}
@ale-delfino
Thank you for contributing to LangChain!
Checklist:
- [x] PR title: Please title your PR "package: description", where
"package" is whichever of langchain, community, core, experimental, etc.
is being modified. Use "docs: ..." for purely docs changes, "templates:
..." for template changes, "infra: ..." for CI changes.
- Example: "community: add foobar LLM"
- [x] PR message: **Delete this entire template message** and replace it
with the following bulleted list
- **Description:** a description of the change
- **Issue:** the issue # it fixes, if applicable
- **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!
- [x] Pass lint and test: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified to check that you're
passing lint and testing. See contribution guidelines for more
information on how to write/run tests, lint, etc:
https://python.langchain.com/docs/contributing/
- [x] Add tests and docs: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.
Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @efriis, @eyurtsev, @hwchase17.
---------
Co-authored-by: Eugene Yurtsev
---
.../core/langchain_core/output_parsers/xml.py | 10 ++++--
.../output_parsers/test_xml_parser.py | 32 +++++++++++++++----
2 files changed, 33 insertions(+), 9 deletions(-)
diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py
index 704c67b8e7..890a4d7c71 100644
--- a/libs/core/langchain_core/output_parsers/xml.py
+++ b/libs/core/langchain_core/output_parsers/xml.py
@@ -155,7 +155,7 @@ class XMLOutputParser(BaseTransformOutputParser):
def get_format_instructions(self) -> str:
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
- def parse(self, text: str) -> Dict[str, List[Any]]:
+ def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]:
# Try to find XML string within triple backticks
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
@@ -207,9 +207,13 @@ class XMLOutputParser(BaseTransformOutputParser):
yield output
streaming_parser.close()
- def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
+ def _root_to_dict(self, root: ET.Element) -> Dict[str, Union[str, List[Any]]]:
"""Converts xml tree to python dictionary."""
- result: Dict[str, List[Any]] = {root.tag: []}
+ if root.text and bool(re.search(r"\S", root.text)):
+ # If root text contains any non-whitespace character it
+ # returns {root.tag: root.text}
+ return {root.tag: root.text}
+ result: Dict = {root.tag: []}
for child in root:
if len(child) == 0:
result[root.tag].append({child.tag: child.text})
diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
index c30d09ea1b..c71ed7f992 100644
--- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
+++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
@@ -45,8 +45,8 @@ DEF_RESULT_EXPECTED = {
async def _test_parser(parser: XMLOutputParser, content: str) -> None:
"""Test parser."""
- xml_content = parser.parse(content)
- assert DEF_RESULT_EXPECTED == xml_content
+ assert parser.parse(content) == DEF_RESULT_EXPECTED
+ assert await parser.aparse(content) == DEF_RESULT_EXPECTED
assert list(parser.transform(iter(content))) == [
{"foo": [{"bar": [{"baz": None}]}]},
@@ -54,10 +54,6 @@ async def _test_parser(parser: XMLOutputParser, content: str) -> None:
{"foo": [{"baz": "tag"}]},
]
- async def _as_iter(iterable: Iterable[str]) -> AsyncIterator[str]:
- for item in iterable:
- yield item
-
chunks = [chunk async for chunk in parser.atransform(_as_iter(content))]
assert list(chunks) == [
@@ -67,6 +63,30 @@ async def _test_parser(parser: XMLOutputParser, content: str) -> None:
]
+ROOT_LEVEL_ONLY = """
+Text of the body.
+"""
+
+ROOT_LEVEL_ONLY_EXPECTED = {"body": "Text of the body."}
+
+
+async def _as_iter(iterable: Iterable[str]) -> AsyncIterator[str]:
+ for item in iterable:
+ yield item
+
+
+async def test_root_only_xml_output_parser() -> None:
+ """Test XMLOutputParser when xml only contains the root level tag"""
+ xml_parser = XMLOutputParser(parser="xml")
+ assert xml_parser.parse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."}
+ assert await xml_parser.aparse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."}
+ assert list(xml_parser.transform(iter(ROOT_LEVEL_ONLY))) == [
+ {"body": "Text of the body."}
+ ]
+ chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(ROOT_LEVEL_ONLY))]
+ assert chunks == [{"body": "Text of the body."}]
+
+
@pytest.mark.parametrize(
"content",
[