From c9b5a30b3767e0fc73e1903fbd2f39cc7feee748 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 11 Mar 2023 16:41:03 -0800 Subject: [PATCH] move output parsing (#1605) --- .../chain_examples/qa_with_sources.ipynb | 4 +- .../chain_examples/question_answering.ipynb | 2 +- .../chains/combine_documents/map_rerank.py | 2 +- .../question_answering/map_rerank_prompt.py | 2 +- langchain/chains/sql_database/prompt.py | 2 +- langchain/evaluation/qa/generate_prompt.py | 2 +- langchain/output_parsers/__init__.py | 13 ++++ langchain/output_parsers/base.py | 25 +++++++ langchain/output_parsers/list.py | 22 ++++++ langchain/output_parsers/loading.py | 15 ++++ langchain/output_parsers/regex.py | 35 ++++++++++ langchain/prompts/base.py | 69 ++----------------- langchain/prompts/loading.py | 11 +-- tests/unit_tests/chains/test_llm.py | 2 +- 14 files changed, 130 insertions(+), 76 deletions(-) create mode 100644 langchain/output_parsers/__init__.py create mode 100644 langchain/output_parsers/base.py create mode 100644 langchain/output_parsers/list.py create mode 100644 langchain/output_parsers/loading.py create mode 100644 langchain/output_parsers/regex.py diff --git a/docs/modules/indexes/chain_examples/qa_with_sources.ipynb b/docs/modules/indexes/chain_examples/qa_with_sources.ipynb index 29c0d7c8bf..70e570f707 100644 --- a/docs/modules/indexes/chain_examples/qa_with_sources.ipynb +++ b/docs/modules/indexes/chain_examples/qa_with_sources.ipynb @@ -635,7 +635,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.prompts.base import RegexParser\n", + "from langchain.output_parsers import RegexParser\n", "\n", "output_parser = RegexParser(\n", " regex=r\"(.*?)\\nScore: (.*)\",\n", @@ -732,4 +732,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/modules/indexes/chain_examples/question_answering.ipynb b/docs/modules/indexes/chain_examples/question_answering.ipynb index 82b0f651da..4820b5d676 100644 --- a/docs/modules/indexes/chain_examples/question_answering.ipynb +++ b/docs/modules/indexes/chain_examples/question_answering.ipynb @@ -635,7 +635,7 @@ } ], "source": [ - "from langchain.prompts.base import RegexParser\n", + "from langchain.output_parsers import RegexParser\n", "\n", "output_parser = RegexParser(\n", " regex=r\"(.*?)\\nScore: (.*)\",\n", diff --git a/langchain/chains/combine_documents/map_rerank.py b/langchain/chains/combine_documents/map_rerank.py index 71855650db..2eb67e4c52 100644 --- a/langchain/chains/combine_documents/map_rerank.py +++ b/langchain/chains/combine_documents/map_rerank.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Extra, root_validator from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.prompts.base import RegexParser +from langchain.output_parsers.regex import RegexParser class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel): diff --git a/langchain/chains/question_answering/map_rerank_prompt.py b/langchain/chains/question_answering/map_rerank_prompt.py index ab68048b0d..0fd945c4bd 100644 --- a/langchain/chains/question_answering/map_rerank_prompt.py +++ b/langchain/chains/question_answering/map_rerank_prompt.py @@ -1,6 +1,6 @@ # flake8: noqa from langchain.prompts import PromptTemplate -from langchain.prompts.base import RegexParser +from langchain.output_parsers.regex import RegexParser output_parser = RegexParser( regex=r"(.*?)\nScore: (.*)", diff --git a/langchain/chains/sql_database/prompt.py b/langchain/chains/sql_database/prompt.py index 8b0fd1529e..730c5a2374 100644 --- a/langchain/chains/sql_database/prompt.py +++ b/langchain/chains/sql_database/prompt.py @@ -1,5 +1,5 @@ # flake8: noqa -from langchain.prompts.base import CommaSeparatedListOutputParser +from langchain.output_parsers.list import CommaSeparatedListOutputParser from langchain.prompts.prompt import PromptTemplate _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database. diff --git a/langchain/evaluation/qa/generate_prompt.py b/langchain/evaluation/qa/generate_prompt.py index 7b9fedfd56..2fe278cfea 100644 --- a/langchain/evaluation/qa/generate_prompt.py +++ b/langchain/evaluation/qa/generate_prompt.py @@ -1,6 +1,6 @@ # flake8: noqa from langchain.prompts import PromptTemplate -from langchain.prompts.base import RegexParser +from langchain.output_parsers.regex import RegexParser template = """You are a teacher coming up with questions to ask on a quiz. Given the following document, please generate a question and answer based on that document. diff --git a/langchain/output_parsers/__init__.py b/langchain/output_parsers/__init__.py new file mode 100644 index 0000000000..8509b6f238 --- /dev/null +++ b/langchain/output_parsers/__init__.py @@ -0,0 +1,13 @@ +from langchain.output_parsers.base import BaseOutputParser +from langchain.output_parsers.list import ( + CommaSeparatedListOutputParser, + ListOutputParser, +) +from langchain.output_parsers.regex import RegexParser + +__all__ = [ + "RegexParser", + "ListOutputParser", + "CommaSeparatedListOutputParser", + "BaseOutputParser", +] diff --git a/langchain/output_parsers/base.py b/langchain/output_parsers/base.py new file mode 100644 index 0000000000..3ea30f9c59 --- /dev/null +++ b/langchain/output_parsers/base.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict + +from pydantic import BaseModel + + +class BaseOutputParser(BaseModel, ABC): + """Class to parse the output of an LLM call.""" + + @abstractmethod + def parse(self, text: str) -> Any: + """Parse the output of an LLM call.""" + + @property + def _type(self) -> str: + """Return the type key.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of output parser.""" + output_parser_dict = super().dict() + output_parser_dict["_type"] = self._type + return output_parser_dict diff --git a/langchain/output_parsers/list.py b/langchain/output_parsers/list.py new file mode 100644 index 0000000000..028685f6be --- /dev/null +++ b/langchain/output_parsers/list.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import List + +from langchain.output_parsers.base import BaseOutputParser + + +class ListOutputParser(BaseOutputParser): + """Class to parse the output of an LLM call to a list.""" + + @abstractmethod + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + + +class CommaSeparatedListOutputParser(ListOutputParser): + """Parse out comma separated lists.""" + + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + return text.strip().split(", ") diff --git a/langchain/output_parsers/loading.py b/langchain/output_parsers/loading.py new file mode 100644 index 0000000000..7acd5aa95b --- /dev/null +++ b/langchain/output_parsers/loading.py @@ -0,0 +1,15 @@ +from langchain.output_parsers.regex import RegexParser + + +def load_output_parser(config: dict) -> dict: + """Load output parser.""" + if "output_parsers" in config: + if config["output_parsers"] is not None: + _config = config["output_parsers"] + output_parser_type = _config["_type"] + if output_parser_type == "regex_parser": + output_parser = RegexParser(**_config) + else: + raise ValueError(f"Unsupported output parser {output_parser_type}") + config["output_parsers"] = output_parser + return config diff --git a/langchain/output_parsers/regex.py b/langchain/output_parsers/regex.py new file mode 100644 index 0000000000..b58137a7b3 --- /dev/null +++ b/langchain/output_parsers/regex.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import re +from typing import Dict, List, Optional + +from pydantic import BaseModel + +from langchain.output_parsers.base import BaseOutputParser + + +class RegexParser(BaseOutputParser, BaseModel): + """Class to parse the output into a dictionary.""" + + regex: str + output_keys: List[str] + default_output_key: Optional[str] = None + + @property + def _type(self) -> str: + """Return the type key.""" + return "regex_parser" + + def parse(self, text: str) -> Dict[str, str]: + """Parse the output of an LLM call.""" + match = re.search(self.regex, text) + if match: + return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)} + else: + if self.default_output_key is None: + raise ValueError(f"Could not parse output: {text}") + else: + return { + key: text if key == self.default_output_key else "" + for key in self.output_keys + } diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 614543ba59..b85f31613a 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -2,7 +2,6 @@ from __future__ import annotations import json -import re from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, Union @@ -11,6 +10,12 @@ import yaml from pydantic import BaseModel, Extra, Field, root_validator from langchain.formatting import formatter +from langchain.output_parsers.base import BaseOutputParser +from langchain.output_parsers.list import ( # noqa: F401 + CommaSeparatedListOutputParser, + ListOutputParser, +) +from langchain.output_parsers.regex import RegexParser # noqa: F401 from langchain.schema import BaseMessage, HumanMessage, PromptValue @@ -54,68 +59,6 @@ def check_valid_template( ) -class BaseOutputParser(BaseModel, ABC): - """Class to parse the output of an LLM call.""" - - @abstractmethod - def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]: - """Parse the output of an LLM call.""" - - @property - def _type(self) -> str: - """Return the type key.""" - raise NotImplementedError - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of output parser.""" - output_parser_dict = super().dict() - output_parser_dict["_type"] = self._type - return output_parser_dict - - -class ListOutputParser(BaseOutputParser): - """Class to parse the output of an LLM call to a list.""" - - @abstractmethod - def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" - - -class CommaSeparatedListOutputParser(ListOutputParser): - """Parse out comma separated lists.""" - - def parse(self, text: str) -> List[str]: - """Parse the output of an LLM call.""" - return text.strip().split(", ") - - -class RegexParser(BaseOutputParser, BaseModel): - """Class to parse the output into a dictionary.""" - - regex: str - output_keys: List[str] - default_output_key: Optional[str] = None - - @property - def _type(self) -> str: - """Return the type key.""" - return "regex_parser" - - def parse(self, text: str) -> Dict[str, str]: - """Parse the output of an LLM call.""" - match = re.search(self.regex, text) - if match: - return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)} - else: - if self.default_output_key is None: - raise ValueError(f"Could not parse output: {text}") - else: - return { - key: text if key == self.default_output_key else "" - for key in self.output_keys - } - - class StringPromptValue(PromptValue): text: str diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index 178c637ea2..c849297924 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -7,7 +7,8 @@ from typing import Union import yaml -from langchain.prompts.base import BasePromptTemplate, RegexParser +from langchain.output_parsers.regex import RegexParser +from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.utilities.loading import try_load_from_hub @@ -73,15 +74,15 @@ def _load_examples(config: dict) -> dict: def _load_output_parser(config: dict) -> dict: """Load output parser.""" - if "output_parser" in config: - if config["output_parser"] is not None: - _config = config["output_parser"] + if "output_parsers" in config: + if config["output_parsers"] is not None: + _config = config["output_parsers"] output_parser_type = _config["_type"] if output_parser_type == "regex_parser": output_parser = RegexParser(**_config) else: raise ValueError(f"Unsupported output parser {output_parser_type}") - config["output_parser"] = output_parser + config["output_parsers"] = output_parser return config diff --git a/tests/unit_tests/chains/test_llm.py b/tests/unit_tests/chains/test_llm.py index 1dfe9bb54e..d01e14e5e6 100644 --- a/tests/unit_tests/chains/test_llm.py +++ b/tests/unit_tests/chains/test_llm.py @@ -7,7 +7,7 @@ import pytest from langchain.chains.llm import LLMChain from langchain.chains.loading import load_chain -from langchain.prompts.base import BaseOutputParser +from langchain.output_parsers.base import BaseOutputParser from langchain.prompts.prompt import PromptTemplate from tests.unit_tests.llms.fake_llm import FakeLLM