Harrison/summarizer chain (#1356)

Co-authored-by: Tim Asp <707699+timothyasp@users.noreply.github.com>
pull/1382/head
Harrison Chase 2 years ago committed by GitHub
parent cfae03042d
commit 1cd8996074
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -11,6 +11,7 @@ from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_checker.base import LLMCheckerChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.llm_summarization_checker.base import LLMSummarizationCheckerChain
from langchain.chains.loading import load_chain
from langchain.chains.mapreduce import MapReduceChain
from langchain.chains.moderation import OpenAIModerationChain
@ -30,6 +31,7 @@ __all__ = [
"LLMChain",
"LLMBashChain",
"LLMCheckerChain",
"LLMSummarizationCheckerChain",
"LLMMathChain",
"PALChain",
"QAWithSourcesChain",

@ -0,0 +1,7 @@
"""Summarization checker chain for verifying accuracy of text generation.
Chain that tries to verify the accuracy of text generation by splitting it into a
list of facts, then checking if those facts are true or not, and rewriting
the text to make it more truth-ful. It will repeat this loop until it hits `max_tries`
or gets to a "true" output.
"""

@ -0,0 +1,133 @@
"""Chain for summarization with self-verification."""
from pathlib import Path
from typing import Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sequential import SequentialChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
PROMPTS_DIR = Path(__file__).parent / "prompts"
CREATE_ASSERTIONS_PROMPT = PromptTemplate.from_file(
PROMPTS_DIR / "create_facts.txt", ["summary"]
)
CHECK_ASSERTIONS_PROMPT = PromptTemplate.from_file(
PROMPTS_DIR / "check_facts.txt", ["assertions"]
)
REVISED_SUMMARY_PROMPT = PromptTemplate.from_file(
PROMPTS_DIR / "revise_summary.txt", ["checked_assertions", "summary"]
)
ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
PROMPTS_DIR / "are_all_true_prompt.txt", ["checked_assertions"]
)
class LLMSummarizationCheckerChain(Chain, BaseModel):
"""Chain for question-answering with self-verification.
Example:
.. code-block:: python
from langchain import OpenAI, LLMSummarizationCheckerChain
llm = OpenAI(temperature=0.0)
checker_chain = LLMSummarizationCheckerChain(llm=llm)
"""
llm: BaseLLM
"""LLM wrapper to use."""
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
max_checks: int = 2
"""Maximum number of times to check the assertions. Default to double-checking."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
all_true = False
count = 0
output = None
original_input = inputs[self.input_key]
chain_input = original_input
while not all_true and count < self.max_checks:
chain = SequentialChain(
chains=[
LLMChain(
llm=self.llm,
prompt=self.create_assertions_prompt,
output_key="assertions",
verbose=self.verbose,
),
LLMChain(
llm=self.llm,
prompt=self.check_assertions_prompt,
output_key="checked_assertions",
verbose=self.verbose,
),
LLMChain(
llm=self.llm,
prompt=self.revised_summary_prompt,
output_key="revised_summary",
verbose=self.verbose,
),
LLMChain(
llm=self.llm,
output_key="all_true",
prompt=self.are_all_true_prompt,
verbose=self.verbose,
),
],
input_variables=["summary"],
output_variables=["all_true", "revised_summary"],
verbose=True,
)
output = chain({"summary": chain_input})
count += 1
if output["all_true"].strip() == "True":
break
if self.verbose:
print(output["revised_summary"])
chain_input = output["revised_summary"]
if not output:
raise ValueError("No output from chain")
return {self.output_key: output["revised_summary"].strip()}
@property
def _chain_type(self) -> str:
return "llm_summarization_checker_chain"

@ -0,0 +1,38 @@
Below are some assertions that have been fact checked and are labeled as true or false.
If all of the assertions are true, return "True". If any of the assertions are false, return "False".
Here are some examples:
===
Checked Assertions: """
- The sky is red: False
- Water is made of lava: False
- The sun is a star: True
"""
Result: False
===
Checked Assertions: """
- The sky is blue: True
- Water is wet: True
- The sun is a star: True
"""
Result: True
===
Checked Assertions: """
- The sky is blue - True
- Water is made of lava- False
- The sun is a star - True
"""
Result: False
===
Checked Assertions:"""
{checked_assertions}
"""
Result:

@ -0,0 +1,10 @@
You are an expert fact checker. You have been hired by a major news organization to fact check a very important story.
Here is a bullet point list of facts:
"""
{assertions}
"""
For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output "Undetermined".
If the fact is false, explain why.

@ -0,0 +1,10 @@
Given some text, extract a list of facts from the text.
Format your output as a bulleted list.
Text:
"""
{summary}
"""
Facts:

@ -0,0 +1,17 @@
Below are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.
Checked Assertions:
"""
{checked_assertions}
"""
Original Summary:
"""
{summary}
"""
Using these checked assertions, rewrite the original summary to be completely true.
The output should have the same structure and formatting as the original summary.
Summary:

@ -46,8 +46,11 @@ def check_valid_template(
try:
formatter_func = DEFAULT_FORMATTER_MAPPING[template_format]
formatter_func(template, **dummy_inputs)
except KeyError:
raise ValueError("Invalid prompt schema.")
except KeyError as e:
raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. "
+ str(e)
)
class BaseOutputParser(BaseModel, ABC):

@ -1,8 +1,9 @@
"""Prompt schema definition."""
from __future__ import annotations
from pathlib import Path
from string import Formatter
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from pydantic import BaseModel, Extra, root_validator
@ -105,7 +106,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
@classmethod
def from_file(
cls, template_file: str, input_variables: List[str]
cls, template_file: Union[str, Path], input_variables: List[str]
) -> PromptTemplate:
"""Load a prompt from a file.
@ -116,7 +117,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
Returns:
The prompt loaded from the file.
"""
with open(template_file, "r") as f:
with open(str(template_file), "r") as f:
template = f.read()
return cls(input_variables=input_variables, template=template)

@ -0,0 +1,44 @@
# flake8: noqa E501
"""Test LLMSummarization functionality."""
import pytest
from langchain.chains.llm_summarization_checker.base import (
ARE_ALL_TRUE_PROMPT,
CHECK_ASSERTIONS_PROMPT,
CREATE_ASSERTIONS_PROMPT,
REVISED_SUMMARY_PROMPT,
LLMSummarizationCheckerChain,
)
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture
def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain:
"""Fake LLMCheckerChain for testing."""
queries = {
CREATE_ASSERTIONS_PROMPT.format(
summary="a",
): "b",
CHECK_ASSERTIONS_PROMPT.format(
assertions="b",
): "- b - True",
REVISED_SUMMARY_PROMPT.format(
checked_assertions="- b - True", summary="a"
): "b",
ARE_ALL_TRUE_PROMPT.format(
checked_assertions="- b - True",
): "True",
}
fake_llm = FakeLLM(queries=queries)
return LLMSummarizationCheckerChain(llm=fake_llm, input_key="q", output_key="a")
def test_simple_text(
fake_llm_summarization_checker_chain: LLMSummarizationCheckerChain,
) -> None:
"""Test simple question that should not need python."""
question = "a"
output = fake_llm_summarization_checker_chain.run(question)
assert output == "b"
Loading…
Cancel
Save