Harrison/summarizer chain (#1356)

Co-authored-by: Tim Asp <707699+timothyasp@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-03-01 20:59:07 -08:00 committed by GitHub
parent cfae03042d
commit 1cd8996074
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1394 additions and 5 deletions

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@ from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_checker.base import LLMCheckerChain from langchain.chains.llm_checker.base import LLMCheckerChain
from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_requests import LLMRequestsChain from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.llm_summarization_checker.base import LLMSummarizationCheckerChain
from langchain.chains.loading import load_chain from langchain.chains.loading import load_chain
from langchain.chains.mapreduce import MapReduceChain from langchain.chains.mapreduce import MapReduceChain
from langchain.chains.moderation import OpenAIModerationChain from langchain.chains.moderation import OpenAIModerationChain
@ -30,6 +31,7 @@ __all__ = [
"LLMChain", "LLMChain",
"LLMBashChain", "LLMBashChain",
"LLMCheckerChain", "LLMCheckerChain",
"LLMSummarizationCheckerChain",
"LLMMathChain", "LLMMathChain",
"PALChain", "PALChain",
"QAWithSourcesChain", "QAWithSourcesChain",

View File

@ -0,0 +1,7 @@
"""Summarization checker chain for verifying accuracy of text generation.
Chain that tries to verify the accuracy of text generation by splitting it into a
list of facts, then checking if those facts are true or not, and rewriting
the text to make it more truth-ful. It will repeat this loop until it hits `max_tries`
or gets to a "true" output.
"""

View File

@ -0,0 +1,133 @@
"""Chain for summarization with self-verification."""
from pathlib import Path
from typing import Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sequential import SequentialChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
PROMPTS_DIR = Path(__file__).parent / "prompts"
CREATE_ASSERTIONS_PROMPT = PromptTemplate.from_file(
PROMPTS_DIR / "create_facts.txt", ["summary"]
)
CHECK_ASSERTIONS_PROMPT = PromptTemplate.from_file(
PROMPTS_DIR / "check_facts.txt", ["assertions"]
)
REVISED_SUMMARY_PROMPT = PromptTemplate.from_file(
PROMPTS_DIR / "revise_summary.txt", ["checked_assertions", "summary"]
)
ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
PROMPTS_DIR / "are_all_true_prompt.txt", ["checked_assertions"]
)
class LLMSummarizationCheckerChain(Chain, BaseModel):
"""Chain for question-answering with self-verification.
Example:
.. code-block:: python
from langchain import OpenAI, LLMSummarizationCheckerChain
llm = OpenAI(temperature=0.0)
checker_chain = LLMSummarizationCheckerChain(llm=llm)
"""
llm: BaseLLM
"""LLM wrapper to use."""
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
max_checks: int = 2
"""Maximum number of times to check the assertions. Default to double-checking."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
all_true = False
count = 0
output = None
original_input = inputs[self.input_key]
chain_input = original_input
while not all_true and count < self.max_checks:
chain = SequentialChain(
chains=[
LLMChain(
llm=self.llm,
prompt=self.create_assertions_prompt,
output_key="assertions",
verbose=self.verbose,
),
LLMChain(
llm=self.llm,
prompt=self.check_assertions_prompt,
output_key="checked_assertions",
verbose=self.verbose,
),
LLMChain(
llm=self.llm,
prompt=self.revised_summary_prompt,
output_key="revised_summary",
verbose=self.verbose,
),
LLMChain(
llm=self.llm,
output_key="all_true",
prompt=self.are_all_true_prompt,
verbose=self.verbose,
),
],
input_variables=["summary"],
output_variables=["all_true", "revised_summary"],
verbose=True,
)
output = chain({"summary": chain_input})
count += 1
if output["all_true"].strip() == "True":
break
if self.verbose:
print(output["revised_summary"])
chain_input = output["revised_summary"]
if not output:
raise ValueError("No output from chain")
return {self.output_key: output["revised_summary"].strip()}
@property
def _chain_type(self) -> str:
return "llm_summarization_checker_chain"

View File

@ -0,0 +1,38 @@
Below are some assertions that have been fact checked and are labeled as true or false.
If all of the assertions are true, return "True". If any of the assertions are false, return "False".
Here are some examples:
===
Checked Assertions: """
- The sky is red: False
- Water is made of lava: False
- The sun is a star: True
"""
Result: False
===
Checked Assertions: """
- The sky is blue: True
- Water is wet: True
- The sun is a star: True
"""
Result: True
===
Checked Assertions: """
- The sky is blue - True
- Water is made of lava- False
- The sun is a star - True
"""
Result: False
===
Checked Assertions:"""
{checked_assertions}
"""
Result:

View File

@ -0,0 +1,10 @@
You are an expert fact checker. You have been hired by a major news organization to fact check a very important story.
Here is a bullet point list of facts:
"""
{assertions}
"""
For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output "Undetermined".
If the fact is false, explain why.

View File

@ -0,0 +1,10 @@
Given some text, extract a list of facts from the text.
Format your output as a bulleted list.
Text:
"""
{summary}
"""
Facts:

View File

@ -0,0 +1,17 @@
Below are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.
Checked Assertions:
"""
{checked_assertions}
"""
Original Summary:
"""
{summary}
"""
Using these checked assertions, rewrite the original summary to be completely true.
The output should have the same structure and formatting as the original summary.
Summary:

View File

@ -46,8 +46,11 @@ def check_valid_template(
try: try:
formatter_func = DEFAULT_FORMATTER_MAPPING[template_format] formatter_func = DEFAULT_FORMATTER_MAPPING[template_format]
formatter_func(template, **dummy_inputs) formatter_func(template, **dummy_inputs)
except KeyError: except KeyError as e:
raise ValueError("Invalid prompt schema.") raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. "
+ str(e)
)
class BaseOutputParser(BaseModel, ABC): class BaseOutputParser(BaseModel, ABC):

View File

@ -1,8 +1,9 @@
"""Prompt schema definition.""" """Prompt schema definition."""
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from string import Formatter from string import Formatter
from typing import Any, Dict, List from typing import Any, Dict, List, Union
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -105,7 +106,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
@classmethod @classmethod
def from_file( def from_file(
cls, template_file: str, input_variables: List[str] cls, template_file: Union[str, Path], input_variables: List[str]
) -> PromptTemplate: ) -> PromptTemplate:
"""Load a prompt from a file. """Load a prompt from a file.
@ -116,7 +117,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
Returns: Returns:
The prompt loaded from the file. The prompt loaded from the file.
""" """
with open(template_file, "r") as f: with open(str(template_file), "r") as f:
template = f.read() template = f.read()
return cls(input_variables=input_variables, template=template) return cls(input_variables=input_variables, template=template)

View File

@ -0,0 +1,44 @@
# flake8: noqa E501
"""Test LLMSummarization functionality."""
import pytest
from langchain.chains.llm_summarization_checker.base import (
ARE_ALL_TRUE_PROMPT,
CHECK_ASSERTIONS_PROMPT,
CREATE_ASSERTIONS_PROMPT,
REVISED_SUMMARY_PROMPT,
LLMSummarizationCheckerChain,
)
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture
def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain:
"""Fake LLMCheckerChain for testing."""
queries = {
CREATE_ASSERTIONS_PROMPT.format(
summary="a",
): "b",
CHECK_ASSERTIONS_PROMPT.format(
assertions="b",
): "- b - True",
REVISED_SUMMARY_PROMPT.format(
checked_assertions="- b - True", summary="a"
): "b",
ARE_ALL_TRUE_PROMPT.format(
checked_assertions="- b - True",
): "True",
}
fake_llm = FakeLLM(queries=queries)
return LLMSummarizationCheckerChain(llm=fake_llm, input_key="q", output_key="a")
def test_simple_text(
fake_llm_summarization_checker_chain: LLMSummarizationCheckerChain,
) -> None:
"""Test simple question that should not need python."""
question = "a"
output = fake_llm_summarization_checker_chain.run(question)
assert output == "b"