mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/summarizer chain (#1356)
Co-authored-by: Tim Asp <707699+timothyasp@users.noreply.github.com>
This commit is contained in:
parent
cfae03042d
commit
1cd8996074
1124
docs/modules/chains/examples/llm_summarization_checker.ipynb
Normal file
1124
docs/modules/chains/examples/llm_summarization_checker.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@ -11,6 +11,7 @@ from langchain.chains.llm_bash.base import LLMBashChain
|
|||||||
from langchain.chains.llm_checker.base import LLMCheckerChain
|
from langchain.chains.llm_checker.base import LLMCheckerChain
|
||||||
from langchain.chains.llm_math.base import LLMMathChain
|
from langchain.chains.llm_math.base import LLMMathChain
|
||||||
from langchain.chains.llm_requests import LLMRequestsChain
|
from langchain.chains.llm_requests import LLMRequestsChain
|
||||||
|
from langchain.chains.llm_summarization_checker.base import LLMSummarizationCheckerChain
|
||||||
from langchain.chains.loading import load_chain
|
from langchain.chains.loading import load_chain
|
||||||
from langchain.chains.mapreduce import MapReduceChain
|
from langchain.chains.mapreduce import MapReduceChain
|
||||||
from langchain.chains.moderation import OpenAIModerationChain
|
from langchain.chains.moderation import OpenAIModerationChain
|
||||||
@ -30,6 +31,7 @@ __all__ = [
|
|||||||
"LLMChain",
|
"LLMChain",
|
||||||
"LLMBashChain",
|
"LLMBashChain",
|
||||||
"LLMCheckerChain",
|
"LLMCheckerChain",
|
||||||
|
"LLMSummarizationCheckerChain",
|
||||||
"LLMMathChain",
|
"LLMMathChain",
|
||||||
"PALChain",
|
"PALChain",
|
||||||
"QAWithSourcesChain",
|
"QAWithSourcesChain",
|
||||||
|
7
langchain/chains/llm_summarization_checker/__init__.py
Normal file
7
langchain/chains/llm_summarization_checker/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""Summarization checker chain for verifying accuracy of text generation.
|
||||||
|
|
||||||
|
Chain that tries to verify the accuracy of text generation by splitting it into a
|
||||||
|
list of facts, then checking if those facts are true or not, and rewriting
|
||||||
|
the text to make it more truth-ful. It will repeat this loop until it hits `max_tries`
|
||||||
|
or gets to a "true" output.
|
||||||
|
"""
|
133
langchain/chains/llm_summarization_checker/base.py
Normal file
133
langchain/chains/llm_summarization_checker/base.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
"""Chain for summarization with self-verification."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.sequential import SequentialChain
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||||
|
|
||||||
|
CREATE_ASSERTIONS_PROMPT = PromptTemplate.from_file(
|
||||||
|
PROMPTS_DIR / "create_facts.txt", ["summary"]
|
||||||
|
)
|
||||||
|
CHECK_ASSERTIONS_PROMPT = PromptTemplate.from_file(
|
||||||
|
PROMPTS_DIR / "check_facts.txt", ["assertions"]
|
||||||
|
)
|
||||||
|
REVISED_SUMMARY_PROMPT = PromptTemplate.from_file(
|
||||||
|
PROMPTS_DIR / "revise_summary.txt", ["checked_assertions", "summary"]
|
||||||
|
)
|
||||||
|
ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
|
||||||
|
PROMPTS_DIR / "are_all_true_prompt.txt", ["checked_assertions"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMSummarizationCheckerChain(Chain, BaseModel):
|
||||||
|
"""Chain for question-answering with self-verification.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import OpenAI, LLMSummarizationCheckerChain
|
||||||
|
llm = OpenAI(temperature=0.0)
|
||||||
|
checker_chain = LLMSummarizationCheckerChain(llm=llm)
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm: BaseLLM
|
||||||
|
"""LLM wrapper to use."""
|
||||||
|
|
||||||
|
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT
|
||||||
|
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
|
||||||
|
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT
|
||||||
|
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT
|
||||||
|
|
||||||
|
input_key: str = "query" #: :meta private:
|
||||||
|
output_key: str = "result" #: :meta private:
|
||||||
|
max_checks: int = 2
|
||||||
|
"""Maximum number of times to check the assertions. Default to double-checking."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Return the singular input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return the singular output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
all_true = False
|
||||||
|
count = 0
|
||||||
|
output = None
|
||||||
|
original_input = inputs[self.input_key]
|
||||||
|
chain_input = original_input
|
||||||
|
|
||||||
|
while not all_true and count < self.max_checks:
|
||||||
|
chain = SequentialChain(
|
||||||
|
chains=[
|
||||||
|
LLMChain(
|
||||||
|
llm=self.llm,
|
||||||
|
prompt=self.create_assertions_prompt,
|
||||||
|
output_key="assertions",
|
||||||
|
verbose=self.verbose,
|
||||||
|
),
|
||||||
|
LLMChain(
|
||||||
|
llm=self.llm,
|
||||||
|
prompt=self.check_assertions_prompt,
|
||||||
|
output_key="checked_assertions",
|
||||||
|
verbose=self.verbose,
|
||||||
|
),
|
||||||
|
LLMChain(
|
||||||
|
llm=self.llm,
|
||||||
|
prompt=self.revised_summary_prompt,
|
||||||
|
output_key="revised_summary",
|
||||||
|
verbose=self.verbose,
|
||||||
|
),
|
||||||
|
LLMChain(
|
||||||
|
llm=self.llm,
|
||||||
|
output_key="all_true",
|
||||||
|
prompt=self.are_all_true_prompt,
|
||||||
|
verbose=self.verbose,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
input_variables=["summary"],
|
||||||
|
output_variables=["all_true", "revised_summary"],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
output = chain({"summary": chain_input})
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if output["all_true"].strip() == "True":
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(output["revised_summary"])
|
||||||
|
|
||||||
|
chain_input = output["revised_summary"]
|
||||||
|
|
||||||
|
if not output:
|
||||||
|
raise ValueError("No output from chain")
|
||||||
|
|
||||||
|
return {self.output_key: output["revised_summary"].strip()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "llm_summarization_checker_chain"
|
@ -0,0 +1,38 @@
|
|||||||
|
Below are some assertions that have been fact checked and are labeled as true or false.
|
||||||
|
|
||||||
|
If all of the assertions are true, return "True". If any of the assertions are false, return "False".
|
||||||
|
|
||||||
|
Here are some examples:
|
||||||
|
===
|
||||||
|
|
||||||
|
Checked Assertions: """
|
||||||
|
- The sky is red: False
|
||||||
|
- Water is made of lava: False
|
||||||
|
- The sun is a star: True
|
||||||
|
"""
|
||||||
|
Result: False
|
||||||
|
|
||||||
|
===
|
||||||
|
|
||||||
|
Checked Assertions: """
|
||||||
|
- The sky is blue: True
|
||||||
|
- Water is wet: True
|
||||||
|
- The sun is a star: True
|
||||||
|
"""
|
||||||
|
Result: True
|
||||||
|
|
||||||
|
===
|
||||||
|
|
||||||
|
Checked Assertions: """
|
||||||
|
- The sky is blue - True
|
||||||
|
- Water is made of lava- False
|
||||||
|
- The sun is a star - True
|
||||||
|
"""
|
||||||
|
Result: False
|
||||||
|
|
||||||
|
===
|
||||||
|
|
||||||
|
Checked Assertions:"""
|
||||||
|
{checked_assertions}
|
||||||
|
"""
|
||||||
|
Result:
|
@ -0,0 +1,10 @@
|
|||||||
|
You are an expert fact checker. You have been hired by a major news organization to fact check a very important story.
|
||||||
|
|
||||||
|
Here is a bullet point list of facts:
|
||||||
|
"""
|
||||||
|
{assertions}
|
||||||
|
"""
|
||||||
|
|
||||||
|
For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output "Undetermined".
|
||||||
|
If the fact is false, explain why.
|
||||||
|
|
@ -0,0 +1,10 @@
|
|||||||
|
Given some text, extract a list of facts from the text.
|
||||||
|
|
||||||
|
Format your output as a bulleted list.
|
||||||
|
|
||||||
|
Text:
|
||||||
|
"""
|
||||||
|
{summary}
|
||||||
|
"""
|
||||||
|
|
||||||
|
Facts:
|
@ -0,0 +1,17 @@
|
|||||||
|
Below are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.
|
||||||
|
|
||||||
|
Checked Assertions:
|
||||||
|
"""
|
||||||
|
{checked_assertions}
|
||||||
|
"""
|
||||||
|
|
||||||
|
Original Summary:
|
||||||
|
"""
|
||||||
|
{summary}
|
||||||
|
"""
|
||||||
|
|
||||||
|
Using these checked assertions, rewrite the original summary to be completely true.
|
||||||
|
|
||||||
|
The output should have the same structure and formatting as the original summary.
|
||||||
|
|
||||||
|
Summary:
|
@ -46,8 +46,11 @@ def check_valid_template(
|
|||||||
try:
|
try:
|
||||||
formatter_func = DEFAULT_FORMATTER_MAPPING[template_format]
|
formatter_func = DEFAULT_FORMATTER_MAPPING[template_format]
|
||||||
formatter_func(template, **dummy_inputs)
|
formatter_func(template, **dummy_inputs)
|
||||||
except KeyError:
|
except KeyError as e:
|
||||||
raise ValueError("Invalid prompt schema.")
|
raise ValueError(
|
||||||
|
"Invalid prompt schema; check for mismatched or missing input parameters. "
|
||||||
|
+ str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseOutputParser(BaseModel, ABC):
|
class BaseOutputParser(BaseModel, ABC):
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
"""Prompt schema definition."""
|
"""Prompt schema definition."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
@ -105,7 +106,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(
|
def from_file(
|
||||||
cls, template_file: str, input_variables: List[str]
|
cls, template_file: Union[str, Path], input_variables: List[str]
|
||||||
) -> PromptTemplate:
|
) -> PromptTemplate:
|
||||||
"""Load a prompt from a file.
|
"""Load a prompt from a file.
|
||||||
|
|
||||||
@ -116,7 +117,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
The prompt loaded from the file.
|
The prompt loaded from the file.
|
||||||
"""
|
"""
|
||||||
with open(template_file, "r") as f:
|
with open(str(template_file), "r") as f:
|
||||||
template = f.read()
|
template = f.read()
|
||||||
return cls(input_variables=input_variables, template=template)
|
return cls(input_variables=input_variables, template=template)
|
||||||
|
|
||||||
|
44
tests/unit_tests/chains/test_llm_summarization_checker.py
Normal file
44
tests/unit_tests/chains/test_llm_summarization_checker.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# flake8: noqa E501
|
||||||
|
|
||||||
|
"""Test LLMSummarization functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chains.llm_summarization_checker.base import (
|
||||||
|
ARE_ALL_TRUE_PROMPT,
|
||||||
|
CHECK_ASSERTIONS_PROMPT,
|
||||||
|
CREATE_ASSERTIONS_PROMPT,
|
||||||
|
REVISED_SUMMARY_PROMPT,
|
||||||
|
LLMSummarizationCheckerChain,
|
||||||
|
)
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain:
|
||||||
|
"""Fake LLMCheckerChain for testing."""
|
||||||
|
queries = {
|
||||||
|
CREATE_ASSERTIONS_PROMPT.format(
|
||||||
|
summary="a",
|
||||||
|
): "b",
|
||||||
|
CHECK_ASSERTIONS_PROMPT.format(
|
||||||
|
assertions="b",
|
||||||
|
): "- b - True",
|
||||||
|
REVISED_SUMMARY_PROMPT.format(
|
||||||
|
checked_assertions="- b - True", summary="a"
|
||||||
|
): "b",
|
||||||
|
ARE_ALL_TRUE_PROMPT.format(
|
||||||
|
checked_assertions="- b - True",
|
||||||
|
): "True",
|
||||||
|
}
|
||||||
|
fake_llm = FakeLLM(queries=queries)
|
||||||
|
return LLMSummarizationCheckerChain(llm=fake_llm, input_key="q", output_key="a")
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_text(
|
||||||
|
fake_llm_summarization_checker_chain: LLMSummarizationCheckerChain,
|
||||||
|
) -> None:
|
||||||
|
"""Test simple question that should not need python."""
|
||||||
|
question = "a"
|
||||||
|
output = fake_llm_summarization_checker_chain.run(question)
|
||||||
|
assert output == "b"
|
Loading…
Reference in New Issue
Block a user