forked from Archives/langchain
standardize json parsing (#5168)
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
2b2176a3c1
commit
94cf391ef1
@ -1,8 +1,8 @@
|
|||||||
import json
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from langchain.agents.agent import AgentOutputParser
|
from langchain.agents.agent import AgentOutputParser
|
||||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS
|
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS
|
||||||
|
from langchain.output_parsers.json import parse_json_markdown
|
||||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||||
|
|
||||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||||
@ -18,8 +18,7 @@ class ChatOutputParser(AgentOutputParser):
|
|||||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
action = text.split("```")[1]
|
response = parse_json_markdown(text)
|
||||||
response = json.loads(action.strip())
|
|
||||||
return AgentAction(response["action"], response["action_input"], text)
|
return AgentAction(response["action"], response["action_input"], text)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from langchain.agents import AgentOutputParser
|
from langchain.agents import AgentOutputParser
|
||||||
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
|
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
|
||||||
|
from langchain.output_parsers.json import parse_json_markdown
|
||||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||||
|
|
||||||
|
|
||||||
@ -14,19 +14,7 @@ class ConvoOutputParser(AgentOutputParser):
|
|||||||
|
|
||||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||||
try:
|
try:
|
||||||
cleaned_output = text.strip()
|
response = parse_json_markdown(text)
|
||||||
if "```json" in cleaned_output:
|
|
||||||
_, cleaned_output = cleaned_output.split("```json")
|
|
||||||
if "```" in cleaned_output:
|
|
||||||
cleaned_output, _ = cleaned_output.split("```")
|
|
||||||
if cleaned_output.startswith("```json"):
|
|
||||||
cleaned_output = cleaned_output[len("```json") :]
|
|
||||||
if cleaned_output.startswith("```"):
|
|
||||||
cleaned_output = cleaned_output[len("```") :]
|
|
||||||
if cleaned_output.endswith("```"):
|
|
||||||
cleaned_output = cleaned_output[: -len("```")]
|
|
||||||
cleaned_output = cleaned_output.strip()
|
|
||||||
response = json.loads(cleaned_output)
|
|
||||||
action, action_input = response["action"], response["action_input"]
|
action, action_input = response["action"], response["action_input"]
|
||||||
if action == "Final Answer":
|
if action == "Final Answer":
|
||||||
return AgentFinish({"output": action_input}, text)
|
return AgentFinish({"output": action_input}, text)
|
||||||
|
@ -22,7 +22,7 @@ from langchain.chains.query_constructor.prompt import (
|
|||||||
SCHEMA_WITH_LIMIT,
|
SCHEMA_WITH_LIMIT,
|
||||||
)
|
)
|
||||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||||
from langchain.output_parsers.structured import parse_json_markdown
|
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|||||||
def parse(self, text: str) -> StructuredQuery:
|
def parse(self, text: str) -> StructuredQuery:
|
||||||
try:
|
try:
|
||||||
expected_keys = ["query", "filter"]
|
expected_keys = ["query", "filter"]
|
||||||
parsed = parse_json_markdown(text, expected_keys)
|
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||||
if len(parsed["query"]) == 0:
|
if len(parsed["query"]) == 0:
|
||||||
parsed["query"] = " "
|
parsed["query"] = " "
|
||||||
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
|
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
|
||||||
|
@ -9,7 +9,7 @@ from langchain.base_language import BaseLanguageModel
|
|||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.chains.router.base import RouterChain
|
from langchain.chains.router.base import RouterChain
|
||||||
from langchain.output_parsers.structured import parse_json_markdown
|
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.prompts import BasePromptTemplate
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
|||||||
def parse(self, text: str) -> Dict[str, Any]:
|
def parse(self, text: str) -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
expected_keys = ["destination", "next_inputs"]
|
expected_keys = ["destination", "next_inputs"]
|
||||||
parsed = parse_json_markdown(text, expected_keys)
|
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||||
if not isinstance(parsed["destination"], str):
|
if not isinstance(parsed["destination"], str):
|
||||||
raise ValueError("Expected 'destination' to be a string.")
|
raise ValueError("Expected 'destination' to be a string.")
|
||||||
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
||||||
|
33
langchain/output_parsers/json.py
Normal file
33
langchain/output_parsers/json.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.schema import OutputParserException
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_markdown(json_string: str) -> dict:
|
||||||
|
# Remove the triple backticks if present
|
||||||
|
json_string = json_string.replace("```json", "").replace("```", "")
|
||||||
|
|
||||||
|
# Strip whitespace and newlines from the start and end
|
||||||
|
json_string = json_string.strip()
|
||||||
|
|
||||||
|
# Parse the JSON string into a Python dictionary
|
||||||
|
parsed = json.loads(json_string)
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||||
|
try:
|
||||||
|
json_obj = parse_json_markdown(text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||||
|
for key in expected_keys:
|
||||||
|
if key not in json_obj:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Got invalid return object. Expected key `{key}` "
|
||||||
|
f"to be present, but got {json_obj}"
|
||||||
|
)
|
||||||
|
return json_obj
|
@ -1,12 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
|
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||||
|
from langchain.schema import BaseOutputParser
|
||||||
|
|
||||||
line_template = '\t"{name}": {type} // {description}'
|
line_template = '\t"{name}": {type} // {description}'
|
||||||
|
|
||||||
@ -22,27 +22,6 @@ def _get_sub_string(schema: ResponseSchema) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_json_markdown(text: str, expected_keys: List[str]) -> Any:
|
|
||||||
if "```json" not in text:
|
|
||||||
raise OutputParserException(
|
|
||||||
f"Got invalid return object. Expected markdown code snippet with JSON "
|
|
||||||
f"object, but got:\n{text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
json_string = text.split("```json")[1].strip().strip("```").strip()
|
|
||||||
try:
|
|
||||||
json_obj = json.loads(json_string)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
|
||||||
for key in expected_keys:
|
|
||||||
if key not in json_obj:
|
|
||||||
raise OutputParserException(
|
|
||||||
f"Got invalid return object. Expected key `{key}` "
|
|
||||||
f"to be present, but got {json_obj}"
|
|
||||||
)
|
|
||||||
return json_obj
|
|
||||||
|
|
||||||
|
|
||||||
class StructuredOutputParser(BaseOutputParser):
|
class StructuredOutputParser(BaseOutputParser):
|
||||||
response_schemas: List[ResponseSchema]
|
response_schemas: List[ResponseSchema]
|
||||||
|
|
||||||
@ -60,7 +39,7 @@ class StructuredOutputParser(BaseOutputParser):
|
|||||||
|
|
||||||
def parse(self, text: str) -> Any:
|
def parse(self, text: str) -> Any:
|
||||||
expected_keys = [rs.name for rs in self.response_schemas]
|
expected_keys = [rs.name for rs in self.response_schemas]
|
||||||
return parse_json_markdown(text, expected_keys)
|
return parse_and_check_json_markdown(text, expected_keys)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type(self) -> str:
|
def _type(self) -> str:
|
||||||
|
81
tests/unit_tests/output_parsers/test_json.py
Normal file
81
tests/unit_tests/output_parsers/test_json.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.output_parsers.json import parse_json_markdown
|
||||||
|
|
||||||
|
GOOD_JSON = """```json
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
JSON_WITH_NEW_LINES = """
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
JSON_WITH_NEW_LINES_INSIDE = """```json
|
||||||
|
{
|
||||||
|
|
||||||
|
"foo": "bar"
|
||||||
|
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
JSON_WITH_NEW_LINES_EVERYWHERE = """
|
||||||
|
|
||||||
|
```json
|
||||||
|
|
||||||
|
{
|
||||||
|
|
||||||
|
"foo": "bar"
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
TICKS_WITH_NEW_LINES_EVERYWHERE = """
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
{
|
||||||
|
|
||||||
|
"foo": "bar"
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
NO_TICKS = """{
|
||||||
|
"foo": "bar"
|
||||||
|
}"""
|
||||||
|
|
||||||
|
NO_TICKS_WHITE_SPACE = """
|
||||||
|
{
|
||||||
|
"foo": "bar"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
GOOD_JSON,
|
||||||
|
JSON_WITH_NEW_LINES,
|
||||||
|
JSON_WITH_NEW_LINES_INSIDE,
|
||||||
|
JSON_WITH_NEW_LINES_EVERYWHERE,
|
||||||
|
TICKS_WITH_NEW_LINES_EVERYWHERE,
|
||||||
|
NO_TICKS,
|
||||||
|
NO_TICKS_WHITE_SPACE,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("json_string", TEST_CASES)
|
||||||
|
def test_parse_json(json_string: str) -> None:
|
||||||
|
parsed = parse_json_markdown(json_string)
|
||||||
|
assert parsed == {"foo": "bar"}
|
Loading…
Reference in New Issue
Block a user