From 9ec01dfc164777f234ba784b51d07fdaf6713384 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 27 Dec 2022 20:28:08 -0500 Subject: [PATCH] regex output parser (#435) --- langchain/evaluation/qa/generate_prompt.py | 25 +++++----------------- langchain/prompts/base.py | 18 +++++++++++++++- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/langchain/evaluation/qa/generate_prompt.py b/langchain/evaluation/qa/generate_prompt.py index 9ee74e8ceb..6eb3137476 100644 --- a/langchain/evaluation/qa/generate_prompt.py +++ b/langchain/evaluation/qa/generate_prompt.py @@ -1,24 +1,6 @@ # flake8: noqa -import re -from typing import Dict - from langchain.prompts import PromptTemplate -from langchain.prompts.base import BaseOutputParser - - -class QAGenerationOutputParser(BaseOutputParser): - """Parse output in question/answer pair.""" - - def parse(self, text: str) -> Dict[str, str]: - regex = r"QUESTION: (.*?)\nANSWER: (.*)" - match = re.search(regex, text) - if match: - question = match.group(1) - answer = match.group(2) - return {"query": question, "answer": answer} - else: - raise ValueError(f"Could not parse output: {text}") - +from langchain.prompts.base import RegexParser template = """You are a teacher coming up with questions to ask on a quiz. Given the following document, please generate a question and answer based on that document. @@ -35,6 +17,9 @@ These questions should be detailed and be based explicitly on information in the {doc} """ +output_parser = RegexParser( + regex=r"QUESTION: (.*?)\nANSWER: (.*)", output_keys=["question", "answer"] +) PROMPT = PromptTemplate( - input_variables=["doc"], template=template, output_parser=QAGenerationOutputParser() + input_variables=["doc"], template=template, output_parser=output_parser ) diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index c7b708a345..5221ff3645 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -1,5 +1,6 @@ """BasePrompt schema definition.""" import json +import re from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union @@ -55,7 +56,7 @@ class BaseOutputParser(ABC): """Parse the output of an LLM call.""" -class ListOutputParser(ABC): +class ListOutputParser(BaseOutputParser): """Class to parse the output of an LLM call to a list.""" @abstractmethod @@ -63,6 +64,21 @@ class ListOutputParser(ABC): """Parse the output of an LLM call.""" +class RegexParser(BaseOutputParser, BaseModel): + """Class to parse the output into a dictionary.""" + + regex: str + output_keys: List[str] + + def parse(self, text: str) -> Dict[str, str]: + """Parse the output of an LLM call.""" + match = re.search(self.regex, text) + if match: + return {key: match.group(i) for i, key in enumerate(self.output_keys)} + else: + raise ValueError(f"Could not parse output: {text}") + + class BasePromptTemplate(BaseModel, ABC): """Base prompt should expose the format method, returning a prompt."""