standardize json parsing (#5168)

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
Harrison Chase 2023-05-24 10:03:53 -07:00 committed by GitHub
parent 2b2176a3c1
commit 94cf391ef1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 125 additions and 45 deletions

View File

@ -1,8 +1,8 @@
import json
from typing import Union
from langchain.agents.agent import AgentOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS
from langchain.output_parsers.json import parse_json_markdown
from langchain.schema import AgentAction, AgentFinish, OutputParserException
FINAL_ANSWER_ACTION = "Final Answer:"
@ -18,8 +18,7 @@ class ChatOutputParser(AgentOutputParser):
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
)
try:
action = text.split("```")[1]
response = json.loads(action.strip())
response = parse_json_markdown(text)
return AgentAction(response["action"], response["action_input"], text)
except Exception:

View File

@ -1,10 +1,10 @@
from __future__ import annotations
import json
from typing import Union
from langchain.agents import AgentOutputParser
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
from langchain.output_parsers.json import parse_json_markdown
from langchain.schema import AgentAction, AgentFinish, OutputParserException
@ -14,19 +14,7 @@ class ConvoOutputParser(AgentOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try:
cleaned_output = text.strip()
if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json")
if "```" in cleaned_output:
cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
response = json.loads(cleaned_output)
response = parse_json_markdown(text)
action, action_input = response["action"], response["action_input"]
if action == "Final Answer":
return AgentFinish({"output": action_input}, text)

View File

@ -22,7 +22,7 @@ from langchain.chains.query_constructor.prompt import (
SCHEMA_WITH_LIMIT,
)
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.output_parsers.structured import parse_json_markdown
from langchain.output_parsers.json import parse_and_check_json_markdown
from langchain.schema import BaseOutputParser, OutputParserException
@ -33,7 +33,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
def parse(self, text: str) -> StructuredQuery:
try:
expected_keys = ["query", "filter"]
parsed = parse_json_markdown(text, expected_keys)
parsed = parse_and_check_json_markdown(text, expected_keys)
if len(parsed["query"]) == 0:
parsed["query"] = " "
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:

View File

@ -9,7 +9,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains import LLMChain
from langchain.chains.router.base import RouterChain
from langchain.output_parsers.structured import parse_json_markdown
from langchain.output_parsers.json import parse_and_check_json_markdown
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
@ -77,7 +77,7 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
def parse(self, text: str) -> Dict[str, Any]:
try:
expected_keys = ["destination", "next_inputs"]
parsed = parse_json_markdown(text, expected_keys)
parsed = parse_and_check_json_markdown(text, expected_keys)
if not isinstance(parsed["destination"], str):
raise ValueError("Expected 'destination' to be a string.")
if not isinstance(parsed["next_inputs"], self.next_inputs_type):

View File

@ -0,0 +1,33 @@
from __future__ import annotations
import json
from typing import List
from langchain.schema import OutputParserException
def parse_json_markdown(json_string: str) -> dict:
# Remove the triple backticks if present
json_string = json_string.replace("```json", "").replace("```", "")
# Strip whitespace and newlines from the start and end
json_string = json_string.strip()
# Parse the JSON string into a Python dictionary
parsed = json.loads(json_string)
return parsed
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{key}` "
f"to be present, but got {json_obj}"
)
return json_obj

View File

@ -1,12 +1,12 @@
from __future__ import annotations
import json
from typing import Any, List
from pydantic import BaseModel
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException
from langchain.output_parsers.json import parse_and_check_json_markdown
from langchain.schema import BaseOutputParser
line_template = '\t"{name}": {type} // {description}'
@ -22,27 +22,6 @@ def _get_sub_string(schema: ResponseSchema) -> str:
)
def parse_json_markdown(text: str, expected_keys: List[str]) -> Any:
if "```json" not in text:
raise OutputParserException(
f"Got invalid return object. Expected markdown code snippet with JSON "
f"object, but got:\n{text}"
)
json_string = text.split("```json")[1].strip().strip("```").strip()
try:
json_obj = json.loads(json_string)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{key}` "
f"to be present, but got {json_obj}"
)
return json_obj
class StructuredOutputParser(BaseOutputParser):
response_schemas: List[ResponseSchema]
@ -60,7 +39,7 @@ class StructuredOutputParser(BaseOutputParser):
def parse(self, text: str) -> Any:
expected_keys = [rs.name for rs in self.response_schemas]
return parse_json_markdown(text, expected_keys)
return parse_and_check_json_markdown(text, expected_keys)
@property
def _type(self) -> str:

View File

@ -0,0 +1,81 @@
import pytest
from langchain.output_parsers.json import parse_json_markdown
GOOD_JSON = """```json
{
"foo": "bar"
}
```"""
JSON_WITH_NEW_LINES = """
```json
{
"foo": "bar"
}
```
"""
JSON_WITH_NEW_LINES_INSIDE = """```json
{
"foo": "bar"
}
```"""
JSON_WITH_NEW_LINES_EVERYWHERE = """
```json
{
"foo": "bar"
}
```
"""
TICKS_WITH_NEW_LINES_EVERYWHERE = """
```
{
"foo": "bar"
}
```
"""
NO_TICKS = """{
"foo": "bar"
}"""
NO_TICKS_WHITE_SPACE = """
{
"foo": "bar"
}
"""
TEST_CASES = [
GOOD_JSON,
JSON_WITH_NEW_LINES,
JSON_WITH_NEW_LINES_INSIDE,
JSON_WITH_NEW_LINES_EVERYWHERE,
TICKS_WITH_NEW_LINES_EVERYWHERE,
NO_TICKS,
NO_TICKS_WHITE_SPACE,
]
@pytest.mark.parametrize("json_string", TEST_CASES)
def test_parse_json(json_string: str) -> None:
parsed = parse_json_markdown(json_string)
assert parsed == {"foo": "bar"}