diff --git a/langchain/agents/chat/output_parser.py b/langchain/agents/chat/output_parser.py index 9f143d07..1cdd839e 100644 --- a/langchain/agents/chat/output_parser.py +++ b/langchain/agents/chat/output_parser.py @@ -1,8 +1,8 @@ -import json from typing import Union from langchain.agents.agent import AgentOutputParser from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS +from langchain.output_parsers.json import parse_json_markdown from langchain.schema import AgentAction, AgentFinish, OutputParserException FINAL_ANSWER_ACTION = "Final Answer:" @@ -18,8 +18,7 @@ class ChatOutputParser(AgentOutputParser): {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text ) try: - action = text.split("```")[1] - response = json.loads(action.strip()) + response = parse_json_markdown(text) return AgentAction(response["action"], response["action_input"], text) except Exception: diff --git a/langchain/agents/conversational_chat/output_parser.py b/langchain/agents/conversational_chat/output_parser.py index d8a2c593..1aece492 100644 --- a/langchain/agents/conversational_chat/output_parser.py +++ b/langchain/agents/conversational_chat/output_parser.py @@ -1,10 +1,10 @@ from __future__ import annotations -import json from typing import Union from langchain.agents import AgentOutputParser from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS +from langchain.output_parsers.json import parse_json_markdown from langchain.schema import AgentAction, AgentFinish, OutputParserException @@ -14,19 +14,7 @@ class ConvoOutputParser(AgentOutputParser): def parse(self, text: str) -> Union[AgentAction, AgentFinish]: try: - cleaned_output = text.strip() - if "```json" in cleaned_output: - _, cleaned_output = cleaned_output.split("```json") - if "```" in cleaned_output: - cleaned_output, _ = cleaned_output.split("```") - if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json") :] - if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```") :] - if cleaned_output.endswith("```"): - cleaned_output = cleaned_output[: -len("```")] - cleaned_output = cleaned_output.strip() - response = json.loads(cleaned_output) + response = parse_json_markdown(text) action, action_input = response["action"], response["action_input"] if action == "Final Answer": return AgentFinish({"output": action_input}, text) diff --git a/langchain/chains/query_constructor/base.py b/langchain/chains/query_constructor/base.py index 48adec01..51dc12c6 100644 --- a/langchain/chains/query_constructor/base.py +++ b/langchain/chains/query_constructor/base.py @@ -22,7 +22,7 @@ from langchain.chains.query_constructor.prompt import ( SCHEMA_WITH_LIMIT, ) from langchain.chains.query_constructor.schema import AttributeInfo -from langchain.output_parsers.structured import parse_json_markdown +from langchain.output_parsers.json import parse_and_check_json_markdown from langchain.schema import BaseOutputParser, OutputParserException @@ -33,7 +33,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): def parse(self, text: str) -> StructuredQuery: try: expected_keys = ["query", "filter"] - parsed = parse_json_markdown(text, expected_keys) + parsed = parse_and_check_json_markdown(text, expected_keys) if len(parsed["query"]) == 0: parsed["query"] = " " if parsed["filter"] == "NO_FILTER" or not parsed["filter"]: diff --git a/langchain/chains/router/llm_router.py b/langchain/chains/router/llm_router.py index 9e5be06b..3276324f 100644 --- a/langchain/chains/router/llm_router.py +++ b/langchain/chains/router/llm_router.py @@ -9,7 +9,7 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains import LLMChain from langchain.chains.router.base import RouterChain -from langchain.output_parsers.structured import parse_json_markdown +from langchain.output_parsers.json import parse_and_check_json_markdown from langchain.prompts import BasePromptTemplate from langchain.schema import BaseOutputParser, OutputParserException @@ -77,7 +77,7 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]): def parse(self, text: str) -> Dict[str, Any]: try: expected_keys = ["destination", "next_inputs"] - parsed = parse_json_markdown(text, expected_keys) + parsed = parse_and_check_json_markdown(text, expected_keys) if not isinstance(parsed["destination"], str): raise ValueError("Expected 'destination' to be a string.") if not isinstance(parsed["next_inputs"], self.next_inputs_type): diff --git a/langchain/output_parsers/json.py b/langchain/output_parsers/json.py new file mode 100644 index 00000000..e0c9ac55 --- /dev/null +++ b/langchain/output_parsers/json.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import json +from typing import List + +from langchain.schema import OutputParserException + + +def parse_json_markdown(json_string: str) -> dict: + # Remove the triple backticks if present + json_string = json_string.replace("```json", "").replace("```", "") + + # Strip whitespace and newlines from the start and end + json_string = json_string.strip() + + # Parse the JSON string into a Python dictionary + parsed = json.loads(json_string) + + return parsed + + +def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: + try: + json_obj = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise OutputParserException(f"Got invalid JSON object. Error: {e}") + for key in expected_keys: + if key not in json_obj: + raise OutputParserException( + f"Got invalid return object. Expected key `{key}` " + f"to be present, but got {json_obj}" + ) + return json_obj diff --git a/langchain/output_parsers/structured.py b/langchain/output_parsers/structured.py index 345950f9..9afaf6cd 100644 --- a/langchain/output_parsers/structured.py +++ b/langchain/output_parsers/structured.py @@ -1,12 +1,12 @@ from __future__ import annotations -import json from typing import Any, List from pydantic import BaseModel from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS -from langchain.schema import BaseOutputParser, OutputParserException +from langchain.output_parsers.json import parse_and_check_json_markdown +from langchain.schema import BaseOutputParser line_template = '\t"{name}": {type} // {description}' @@ -22,27 +22,6 @@ def _get_sub_string(schema: ResponseSchema) -> str: ) -def parse_json_markdown(text: str, expected_keys: List[str]) -> Any: - if "```json" not in text: - raise OutputParserException( - f"Got invalid return object. Expected markdown code snippet with JSON " - f"object, but got:\n{text}" - ) - - json_string = text.split("```json")[1].strip().strip("```").strip() - try: - json_obj = json.loads(json_string) - except json.JSONDecodeError as e: - raise OutputParserException(f"Got invalid JSON object. Error: {e}") - for key in expected_keys: - if key not in json_obj: - raise OutputParserException( - f"Got invalid return object. Expected key `{key}` " - f"to be present, but got {json_obj}" - ) - return json_obj - - class StructuredOutputParser(BaseOutputParser): response_schemas: List[ResponseSchema] @@ -60,7 +39,7 @@ class StructuredOutputParser(BaseOutputParser): def parse(self, text: str) -> Any: expected_keys = [rs.name for rs in self.response_schemas] - return parse_json_markdown(text, expected_keys) + return parse_and_check_json_markdown(text, expected_keys) @property def _type(self) -> str: diff --git a/tests/unit_tests/output_parsers/test_json.py b/tests/unit_tests/output_parsers/test_json.py new file mode 100644 index 00000000..4055dd2d --- /dev/null +++ b/tests/unit_tests/output_parsers/test_json.py @@ -0,0 +1,81 @@ +import pytest + +from langchain.output_parsers.json import parse_json_markdown + +GOOD_JSON = """```json +{ + "foo": "bar" +} +```""" + +JSON_WITH_NEW_LINES = """ + +```json +{ + "foo": "bar" +} +``` + +""" + +JSON_WITH_NEW_LINES_INSIDE = """```json +{ + + "foo": "bar" + +} +```""" + +JSON_WITH_NEW_LINES_EVERYWHERE = """ + +```json + +{ + + "foo": "bar" + +} + +``` + +""" + +TICKS_WITH_NEW_LINES_EVERYWHERE = """ + +``` + +{ + + "foo": "bar" + +} + +``` + +""" + +NO_TICKS = """{ + "foo": "bar" +}""" + +NO_TICKS_WHITE_SPACE = """ +{ + "foo": "bar" +} +""" + +TEST_CASES = [ + GOOD_JSON, + JSON_WITH_NEW_LINES, + JSON_WITH_NEW_LINES_INSIDE, + JSON_WITH_NEW_LINES_EVERYWHERE, + TICKS_WITH_NEW_LINES_EVERYWHERE, + NO_TICKS, + NO_TICKS_WHITE_SPACE, +] + + +@pytest.mark.parametrize("json_string", TEST_CASES) +def test_parse_json(json_string: str) -> None: + parsed = parse_json_markdown(json_string) + assert parsed == {"foo": "bar"}