move output parsing (#1605)

This commit is contained in:
Harrison Chase 2023-03-11 16:41:03 -08:00 committed by GitHub
parent cb04ba0136
commit c9b5a30b37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 130 additions and 76 deletions

View File

@ -635,7 +635,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.prompts.base import RegexParser\n", "from langchain.output_parsers import RegexParser\n",
"\n", "\n",
"output_parser = RegexParser(\n", "output_parser = RegexParser(\n",
" regex=r\"(.*?)\\nScore: (.*)\",\n", " regex=r\"(.*?)\\nScore: (.*)\",\n",

View File

@ -635,7 +635,7 @@
} }
], ],
"source": [ "source": [
"from langchain.prompts.base import RegexParser\n", "from langchain.output_parsers import RegexParser\n",
"\n", "\n",
"output_parser = RegexParser(\n", "output_parser = RegexParser(\n",
" regex=r\"(.*?)\\nScore: (.*)\",\n", " regex=r\"(.*?)\\nScore: (.*)\",\n",

View File

@ -9,7 +9,7 @@ from pydantic import BaseModel, Extra, root_validator
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.prompts.base import RegexParser from langchain.output_parsers.regex import RegexParser
class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel): class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel):

View File

@ -1,6 +1,6 @@
# flake8: noqa # flake8: noqa
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.prompts.base import RegexParser from langchain.output_parsers.regex import RegexParser
output_parser = RegexParser( output_parser = RegexParser(
regex=r"(.*?)\nScore: (.*)", regex=r"(.*?)\nScore: (.*)",

View File

@ -1,5 +1,5 @@
# flake8: noqa # flake8: noqa
from langchain.prompts.base import CommaSeparatedListOutputParser from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database. _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.

View File

@ -1,6 +1,6 @@
# flake8: noqa # flake8: noqa
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.prompts.base import RegexParser from langchain.output_parsers.regex import RegexParser
template = """You are a teacher coming up with questions to ask on a quiz. template = """You are a teacher coming up with questions to ask on a quiz.
Given the following document, please generate a question and answer based on that document. Given the following document, please generate a question and answer based on that document.

View File

@ -0,0 +1,13 @@
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
)
from langchain.output_parsers.regex import RegexParser
__all__ = [
"RegexParser",
"ListOutputParser",
"CommaSeparatedListOutputParser",
"BaseOutputParser",
]

View File

@ -0,0 +1,25 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict
from pydantic import BaseModel
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict

View File

@ -0,0 +1,22 @@
from __future__ import annotations
from abc import abstractmethod
from typing import List
from langchain.output_parsers.base import BaseOutputParser
class ListOutputParser(BaseOutputParser):
"""Class to parse the output of an LLM call to a list."""
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse out comma separated lists."""
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")

View File

@ -0,0 +1,15 @@
from langchain.output_parsers.regex import RegexParser
def load_output_parser(config: dict) -> dict:
"""Load output parser."""
if "output_parsers" in config:
if config["output_parsers"] is not None:
_config = config["output_parsers"]
output_parser_type = _config["_type"]
if output_parser_type == "regex_parser":
output_parser = RegexParser(**_config)
else:
raise ValueError(f"Unsupported output parser {output_parser_type}")
config["output_parsers"] = output_parser
return config

View File

@ -0,0 +1,35 @@
from __future__ import annotations
import re
from typing import Dict, List, Optional
from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser
class RegexParser(BaseOutputParser, BaseModel):
"""Class to parse the output into a dictionary."""
regex: str
output_keys: List[str]
default_output_key: Optional[str] = None
@property
def _type(self) -> str:
"""Return the type key."""
return "regex_parser"
def parse(self, text: str) -> Dict[str, str]:
"""Parse the output of an LLM call."""
match = re.search(self.regex, text)
if match:
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
else:
if self.default_output_key is None:
raise ValueError(f"Could not parse output: {text}")
else:
return {
key: text if key == self.default_output_key else ""
for key in self.output_keys
}

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import json import json
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Union
@ -11,6 +10,12 @@ import yaml
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from langchain.formatting import formatter from langchain.formatting import formatter
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.list import ( # noqa: F401
CommaSeparatedListOutputParser,
ListOutputParser,
)
from langchain.output_parsers.regex import RegexParser # noqa: F401
from langchain.schema import BaseMessage, HumanMessage, PromptValue from langchain.schema import BaseMessage, HumanMessage, PromptValue
@ -54,68 +59,6 @@ def check_valid_template(
) )
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
"""Parse the output of an LLM call."""
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
class ListOutputParser(BaseOutputParser):
"""Class to parse the output of an LLM call to a list."""
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse out comma separated lists."""
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")
class RegexParser(BaseOutputParser, BaseModel):
"""Class to parse the output into a dictionary."""
regex: str
output_keys: List[str]
default_output_key: Optional[str] = None
@property
def _type(self) -> str:
"""Return the type key."""
return "regex_parser"
def parse(self, text: str) -> Dict[str, str]:
"""Parse the output of an LLM call."""
match = re.search(self.regex, text)
if match:
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
else:
if self.default_output_key is None:
raise ValueError(f"Could not parse output: {text}")
else:
return {
key: text if key == self.default_output_key else ""
for key in self.output_keys
}
class StringPromptValue(PromptValue): class StringPromptValue(PromptValue):
text: str text: str

View File

@ -7,7 +7,8 @@ from typing import Union
import yaml import yaml
from langchain.prompts.base import BasePromptTemplate, RegexParser from langchain.output_parsers.regex import RegexParser
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.utilities.loading import try_load_from_hub from langchain.utilities.loading import try_load_from_hub
@ -73,15 +74,15 @@ def _load_examples(config: dict) -> dict:
def _load_output_parser(config: dict) -> dict: def _load_output_parser(config: dict) -> dict:
"""Load output parser.""" """Load output parser."""
if "output_parser" in config: if "output_parsers" in config:
if config["output_parser"] is not None: if config["output_parsers"] is not None:
_config = config["output_parser"] _config = config["output_parsers"]
output_parser_type = _config["_type"] output_parser_type = _config["_type"]
if output_parser_type == "regex_parser": if output_parser_type == "regex_parser":
output_parser = RegexParser(**_config) output_parser = RegexParser(**_config)
else: else:
raise ValueError(f"Unsupported output parser {output_parser_type}") raise ValueError(f"Unsupported output parser {output_parser_type}")
config["output_parser"] = output_parser config["output_parsers"] = output_parser
return config return config

View File

@ -7,7 +7,7 @@ import pytest
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.loading import load_chain from langchain.chains.loading import load_chain
from langchain.prompts.base import BaseOutputParser from langchain.output_parsers.base import BaseOutputParser
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM