mirror of https://github.com/hwchase17/langchain
Move json and xml parsers to core (#15026)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->pull/15036/head
parent
d5533b7081
commit
71076cceaf
@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
value = match.group(2)
|
||||
value = re.sub(r"\n", r"\\n", value)
|
||||
value = re.sub(r"\r", r"\\r", value)
|
||||
value = re.sub(r"\t", r"\\t", value)
|
||||
value = re.sub(r'(?<!\\)"', r"\"", value)
|
||||
|
||||
return match.group(1) + value + match.group(3)
|
||||
|
||||
|
||||
def _custom_parser(multiline_string: str) -> str:
|
||||
"""
|
||||
The LLM response for `action_input` may be a multiline
|
||||
string containing unescaped newlines, tabs or quotes. This function
|
||||
replaces those characters with their escaped counterparts.
|
||||
(newlines in JSON must be double-escaped: `\\n`)
|
||||
"""
|
||||
if isinstance(multiline_string, (bytes, bytearray)):
|
||||
multiline_string = multiline_string.decode()
|
||||
|
||||
multiline_string = re.sub(
|
||||
r'("action_input"\:\s*")(.*)(")',
|
||||
_replace_new_line,
|
||||
multiline_string,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
return multiline_string
|
||||
|
||||
|
||||
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
|
||||
# MIT License
|
||||
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
||||
"""Parse a JSON string that may be missing closing braces.
|
||||
|
||||
Args:
|
||||
s: The JSON string to parse.
|
||||
strict: Whether to use strict parsing. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
# Attempt to parse the string as-is.
|
||||
try:
|
||||
return json.loads(s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Initialize variables.
|
||||
new_s = ""
|
||||
stack = []
|
||||
is_inside_string = False
|
||||
escaped = False
|
||||
|
||||
# Process each character in the string one at a time.
|
||||
for char in s:
|
||||
if is_inside_string:
|
||||
if char == '"' and not escaped:
|
||||
is_inside_string = False
|
||||
elif char == "\n" and not escaped:
|
||||
char = "\\n" # Replace the newline character with the escape sequence.
|
||||
elif char == "\\":
|
||||
escaped = not escaped
|
||||
else:
|
||||
escaped = False
|
||||
else:
|
||||
if char == '"':
|
||||
is_inside_string = True
|
||||
escaped = False
|
||||
elif char == "{":
|
||||
stack.append("}")
|
||||
elif char == "[":
|
||||
stack.append("]")
|
||||
elif char == "}" or char == "]":
|
||||
if stack and stack[-1] == char:
|
||||
stack.pop()
|
||||
else:
|
||||
# Mismatched closing character; the input is malformed.
|
||||
return None
|
||||
|
||||
# Append the processed character to the new string.
|
||||
new_s += char
|
||||
|
||||
# If we're still inside a string at the end of processing,
|
||||
# we need to close the string.
|
||||
if is_inside_string:
|
||||
new_s += '"'
|
||||
|
||||
# Close any remaining open structures in the reverse order that they were opened.
|
||||
for closing_char in reversed(stack):
|
||||
new_s += closing_char
|
||||
|
||||
# Attempt to parse the modified string as JSON.
|
||||
try:
|
||||
return json.loads(new_s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
# If we still can't parse the string as JSON, return None to indicate failure.
|
||||
return None
|
||||
|
||||
|
||||
def parse_json_markdown(
|
||||
json_string: str, *, parser: Callable[[str], Any] = json.loads
|
||||
) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string.
|
||||
|
||||
Args:
|
||||
json_string: The Markdown string.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
# Try to find JSON string within triple backticks
|
||||
match = re.search(r"```(json)?(.*)```", json_string, re.DOTALL)
|
||||
|
||||
# If no match found, assume the entire string is a JSON string
|
||||
if match is None:
|
||||
json_str = json_string
|
||||
else:
|
||||
# If match found, use the content within the backticks
|
||||
json_str = match.group(2)
|
||||
|
||||
# Strip whitespace and newlines from the start and end
|
||||
json_str = json_str.strip()
|
||||
|
||||
# handle newlines and other special characters inside the returned value
|
||||
json_str = _custom_parser(json_str)
|
||||
|
||||
# Parse the JSON string into a Python dictionary
|
||||
parsed = parser(json_str)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string and check that it
|
||||
contains the expected keys.
|
||||
|
||||
Args:
|
||||
text: The Markdown string.
|
||||
expected_keys: The expected keys in the JSON string.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{key}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
|
||||
|
||||
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse the output of an LLM call to a JSON object.
|
||||
|
||||
When used in streaming mode, it will yield partial JSON objects containing
|
||||
all the keys that have been returned so far.
|
||||
|
||||
In streaming, if `diff` is set to `True`, yields JSONPatch operations
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
text = text.strip()
|
||||
try:
|
||||
return parse_json_markdown(text.strip(), parser=parse_partial_json)
|
||||
except JSONDecodeError as e:
|
||||
raise OutputParserException(f"Invalid json output: {text}") from e
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "simple_json_output_parser"
|
@ -0,0 +1,135 @@
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
from langchain_core.runnables.utils import AddableDict
|
||||
|
||||
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
|
||||
1. Output should conform to the tags below.
|
||||
2. If tags are not given, make them on your own.
|
||||
3. Remember to always open and close all the tags.
|
||||
|
||||
As an example, for the tags ["foo", "bar", "baz"]:
|
||||
1. String "<foo>\n <bar>\n <baz></baz>\n </bar>\n</foo>" is a well-formatted instance of the schema.
|
||||
2. String "<foo>\n <bar>\n </foo>" is a badly-formatted instance.
|
||||
3. String "<foo>\n <tag>\n </tag>\n</foo>" is a badly-formatted instance.
|
||||
|
||||
Here are the output tags:
|
||||
```
|
||||
{tags}
|
||||
```""" # noqa: E501
|
||||
|
||||
|
||||
class XMLOutputParser(BaseTransformOutputParser):
|
||||
"""Parse an output using xml format."""
|
||||
|
||||
tags: Optional[List[str]] = None
|
||||
encoding_matcher: re.Pattern = re.compile(
|
||||
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
|
||||
)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
||||
|
||||
def parse(self, text: str) -> Dict[str, List[Any]]:
|
||||
text = text.strip("`").strip("xml")
|
||||
encoding_match = self.encoding_matcher.search(text)
|
||||
if encoding_match:
|
||||
text = encoding_match.group(2)
|
||||
|
||||
text = text.strip()
|
||||
if (text.startswith("<") or text.startswith("\n<")) and (
|
||||
text.endswith(">") or text.endswith(">\n")
|
||||
):
|
||||
root = ET.fromstring(text)
|
||||
return self._root_to_dict(root)
|
||||
else:
|
||||
raise ValueError(f"Could not parse output: {text}")
|
||||
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
) -> Iterator[AddableDict]:
|
||||
parser = ET.XMLPullParser(["start", "end"])
|
||||
current_path: List[str] = []
|
||||
current_path_has_children = False
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
# extract text
|
||||
chunk_content = chunk.content
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
chunk = chunk_content
|
||||
# pass chunk to parser
|
||||
parser.feed(chunk)
|
||||
# yield all events
|
||||
for event, elem in parser.read_events():
|
||||
if event == "start":
|
||||
# update current path
|
||||
current_path.append(elem.tag)
|
||||
current_path_has_children = False
|
||||
elif event == "end":
|
||||
# remove last element from current path
|
||||
current_path.pop()
|
||||
# yield element
|
||||
if not current_path_has_children:
|
||||
yield nested_element(current_path, elem)
|
||||
# prevent yielding of parent element
|
||||
current_path_has_children = True
|
||||
# close parser
|
||||
parser.close()
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
parser = ET.XMLPullParser(["start", "end"])
|
||||
current_path: List[str] = []
|
||||
current_path_has_children = False
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
# extract text
|
||||
chunk_content = chunk.content
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
chunk = chunk_content
|
||||
# pass chunk to parser
|
||||
parser.feed(chunk)
|
||||
# yield all events
|
||||
for event, elem in parser.read_events():
|
||||
if event == "start":
|
||||
# update current path
|
||||
current_path.append(elem.tag)
|
||||
current_path_has_children = False
|
||||
elif event == "end":
|
||||
# remove last element from current path
|
||||
current_path.pop()
|
||||
# yield element
|
||||
if not current_path_has_children:
|
||||
yield nested_element(current_path, elem)
|
||||
# prevent yielding of parent element
|
||||
current_path_has_children = True
|
||||
# close parser
|
||||
parser.close()
|
||||
|
||||
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
|
||||
"""Converts xml tree to python dictionary."""
|
||||
result: Dict[str, List[Any]] = {root.tag: []}
|
||||
for child in root:
|
||||
if len(child) == 0:
|
||||
result[root.tag].append({child.tag: child.text})
|
||||
else:
|
||||
result[root.tag].append(self._root_to_dict(child))
|
||||
return result
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "xml"
|
||||
|
||||
|
||||
def nested_element(path: List[str], elem: ET.Element) -> Any:
|
||||
"""Get nested element from path."""
|
||||
if len(path) == 0:
|
||||
return AddableDict({elem.tag: elem.text})
|
||||
else:
|
||||
return AddableDict({path[0]: [nested_element(path[1:], elem)]})
|
@ -0,0 +1,488 @@
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Iterator, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.output_parsers.json import (
|
||||
SimpleJsonOutputParser,
|
||||
parse_json_markdown,
|
||||
parse_partial_json,
|
||||
)
|
||||
|
||||
GOOD_JSON = """```json
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
```"""
|
||||
|
||||
JSON_WITH_NEW_LINES = """
|
||||
|
||||
```json
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
JSON_WITH_NEW_LINES_INSIDE = """```json
|
||||
{
|
||||
|
||||
"foo": "bar"
|
||||
|
||||
}
|
||||
```"""
|
||||
|
||||
JSON_WITH_NEW_LINES_EVERYWHERE = """
|
||||
|
||||
```json
|
||||
|
||||
{
|
||||
|
||||
"foo": "bar"
|
||||
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
TICKS_WITH_NEW_LINES_EVERYWHERE = """
|
||||
|
||||
```
|
||||
|
||||
{
|
||||
|
||||
"foo": "bar"
|
||||
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
JSON_WITH_MARKDOWN_CODE_BLOCK = """```json
|
||||
{
|
||||
"foo": "```bar```"
|
||||
}
|
||||
```"""
|
||||
|
||||
JSON_WITH_MARKDOWN_CODE_BLOCK_AND_NEWLINES = """```json
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "```bar\n<div id="1" class=\"value\">\n\ttext\n</div>```"
|
||||
}
|
||||
```"""
|
||||
|
||||
JSON_WITH_UNESCAPED_QUOTES_IN_NESTED_JSON = """```json
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "{"foo": "bar", "bar": "foo"}"
|
||||
}
|
||||
```"""
|
||||
|
||||
JSON_WITH_ESCAPED_QUOTES_IN_NESTED_JSON = """```json
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "{\"foo\": \"bar\", \"bar\": \"foo\"}"
|
||||
}
|
||||
```"""
|
||||
|
||||
JSON_WITH_PYTHON_DICT = """```json
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": {"foo": "bar", "bar": "foo"}
|
||||
}
|
||||
```"""
|
||||
|
||||
JSON_WITH_ESCAPED_DOUBLE_QUOTES_IN_NESTED_JSON = """```json
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "{\\"foo\\": \\"bar\\", \\"bar\\": \\"foo\\"}"
|
||||
}
|
||||
```"""
|
||||
|
||||
NO_TICKS = """{
|
||||
"foo": "bar"
|
||||
}"""
|
||||
|
||||
NO_TICKS_WHITE_SPACE = """
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
"""
|
||||
|
||||
TEXT_BEFORE = """Thought: I need to use the search tool
|
||||
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
```"""
|
||||
|
||||
TEXT_AFTER = """```
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
```
|
||||
This should do the trick"""
|
||||
|
||||
TEXT_BEFORE_AND_AFTER = """Action: Testing
|
||||
|
||||
```
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
```
|
||||
This should do the trick"""
|
||||
|
||||
TEST_CASES = [
|
||||
GOOD_JSON,
|
||||
JSON_WITH_NEW_LINES,
|
||||
JSON_WITH_NEW_LINES_INSIDE,
|
||||
JSON_WITH_NEW_LINES_EVERYWHERE,
|
||||
TICKS_WITH_NEW_LINES_EVERYWHERE,
|
||||
NO_TICKS,
|
||||
NO_TICKS_WHITE_SPACE,
|
||||
TEXT_BEFORE,
|
||||
TEXT_AFTER,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("json_string", TEST_CASES)
|
||||
def test_parse_json(json_string: str) -> None:
|
||||
parsed = parse_json_markdown(json_string)
|
||||
assert parsed == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_parse_json_with_code_blocks() -> None:
|
||||
parsed = parse_json_markdown(JSON_WITH_MARKDOWN_CODE_BLOCK)
|
||||
assert parsed == {"foo": "```bar```"}
|
||||
|
||||
parsed = parse_json_markdown(JSON_WITH_MARKDOWN_CODE_BLOCK_AND_NEWLINES)
|
||||
|
||||
assert parsed == {
|
||||
"action": "Final Answer",
|
||||
"action_input": '```bar\n<div id="1" class="value">\n\ttext\n</div>```',
|
||||
}
|
||||
|
||||
|
||||
TEST_CASES_ESCAPED_QUOTES = [
|
||||
JSON_WITH_UNESCAPED_QUOTES_IN_NESTED_JSON,
|
||||
JSON_WITH_ESCAPED_QUOTES_IN_NESTED_JSON,
|
||||
JSON_WITH_ESCAPED_DOUBLE_QUOTES_IN_NESTED_JSON,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("json_string", TEST_CASES_ESCAPED_QUOTES)
|
||||
def test_parse_nested_json_with_escaped_quotes(json_string: str) -> None:
|
||||
parsed = parse_json_markdown(json_string)
|
||||
assert parsed == {
|
||||
"action": "Final Answer",
|
||||
"action_input": '{"foo": "bar", "bar": "foo"}',
|
||||
}
|
||||
|
||||
|
||||
def test_parse_json_with_python_dict() -> None:
|
||||
parsed = parse_json_markdown(JSON_WITH_PYTHON_DICT)
|
||||
assert parsed == {
|
||||
"action": "Final Answer",
|
||||
"action_input": {"foo": "bar", "bar": "foo"},
|
||||
}
|
||||
|
||||
|
||||
TEST_CASES_PARTIAL = [
|
||||
('{"foo": "bar", "bar": "foo"}', '{"foo": "bar", "bar": "foo"}'),
|
||||
('{"foo": "bar", "bar": "foo', '{"foo": "bar", "bar": "foo"}'),
|
||||
('{"foo": "bar", "bar": "foo}', '{"foo": "bar", "bar": "foo}"}'),
|
||||
('{"foo": "bar", "bar": "foo[', '{"foo": "bar", "bar": "foo["}'),
|
||||
('{"foo": "bar", "bar": "foo\\"', '{"foo": "bar", "bar": "foo\\""}'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("json_strings", TEST_CASES_PARTIAL)
|
||||
def test_parse_partial_json(json_strings: Tuple[str, str]) -> None:
|
||||
case, expected = json_strings
|
||||
parsed = parse_partial_json(case)
|
||||
assert parsed == json.loads(expected)
|
||||
|
||||
|
||||
STREAMED_TOKENS = """
|
||||
{
|
||||
|
||||
"
|
||||
setup
|
||||
":
|
||||
"
|
||||
Why
|
||||
did
|
||||
the
|
||||
bears
|
||||
start
|
||||
a
|
||||
band
|
||||
called
|
||||
Bears
|
||||
Bears
|
||||
Bears
|
||||
?
|
||||
"
|
||||
,
|
||||
"
|
||||
punchline
|
||||
":
|
||||
"
|
||||
Because
|
||||
they
|
||||
wanted
|
||||
to
|
||||
play
|
||||
bear
|
||||
-y
|
||||
good
|
||||
music
|
||||
!
|
||||
"
|
||||
,
|
||||
"
|
||||
audience
|
||||
":
|
||||
[
|
||||
"
|
||||
Haha
|
||||
"
|
||||
,
|
||||
"
|
||||
So
|
||||
funny
|
||||
"
|
||||
]
|
||||
|
||||
}
|
||||
""".splitlines()
|
||||
|
||||
EXPECTED_STREAMED_JSON = [
|
||||
{},
|
||||
{"setup": ""},
|
||||
{"setup": "Why"},
|
||||
{"setup": "Why did"},
|
||||
{"setup": "Why did the"},
|
||||
{"setup": "Why did the bears"},
|
||||
{"setup": "Why did the bears start"},
|
||||
{"setup": "Why did the bears start a"},
|
||||
{"setup": "Why did the bears start a band"},
|
||||
{"setup": "Why did the bears start a band called"},
|
||||
{"setup": "Why did the bears start a band called Bears"},
|
||||
{"setup": "Why did the bears start a band called Bears Bears"},
|
||||
{"setup": "Why did the bears start a band called Bears Bears Bears"},
|
||||
{"setup": "Why did the bears start a band called Bears Bears Bears ?"},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear -y",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear -y good",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear -y good music",
|
||||
},
|
||||
{
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": [],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": [""],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": ["Haha"],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": ["Haha", ""],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": ["Haha", "So"],
|
||||
},
|
||||
{
|
||||
"punchline": "Because they wanted to play bear -y good music !",
|
||||
"setup": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
"audience": ["Haha", "So funny"],
|
||||
},
|
||||
]
|
||||
|
||||
EXPECTED_STREAMED_JSON_DIFF = [
|
||||
[{"op": "replace", "path": "", "value": {}}],
|
||||
[{"op": "add", "path": "/setup", "value": ""}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the bears"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the bears start"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a"}],
|
||||
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a band"}],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called Bears",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called Bears Bears",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called Bears Bears Bears",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/setup",
|
||||
"value": "Why did the bears start a band called Bears Bears Bears ?",
|
||||
}
|
||||
],
|
||||
[{"op": "add", "path": "/punchline", "value": ""}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because"}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because they"}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because they wanted"}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to"}],
|
||||
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to play"}],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear -y",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear -y good",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear -y good music",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/punchline",
|
||||
"value": "Because they wanted to play bear -y good music !",
|
||||
}
|
||||
],
|
||||
[{"op": "add", "path": "/audience", "value": []}],
|
||||
[{"op": "add", "path": "/audience/0", "value": ""}],
|
||||
[{"op": "replace", "path": "/audience/0", "value": "Haha"}],
|
||||
[{"op": "add", "path": "/audience/1", "value": ""}],
|
||||
[{"op": "replace", "path": "/audience/1", "value": "So"}],
|
||||
[{"op": "replace", "path": "/audience/1", "value": "So funny"}],
|
||||
]
|
||||
|
||||
|
||||
def test_partial_text_json_output_parser() -> None:
|
||||
def input_iter(_: Any) -> Iterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | SimpleJsonOutputParser()
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
def test_partial_text_json_output_parser_diff() -> None:
|
||||
def input_iter(_: Any) -> Iterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | SimpleJsonOutputParser(diff=True)
|
||||
|
||||
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
|
||||
|
||||
|
||||
async def test_partial_text_json_output_parser_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | SimpleJsonOutputParser()
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
async def test_partial_text_json_output_parser_diff_async() -> None:
|
||||
async def input_iter(_: Any) -> AsyncIterator[str]:
|
||||
for token in STREAMED_TOKENS:
|
||||
yield token
|
||||
|
||||
chain = input_iter | SimpleJsonOutputParser(diff=True)
|
||||
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
|
@ -1,7 +1,7 @@
|
||||
"""Test XMLOutputParser"""
|
||||
import pytest
|
||||
|
||||
from langchain.output_parsers.xml import XMLOutputParser
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
|
||||
DEF_RESULT_ENCODING = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<foo>
|
@ -1,194 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import jsonpatch
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseCumulativeTransformOutputParser
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
value = match.group(2)
|
||||
value = re.sub(r"\n", r"\\n", value)
|
||||
value = re.sub(r"\r", r"\\r", value)
|
||||
value = re.sub(r"\t", r"\\t", value)
|
||||
value = re.sub(r'(?<!\\)"', r"\"", value)
|
||||
|
||||
return match.group(1) + value + match.group(3)
|
||||
|
||||
|
||||
def _custom_parser(multiline_string: str) -> str:
|
||||
"""
|
||||
The LLM response for `action_input` may be a multiline
|
||||
string containing unescaped newlines, tabs or quotes. This function
|
||||
replaces those characters with their escaped counterparts.
|
||||
(newlines in JSON must be double-escaped: `\\n`)
|
||||
"""
|
||||
if isinstance(multiline_string, (bytes, bytearray)):
|
||||
multiline_string = multiline_string.decode()
|
||||
|
||||
multiline_string = re.sub(
|
||||
r'("action_input"\:\s*")(.*)(")',
|
||||
_replace_new_line,
|
||||
multiline_string,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
return multiline_string
|
||||
|
||||
|
||||
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
|
||||
# MIT License
|
||||
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
||||
"""Parse a JSON string that may be missing closing braces.
|
||||
|
||||
Args:
|
||||
s: The JSON string to parse.
|
||||
strict: Whether to use strict parsing. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
# Attempt to parse the string as-is.
|
||||
try:
|
||||
return json.loads(s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Initialize variables.
|
||||
new_s = ""
|
||||
stack = []
|
||||
is_inside_string = False
|
||||
escaped = False
|
||||
|
||||
# Process each character in the string one at a time.
|
||||
for char in s:
|
||||
if is_inside_string:
|
||||
if char == '"' and not escaped:
|
||||
is_inside_string = False
|
||||
elif char == "\n" and not escaped:
|
||||
char = "\\n" # Replace the newline character with the escape sequence.
|
||||
elif char == "\\":
|
||||
escaped = not escaped
|
||||
else:
|
||||
escaped = False
|
||||
else:
|
||||
if char == '"':
|
||||
is_inside_string = True
|
||||
escaped = False
|
||||
elif char == "{":
|
||||
stack.append("}")
|
||||
elif char == "[":
|
||||
stack.append("]")
|
||||
elif char == "}" or char == "]":
|
||||
if stack and stack[-1] == char:
|
||||
stack.pop()
|
||||
else:
|
||||
# Mismatched closing character; the input is malformed.
|
||||
return None
|
||||
|
||||
# Append the processed character to the new string.
|
||||
new_s += char
|
||||
|
||||
# If we're still inside a string at the end of processing,
|
||||
# we need to close the string.
|
||||
if is_inside_string:
|
||||
new_s += '"'
|
||||
|
||||
# Close any remaining open structures in the reverse order that they were opened.
|
||||
for closing_char in reversed(stack):
|
||||
new_s += closing_char
|
||||
|
||||
# Attempt to parse the modified string as JSON.
|
||||
try:
|
||||
return json.loads(new_s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
# If we still can't parse the string as JSON, return None to indicate failure.
|
||||
return None
|
||||
|
||||
|
||||
def parse_json_markdown(
|
||||
json_string: str, *, parser: Callable[[str], Any] = json.loads
|
||||
) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string.
|
||||
|
||||
Args:
|
||||
json_string: The Markdown string.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
# Try to find JSON string within triple backticks
|
||||
match = re.search(r"```(json)?(.*)```", json_string, re.DOTALL)
|
||||
|
||||
# If no match found, assume the entire string is a JSON string
|
||||
if match is None:
|
||||
json_str = json_string
|
||||
else:
|
||||
# If match found, use the content within the backticks
|
||||
json_str = match.group(2)
|
||||
|
||||
# Strip whitespace and newlines from the start and end
|
||||
json_str = json_str.strip()
|
||||
|
||||
# handle newlines and other special characters inside the returned value
|
||||
json_str = _custom_parser(json_str)
|
||||
|
||||
# Parse the JSON string into a Python dictionary
|
||||
parsed = parser(json_str)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string and check that it
|
||||
contains the expected keys.
|
||||
|
||||
Args:
|
||||
text: The Markdown string.
|
||||
expected_keys: The expected keys in the JSON string.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{key}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
|
||||
|
||||
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse the output of an LLM call to a JSON object.
|
||||
|
||||
When used in streaming mode, it will yield partial JSON objects containing
|
||||
all the keys that have been returned so far.
|
||||
|
||||
In streaming, if `diff` is set to `True`, yields JSONPatch operations
|
||||
describing the difference between the previous and the current object.
|
||||
"""
|
||||
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
text = text.strip()
|
||||
try:
|
||||
return parse_json_markdown(text.strip(), parser=parse_partial_json)
|
||||
except JSONDecodeError as e:
|
||||
raise OutputParserException(f"Invalid json output: {text}") from e
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "simple_json_output_parser"
|
||||
from langchain_core.output_parsers.json import (
|
||||
SimpleJsonOutputParser,
|
||||
parse_and_check_json_markdown,
|
||||
parse_json_markdown,
|
||||
parse_partial_json,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SimpleJsonOutputParser",
|
||||
"parse_partial_json",
|
||||
"parse_json_markdown",
|
||||
"parse_and_check_json_markdown",
|
||||
]
|
||||
|
@ -1,122 +1,3 @@
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
from langchain_core.runnables.utils import AddableDict
|
||||
|
||||
from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS
|
||||
|
||||
|
||||
class XMLOutputParser(BaseTransformOutputParser):
|
||||
"""Parse an output using xml format."""
|
||||
|
||||
tags: Optional[List[str]] = None
|
||||
encoding_matcher: re.Pattern = re.compile(
|
||||
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
|
||||
)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
||||
|
||||
def parse(self, text: str) -> Dict[str, List[Any]]:
|
||||
text = text.strip("`").strip("xml")
|
||||
encoding_match = self.encoding_matcher.search(text)
|
||||
if encoding_match:
|
||||
text = encoding_match.group(2)
|
||||
|
||||
text = text.strip()
|
||||
if (text.startswith("<") or text.startswith("\n<")) and (
|
||||
text.endswith(">") or text.endswith(">\n")
|
||||
):
|
||||
root = ET.fromstring(text)
|
||||
return self._root_to_dict(root)
|
||||
else:
|
||||
raise ValueError(f"Could not parse output: {text}")
|
||||
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
) -> Iterator[AddableDict]:
|
||||
parser = ET.XMLPullParser(["start", "end"])
|
||||
current_path: List[str] = []
|
||||
current_path_has_children = False
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
# extract text
|
||||
chunk_content = chunk.content
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
chunk = chunk_content
|
||||
# pass chunk to parser
|
||||
parser.feed(chunk)
|
||||
# yield all events
|
||||
for event, elem in parser.read_events():
|
||||
if event == "start":
|
||||
# update current path
|
||||
current_path.append(elem.tag)
|
||||
current_path_has_children = False
|
||||
elif event == "end":
|
||||
# remove last element from current path
|
||||
current_path.pop()
|
||||
# yield element
|
||||
if not current_path_has_children:
|
||||
yield nested_element(current_path, elem)
|
||||
# prevent yielding of parent element
|
||||
current_path_has_children = True
|
||||
# close parser
|
||||
parser.close()
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
parser = ET.XMLPullParser(["start", "end"])
|
||||
current_path: List[str] = []
|
||||
current_path_has_children = False
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
# extract text
|
||||
chunk_content = chunk.content
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
chunk = chunk_content
|
||||
# pass chunk to parser
|
||||
parser.feed(chunk)
|
||||
# yield all events
|
||||
for event, elem in parser.read_events():
|
||||
if event == "start":
|
||||
# update current path
|
||||
current_path.append(elem.tag)
|
||||
current_path_has_children = False
|
||||
elif event == "end":
|
||||
# remove last element from current path
|
||||
current_path.pop()
|
||||
# yield element
|
||||
if not current_path_has_children:
|
||||
yield nested_element(current_path, elem)
|
||||
# prevent yielding of parent element
|
||||
current_path_has_children = True
|
||||
# close parser
|
||||
parser.close()
|
||||
|
||||
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
|
||||
"""Converts xml tree to python dictionary."""
|
||||
result: Dict[str, List[Any]] = {root.tag: []}
|
||||
for child in root:
|
||||
if len(child) == 0:
|
||||
result[root.tag].append({child.tag: child.text})
|
||||
else:
|
||||
result[root.tag].append(self._root_to_dict(child))
|
||||
return result
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "xml"
|
||||
|
||||
|
||||
def nested_element(path: List[str], elem: ET.Element) -> Any:
|
||||
"""Get nested element from path."""
|
||||
if len(path) == 0:
|
||||
return AddableDict({elem.tag: elem.text})
|
||||
else:
|
||||
return AddableDict({path[0]: [nested_element(path[1:], elem)]})
|
||||
__all__ = ["XMLOutputParser"]
|
||||
|
Loading…
Reference in New Issue