diff --git a/langchain/output_parsers/__init__.py b/langchain/output_parsers/__init__.py index cac915e0..a5d64840 100644 --- a/langchain/output_parsers/__init__.py +++ b/langchain/output_parsers/__init__.py @@ -4,10 +4,12 @@ from langchain.output_parsers.list import ( ListOutputParser, ) from langchain.output_parsers.regex import RegexParser +from langchain.output_parsers.regex_dict import RegexDictParser from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser __all__ = [ "RegexParser", + "RegexDictParser", "ListOutputParser", "CommaSeparatedListOutputParser", "BaseOutputParser", diff --git a/langchain/output_parsers/regex_dict.py b/langchain/output_parsers/regex_dict.py new file mode 100644 index 00000000..f17e8203 --- /dev/null +++ b/langchain/output_parsers/regex_dict.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import re +from typing import Dict, Optional + +from pydantic import BaseModel + +from langchain.output_parsers.base import BaseOutputParser + + +class RegexDictParser(BaseOutputParser, BaseModel): + """Class to parse the output into a dictionary.""" + + regex_pattern: str = r"{}:\s?([^.'\n']*)\.?" # : :meta private: + output_key_to_format: Dict[str, str] + no_update_value: Optional[str] = None + + @property + def _type(self) -> str: + """Return the type key.""" + return "regex_dict_parser" + + def parse(self, text: str) -> Dict[str, str]: + """Parse the output of an LLM call.""" + result = {} + for output_key, expected_format in self.output_key_to_format.items(): + specific_regex = self.regex_pattern.format(re.escape(expected_format)) + matches = re.findall(specific_regex, text) + if not matches: + raise ValueError( + f"No match found for output key: {output_key} with expected format \ + {expected_format} on text {text}" + ) + elif len(matches) > 1: + raise ValueError( + f"Multiple matches found for output key: {output_key} with \ + expected format {expected_format} on text {text}" + ) + elif ( + self.no_update_value is not None and matches[0] == self.no_update_value + ): + continue + else: + result[output_key] = matches[0] + return result diff --git a/tests/unit_tests/output_parsers/__init__.py b/tests/unit_tests/output_parsers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/output_parsers/test_regex_dict.py b/tests/unit_tests/output_parsers/test_regex_dict.py new file mode 100644 index 00000000..09df585a --- /dev/null +++ b/tests/unit_tests/output_parsers/test_regex_dict.py @@ -0,0 +1,37 @@ +"""Test in memory docstore.""" +from langchain.output_parsers.regex_dict import RegexDictParser + +DEF_EXPECTED_RESULT = {"action": "Search", "action_input": "How to use this class?"} + +DEF_OUTPUT_KEY_TO_FORMAT = {"action": "Action", "action_input": "Action Input"} + +DEF_README = """We have just received a new result from the LLM, and our next step is +to filter and read its format using regular expressions to identify specific fields, +such as: + +- Action: Search +- Action Input: How to use this class? +- Additional Fields: "N/A" + +To assist us in this task, we use the regex_dict class. This class allows us to send a +dictionary containing an output key and the expected format, which in turn enables us to +retrieve the result of the matching formats and extract specific information from it. + +To exclude irrelevant information from our return dictionary, we can instruct the LLM to +use a specific command that notifies us when it doesn't know the answer. We call this +variable the "no_update_value", and for our current case, we set it to "N/A". Therefore, +we expect the result to only contain the following fields: +{ + {key = action, value = search} + {key = action_input, value = "How to use this class?"}. +}""" + + +def test_regex_dict_result() -> None: + """Test regex dict result.""" + regex_dict_parser = RegexDictParser( + output_key_to_format=DEF_OUTPUT_KEY_TO_FORMAT, no_update_value="N/A" + ) + result_dict = regex_dict_parser.parse(DEF_README) + print("parse_result:", result_dict) + assert DEF_EXPECTED_RESULT == result_dict