From 0df76bee3770c17fb4d6997de99e2dd5fc96e338 Mon Sep 17 00:00:00 2001 From: ale-delfino <105441283+ale-delfino@users.noreply.github.com> Date: Fri, 29 Mar 2024 00:55:23 +0000 Subject: [PATCH] core[patch]:: XML parser to cover the case when the xml only contains the root level tag (#17456) Description: Fix xml parser to handle strings that only contain the root tag Issue: N/A Dependencies: None Twitter handle: N/A A valid xml text can contain only the root level tag. Example: Some text here The example above is a valid xml string. If parsed with the current implementation the result is {"body": []}. This fix checks if the root level text contains any non-whitespace character and if that's the case it returns {root.tag: root.text}. The result is that the above text is correctly parsed as {"body": "Some text here"} @ale-delfino Thank you for contributing to LangChain! Checklist: - [x] PR title: Please title your PR "package: description", where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] PR message: **Delete this entire template message** and replace it with the following bulleted list - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [x] Pass lint and test: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified to check that you're passing lint and testing. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ - [x] Add tests and docs: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @efriis, @eyurtsev, @hwchase17. --------- Co-authored-by: Eugene Yurtsev --- .../core/langchain_core/output_parsers/xml.py | 10 ++++-- .../output_parsers/test_xml_parser.py | 32 +++++++++++++++---- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 704c67b8e7..890a4d7c71 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -155,7 +155,7 @@ class XMLOutputParser(BaseTransformOutputParser): def get_format_instructions(self) -> str: return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) - def parse(self, text: str) -> Dict[str, List[Any]]: + def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: # Try to find XML string within triple backticks # Imports are temporarily placed here to avoid issue with caching on CI # likely if you're reading this you can move them to the top of the file @@ -207,9 +207,13 @@ class XMLOutputParser(BaseTransformOutputParser): yield output streaming_parser.close() - def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: + def _root_to_dict(self, root: ET.Element) -> Dict[str, Union[str, List[Any]]]: """Converts xml tree to python dictionary.""" - result: Dict[str, List[Any]] = {root.tag: []} + if root.text and bool(re.search(r"\S", root.text)): + # If root text contains any non-whitespace character it + # returns {root.tag: root.text} + return {root.tag: root.text} + result: Dict = {root.tag: []} for child in root: if len(child) == 0: result[root.tag].append({child.tag: child.text}) diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index c30d09ea1b..c71ed7f992 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -45,8 +45,8 @@ DEF_RESULT_EXPECTED = { async def _test_parser(parser: XMLOutputParser, content: str) -> None: """Test parser.""" - xml_content = parser.parse(content) - assert DEF_RESULT_EXPECTED == xml_content + assert parser.parse(content) == DEF_RESULT_EXPECTED + assert await parser.aparse(content) == DEF_RESULT_EXPECTED assert list(parser.transform(iter(content))) == [ {"foo": [{"bar": [{"baz": None}]}]}, @@ -54,10 +54,6 @@ async def _test_parser(parser: XMLOutputParser, content: str) -> None: {"foo": [{"baz": "tag"}]}, ] - async def _as_iter(iterable: Iterable[str]) -> AsyncIterator[str]: - for item in iterable: - yield item - chunks = [chunk async for chunk in parser.atransform(_as_iter(content))] assert list(chunks) == [ @@ -67,6 +63,30 @@ async def _test_parser(parser: XMLOutputParser, content: str) -> None: ] +ROOT_LEVEL_ONLY = """ +Text of the body. +""" + +ROOT_LEVEL_ONLY_EXPECTED = {"body": "Text of the body."} + + +async def _as_iter(iterable: Iterable[str]) -> AsyncIterator[str]: + for item in iterable: + yield item + + +async def test_root_only_xml_output_parser() -> None: + """Test XMLOutputParser when xml only contains the root level tag""" + xml_parser = XMLOutputParser(parser="xml") + assert xml_parser.parse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."} + assert await xml_parser.aparse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."} + assert list(xml_parser.transform(iter(ROOT_LEVEL_ONLY))) == [ + {"body": "Text of the body."} + ] + chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(ROOT_LEVEL_ONLY))] + assert chunks == [{"body": "Text of the body."}] + + @pytest.mark.parametrize( "content", [