diff --git a/langchain/output_parsers/combining.py b/langchain/output_parsers/combining.py new file mode 100644 index 00000000..038919b6 --- /dev/null +++ b/langchain/output_parsers/combining.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Any, Dict, List + +from pydantic import root_validator + +from langchain.schema import BaseOutputParser + + +class CombiningOutputParser(BaseOutputParser): + """Class to combine multiple output parsers into one.""" + + parsers: List[BaseOutputParser] + + @root_validator() + def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate the parsers.""" + parsers = values["parsers"] + if len(parsers) < 2: + raise ValueError("Must have at least two parsers") + for parser in parsers: + if parser._type == "combining": + raise ValueError("Cannot nest combining parsers") + if parser._type == "list": + raise ValueError("Cannot comine list parsers") + return values + + @property + def _type(self) -> str: + """Return the type key.""" + return "combining" + + def get_format_instructions(self) -> str: + """Instructions on how the LLM output should be formatted.""" + + initial = f"For your first output: {self.parsers[0].get_format_instructions()}" + subsequent = "\n".join( + [ + f"Complete that output fully. Then produce another output, separated by two newline characters: {p.get_format_instructions()}" # noqa: E501 + for p in self.parsers[1:] + ] + ) + return f"{initial}\n{subsequent}" + + def parse(self, text: str) -> Dict[str, Any]: + """Parse the output of an LLM call.""" + texts = text.split("\n\n") + output = dict() + for i, parser in enumerate(self.parsers): + output.update(parser.parse(texts[i].strip())) + return output diff --git a/langchain/output_parsers/fix.py b/langchain/output_parsers/fix.py index b695586e..dfa3d639 100644 --- a/langchain/output_parsers/fix.py +++ b/langchain/output_parsers/fix.py @@ -41,3 +41,7 @@ class OutputFixingParser(BaseOutputParser[T]): def get_format_instructions(self) -> str: return self.parser.get_format_instructions() + + @property + def _type(self) -> str: + return self.parser._type diff --git a/langchain/output_parsers/list.py b/langchain/output_parsers/list.py index 32b26742..1cf2d39f 100644 --- a/langchain/output_parsers/list.py +++ b/langchain/output_parsers/list.py @@ -9,6 +9,10 @@ from langchain.schema import BaseOutputParser class ListOutputParser(BaseOutputParser): """Class to parse the output of an LLM call to a list.""" + @property + def _type(self) -> str: + return "list" + @abstractmethod def parse(self, text: str) -> List[str]: """Parse the output of an LLM call.""" diff --git a/langchain/output_parsers/pydantic.py b/langchain/output_parsers/pydantic.py index 7a4050d0..e1e5e716 100644 --- a/langchain/output_parsers/pydantic.py +++ b/langchain/output_parsers/pydantic.py @@ -43,3 +43,7 @@ class PydanticOutputParser(BaseOutputParser[T]): schema_str = json.dumps(reduced_schema) return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str) + + @property + def _type(self) -> str: + return "pydantic" diff --git a/langchain/output_parsers/retry.py b/langchain/output_parsers/retry.py index 6ef08cf6..b1982608 100644 --- a/langchain/output_parsers/retry.py +++ b/langchain/output_parsers/retry.py @@ -76,6 +76,10 @@ class RetryOutputParser(BaseOutputParser[T]): def get_format_instructions(self) -> str: return self.parser.get_format_instructions() + @property + def _type(self) -> str: + return self.parser._type + class RetryWithErrorOutputParser(BaseOutputParser[T]): """Wraps a parser and tries to fix parsing errors. diff --git a/langchain/output_parsers/structured.py b/langchain/output_parsers/structured.py index d9c7b83a..af9b80bc 100644 --- a/langchain/output_parsers/structured.py +++ b/langchain/output_parsers/structured.py @@ -56,3 +56,7 @@ class StructuredOutputParser(BaseOutputParser): f"to be present, but got {json_obj}" ) return json_obj + + @property + def _type(self) -> str: + return "structured" diff --git a/tests/unit_tests/output_parsers/test_combining_parser.py b/tests/unit_tests/output_parsers/test_combining_parser.py new file mode 100644 index 00000000..21a3ab6a --- /dev/null +++ b/tests/unit_tests/output_parsers/test_combining_parser.py @@ -0,0 +1,45 @@ +"""Test in memory docstore.""" +from langchain.output_parsers.combining import CombiningOutputParser +from langchain.output_parsers.regex import RegexParser +from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser + +DEF_EXPECTED_RESULT = { + "answer": "Paris", + "source": "https://en.wikipedia.org/wiki/France", + "confidence": "A", + "explanation": "Paris is the capital of France according to Wikipedia.", +} + +DEF_README = """```json +{ + "answer": "Paris", + "source": "https://en.wikipedia.org/wiki/France" +} +``` + +//Confidence: A, Explanation: Paris is the capital of France according to Wikipedia.""" + + +def test_combining_dict_result() -> None: + """Test combining result.""" + parsers = [ + StructuredOutputParser( + response_schemas=[ + ResponseSchema( + name="answer", description="answer to the user's question" + ), + ResponseSchema( + name="source", + description="source used to answer the user's question", + ), + ] + ), + RegexParser( + regex=r"Confidence: (A|B|C), Explanation: (.*)", + output_keys=["confidence", "explanation"], + default_output_key="noConfidence", + ), + ] + combining_parser = CombiningOutputParser(parsers=parsers) + result_dict = combining_parser.parse(DEF_README) + assert DEF_EXPECTED_RESULT == result_dict