forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from typing import List, Optional
|
|
|
|
from langchain.schema import OutputParserException
|
|
|
|
REGEXES = {
|
|
"json_markdown": r"```(json)?(.*?)```",
|
|
# must use greedy matching to match the outermost code block
|
|
"nested_json_md_code_block": r"```(json)?(.*)```",
|
|
}
|
|
|
|
|
|
def parse_json_markdown(json_string: str, regex: Optional[str] = None) -> dict:
|
|
# Try to find JSON string within triple backticks
|
|
if regex is None:
|
|
regex = REGEXES["json_markdown"]
|
|
match = re.search(regex, json_string, re.DOTALL)
|
|
|
|
# If no match found, assume the entire string is a JSON string
|
|
if match is None:
|
|
json_str = json_string
|
|
else:
|
|
# If match found, use the content within the backticks
|
|
json_str = match.group(2)
|
|
|
|
# Strip whitespace and newlines from the start and end
|
|
json_str = json_str.strip()
|
|
|
|
# Parse the JSON string into a Python dictionary
|
|
parsed = json.loads(json_str)
|
|
|
|
return parsed
|
|
|
|
|
|
def fix_code_in_json(text: str) -> str:
|
|
"""Fixes nested code block in json markdown"""
|
|
# Extract the code block and replace it with a placeholder
|
|
pattern = r"```([^`]*?)```"
|
|
match = re.search(pattern, text)
|
|
if match:
|
|
code_block = match.group(1)
|
|
text = re.sub(pattern, "CODE_BLOCK_PLACEHOLDER", text, count=1)
|
|
|
|
# Escape the special characters in the code block
|
|
escaped_code_block = (
|
|
code_block.replace("\n", "\\n").replace("\t", "\\t").replace('"', '\\"')
|
|
)
|
|
|
|
# Add backtick pairs to escaped code block
|
|
escaped_code_block = "[BEGIN_CODE]" + escaped_code_block + "[END_CODE]"
|
|
|
|
# Replace the placeholder in the original text with the escaped code block
|
|
text = text.replace("CODE_BLOCK_PLACEHOLDER", escaped_code_block)
|
|
|
|
return text
|
|
|
|
|
|
def fix_json_with_embedded_code_block(text: str, max_loop: int = 20) -> dict:
|
|
"""Try to fix json with embedded code block.
|
|
|
|
Args:
|
|
text: JSON string with embedded code block
|
|
max_loop: Maximum number of loops to try fixing the JSON string
|
|
"""
|
|
loop = 0
|
|
while True:
|
|
if loop > max_loop:
|
|
raise ValueError("Max loop reached")
|
|
try:
|
|
text = fix_code_in_json(text)
|
|
json.loads(text)
|
|
break
|
|
except json.JSONDecodeError as e:
|
|
if text[e.pos] == "\n":
|
|
text = text[: e.pos] + "\\n" + text[e.pos + 1 :]
|
|
text = text.replace("[BEGIN_CODE]", "```")
|
|
else:
|
|
raise
|
|
finally:
|
|
loop += 1
|
|
final_text = text.replace("[END_CODE]", "```")
|
|
return json.loads(final_text)
|
|
|
|
|
|
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
|
try:
|
|
json_obj = parse_json_markdown(text)
|
|
except json.JSONDecodeError as e:
|
|
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
|
for key in expected_keys:
|
|
if key not in json_obj:
|
|
raise OutputParserException(
|
|
f"Got invalid return object. Expected key `{key}` "
|
|
f"to be present, but got {json_obj}"
|
|
)
|
|
return json_obj
|