mirror of https://github.com/hwchase17/langchain
Harrison/summarizer chain (#1356)
Co-authored-by: Tim Asp <707699+timothyasp@users.noreply.github.com>pull/1382/head
parent
cfae03042d
commit
1cd8996074
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,7 @@
|
||||
"""Summarization checker chain for verifying accuracy of text generation.
|
||||
|
||||
Chain that tries to verify the accuracy of text generation by splitting it into a
|
||||
list of facts, then checking if those facts are true or not, and rewriting
|
||||
the text to make it more truth-ful. It will repeat this loop until it hits `max_tries`
|
||||
or gets to a "true" output.
|
||||
"""
|
@ -0,0 +1,133 @@
|
||||
"""Chain for summarization with self-verification."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||
|
||||
CREATE_ASSERTIONS_PROMPT = PromptTemplate.from_file(
|
||||
PROMPTS_DIR / "create_facts.txt", ["summary"]
|
||||
)
|
||||
CHECK_ASSERTIONS_PROMPT = PromptTemplate.from_file(
|
||||
PROMPTS_DIR / "check_facts.txt", ["assertions"]
|
||||
)
|
||||
REVISED_SUMMARY_PROMPT = PromptTemplate.from_file(
|
||||
PROMPTS_DIR / "revise_summary.txt", ["checked_assertions", "summary"]
|
||||
)
|
||||
ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
|
||||
PROMPTS_DIR / "are_all_true_prompt.txt", ["checked_assertions"]
|
||||
)
|
||||
|
||||
|
||||
class LLMSummarizationCheckerChain(Chain, BaseModel):
|
||||
"""Chain for question-answering with self-verification.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import OpenAI, LLMSummarizationCheckerChain
|
||||
llm = OpenAI(temperature=0.0)
|
||||
checker_chain = LLMSummarizationCheckerChain(llm=llm)
|
||||
"""
|
||||
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
|
||||
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
|
||||
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT
|
||||
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT
|
||||
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
max_checks: int = 2
|
||||
"""Maximum number of times to check the assertions. Default to double-checking."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
all_true = False
|
||||
count = 0
|
||||
output = None
|
||||
original_input = inputs[self.input_key]
|
||||
chain_input = original_input
|
||||
|
||||
while not all_true and count < self.max_checks:
|
||||
chain = SequentialChain(
|
||||
chains=[
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.create_assertions_prompt,
|
||||
output_key="assertions",
|
||||
verbose=self.verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
verbose=self.verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.revised_summary_prompt,
|
||||
output_key="revised_summary",
|
||||
verbose=self.verbose,
|
||||
),
|
||||
LLMChain(
|
||||
llm=self.llm,
|
||||
output_key="all_true",
|
||||
prompt=self.are_all_true_prompt,
|
||||
verbose=self.verbose,
|
||||
),
|
||||
],
|
||||
input_variables=["summary"],
|
||||
output_variables=["all_true", "revised_summary"],
|
||||
verbose=True,
|
||||
)
|
||||
output = chain({"summary": chain_input})
|
||||
count += 1
|
||||
|
||||
if output["all_true"].strip() == "True":
|
||||
break
|
||||
|
||||
if self.verbose:
|
||||
print(output["revised_summary"])
|
||||
|
||||
chain_input = output["revised_summary"]
|
||||
|
||||
if not output:
|
||||
raise ValueError("No output from chain")
|
||||
|
||||
return {self.output_key: output["revised_summary"].strip()}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_summarization_checker_chain"
|
@ -0,0 +1,38 @@
|
||||
Below are some assertions that have been fact checked and are labeled as true or false.
|
||||
|
||||
If all of the assertions are true, return "True". If any of the assertions are false, return "False".
|
||||
|
||||
Here are some examples:
|
||||
===
|
||||
|
||||
Checked Assertions: """
|
||||
- The sky is red: False
|
||||
- Water is made of lava: False
|
||||
- The sun is a star: True
|
||||
"""
|
||||
Result: False
|
||||
|
||||
===
|
||||
|
||||
Checked Assertions: """
|
||||
- The sky is blue: True
|
||||
- Water is wet: True
|
||||
- The sun is a star: True
|
||||
"""
|
||||
Result: True
|
||||
|
||||
===
|
||||
|
||||
Checked Assertions: """
|
||||
- The sky is blue - True
|
||||
- Water is made of lava- False
|
||||
- The sun is a star - True
|
||||
"""
|
||||
Result: False
|
||||
|
||||
===
|
||||
|
||||
Checked Assertions:"""
|
||||
{checked_assertions}
|
||||
"""
|
||||
Result:
|
@ -0,0 +1,10 @@
|
||||
You are an expert fact checker. You have been hired by a major news organization to fact check a very important story.
|
||||
|
||||
Here is a bullet point list of facts:
|
||||
"""
|
||||
{assertions}
|
||||
"""
|
||||
|
||||
For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output "Undetermined".
|
||||
If the fact is false, explain why.
|
||||
|
@ -0,0 +1,10 @@
|
||||
Given some text, extract a list of facts from the text.
|
||||
|
||||
Format your output as a bulleted list.
|
||||
|
||||
Text:
|
||||
"""
|
||||
{summary}
|
||||
"""
|
||||
|
||||
Facts:
|
@ -0,0 +1,17 @@
|
||||
Below are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.
|
||||
|
||||
Checked Assertions:
|
||||
"""
|
||||
{checked_assertions}
|
||||
"""
|
||||
|
||||
Original Summary:
|
||||
"""
|
||||
{summary}
|
||||
"""
|
||||
|
||||
Using these checked assertions, rewrite the original summary to be completely true.
|
||||
|
||||
The output should have the same structure and formatting as the original summary.
|
||||
|
||||
Summary:
|
@ -0,0 +1,44 @@
|
||||
# flake8: noqa E501
|
||||
|
||||
"""Test LLMSummarization functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm_summarization_checker.base import (
|
||||
ARE_ALL_TRUE_PROMPT,
|
||||
CHECK_ASSERTIONS_PROMPT,
|
||||
CREATE_ASSERTIONS_PROMPT,
|
||||
REVISED_SUMMARY_PROMPT,
|
||||
LLMSummarizationCheckerChain,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain:
|
||||
"""Fake LLMCheckerChain for testing."""
|
||||
queries = {
|
||||
CREATE_ASSERTIONS_PROMPT.format(
|
||||
summary="a",
|
||||
): "b",
|
||||
CHECK_ASSERTIONS_PROMPT.format(
|
||||
assertions="b",
|
||||
): "- b - True",
|
||||
REVISED_SUMMARY_PROMPT.format(
|
||||
checked_assertions="- b - True", summary="a"
|
||||
): "b",
|
||||
ARE_ALL_TRUE_PROMPT.format(
|
||||
checked_assertions="- b - True",
|
||||
): "True",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMSummarizationCheckerChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_text(
|
||||
fake_llm_summarization_checker_chain: LLMSummarizationCheckerChain,
|
||||
) -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
question = "a"
|
||||
output = fake_llm_summarization_checker_chain.run(question)
|
||||
assert output == "b"
|
Loading…
Reference in New Issue