forked from Archives/langchain
standardize json parsing (#5168)
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
2b2176a3c1
commit
94cf391ef1
@ -1,8 +1,8 @@
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
@ -18,8 +18,7 @@ class ChatOutputParser(AgentOutputParser):
|
||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||
)
|
||||
try:
|
||||
action = text.split("```")[1]
|
||||
response = json.loads(action.strip())
|
||||
response = parse_json_markdown(text)
|
||||
return AgentAction(response["action"], response["action_input"], text)
|
||||
|
||||
except Exception:
|
||||
|
@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents import AgentOutputParser
|
||||
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
|
||||
|
||||
@ -14,19 +14,7 @@ class ConvoOutputParser(AgentOutputParser):
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
try:
|
||||
cleaned_output = text.strip()
|
||||
if "```json" in cleaned_output:
|
||||
_, cleaned_output = cleaned_output.split("```json")
|
||||
if "```" in cleaned_output:
|
||||
cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
cleaned_output = cleaned_output[len("```json") :]
|
||||
if cleaned_output.startswith("```"):
|
||||
cleaned_output = cleaned_output[len("```") :]
|
||||
if cleaned_output.endswith("```"):
|
||||
cleaned_output = cleaned_output[: -len("```")]
|
||||
cleaned_output = cleaned_output.strip()
|
||||
response = json.loads(cleaned_output)
|
||||
response = parse_json_markdown(text)
|
||||
action, action_input = response["action"], response["action_input"]
|
||||
if action == "Final Answer":
|
||||
return AgentFinish({"output": action_input}, text)
|
||||
|
@ -22,7 +22,7 @@ from langchain.chains.query_constructor.prompt import (
|
||||
SCHEMA_WITH_LIMIT,
|
||||
)
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.output_parsers.structured import parse_json_markdown
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
def parse(self, text: str) -> StructuredQuery:
|
||||
try:
|
||||
expected_keys = ["query", "filter"]
|
||||
parsed = parse_json_markdown(text, expected_keys)
|
||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||
if len(parsed["query"]) == 0:
|
||||
parsed["query"] = " "
|
||||
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
|
||||
|
@ -9,7 +9,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.router.base import RouterChain
|
||||
from langchain.output_parsers.structured import parse_json_markdown
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
@ -77,7 +77,7 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
||||
def parse(self, text: str) -> Dict[str, Any]:
|
||||
try:
|
||||
expected_keys = ["destination", "next_inputs"]
|
||||
parsed = parse_json_markdown(text, expected_keys)
|
||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||
if not isinstance(parsed["destination"], str):
|
||||
raise ValueError("Expected 'destination' to be a string.")
|
||||
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
||||
|
33
langchain/output_parsers/json.py
Normal file
33
langchain/output_parsers/json.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import OutputParserException
|
||||
|
||||
|
||||
def parse_json_markdown(json_string: str) -> dict:
|
||||
# Remove the triple backticks if present
|
||||
json_string = json_string.replace("```json", "").replace("```", "")
|
||||
|
||||
# Strip whitespace and newlines from the start and end
|
||||
json_string = json_string.strip()
|
||||
|
||||
# Parse the JSON string into a Python dictionary
|
||||
parsed = json.loads(json_string)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{key}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
@ -1,12 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
line_template = '\t"{name}": {type} // {description}'
|
||||
|
||||
@ -22,27 +22,6 @@ def _get_sub_string(schema: ResponseSchema) -> str:
|
||||
)
|
||||
|
||||
|
||||
def parse_json_markdown(text: str, expected_keys: List[str]) -> Any:
|
||||
if "```json" not in text:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected markdown code snippet with JSON "
|
||||
f"object, but got:\n{text}"
|
||||
)
|
||||
|
||||
json_string = text.split("```json")[1].strip().strip("```").strip()
|
||||
try:
|
||||
json_obj = json.loads(json_string)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{key}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
|
||||
|
||||
class StructuredOutputParser(BaseOutputParser):
|
||||
response_schemas: List[ResponseSchema]
|
||||
|
||||
@ -60,7 +39,7 @@ class StructuredOutputParser(BaseOutputParser):
|
||||
|
||||
def parse(self, text: str) -> Any:
|
||||
expected_keys = [rs.name for rs in self.response_schemas]
|
||||
return parse_json_markdown(text, expected_keys)
|
||||
return parse_and_check_json_markdown(text, expected_keys)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
|
81
tests/unit_tests/output_parsers/test_json.py
Normal file
81
tests/unit_tests/output_parsers/test_json.py
Normal file
@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
|
||||
from langchain.output_parsers.json import parse_json_markdown
|
||||
|
||||
GOOD_JSON = """```json
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
```"""
|
||||
|
||||
JSON_WITH_NEW_LINES = """
|
||||
|
||||
```json
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
JSON_WITH_NEW_LINES_INSIDE = """```json
|
||||
{
|
||||
|
||||
"foo": "bar"
|
||||
|
||||
}
|
||||
```"""
|
||||
|
||||
JSON_WITH_NEW_LINES_EVERYWHERE = """
|
||||
|
||||
```json
|
||||
|
||||
{
|
||||
|
||||
"foo": "bar"
|
||||
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
TICKS_WITH_NEW_LINES_EVERYWHERE = """
|
||||
|
||||
```
|
||||
|
||||
{
|
||||
|
||||
"foo": "bar"
|
||||
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
NO_TICKS = """{
|
||||
"foo": "bar"
|
||||
}"""
|
||||
|
||||
NO_TICKS_WHITE_SPACE = """
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
"""
|
||||
|
||||
TEST_CASES = [
|
||||
GOOD_JSON,
|
||||
JSON_WITH_NEW_LINES,
|
||||
JSON_WITH_NEW_LINES_INSIDE,
|
||||
JSON_WITH_NEW_LINES_EVERYWHERE,
|
||||
TICKS_WITH_NEW_LINES_EVERYWHERE,
|
||||
NO_TICKS,
|
||||
NO_TICKS_WHITE_SPACE,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("json_string", TEST_CASES)
|
||||
def test_parse_json(json_string: str) -> None:
|
||||
parsed = parse_json_markdown(json_string)
|
||||
assert parsed == {"foo": "bar"}
|
Loading…
Reference in New Issue
Block a user