Add LLMCheckerChain (#281)

Implementation of https://github.com/jagilley/fact-checker. Works pretty
well.

<img width="993" alt="Screenshot 2022-12-07 at 4 41 47 PM"
src="https://user-images.githubusercontent.com/101075607/206302751-356a19ff-d000-4798-9aee-9c38b7f532b9.png">

Verifying this manually:
1. "Only two kinds of egg-laying mammals are left on the planet
today—the duck-billed platypus and the echidna, or spiny anteater."
https://www.scientificamerican.com/article/extreme-monotremes/
2. "An [Echidna] egg weighs 1.5 to 2 grams (0.05 to 0.07
oz)[[19]](https://en.wikipedia.org/wiki/Echidna#cite_note-19) and is
about 1.4 centimetres (0.55 in) long."
https://en.wikipedia.org/wiki/Echidna#:~:text=sleep%20is%20suppressed.-,Reproduction,a%20reptile%2Dlike%20egg%20tooth.
3. "A [platypus] lays one to three (usually two) small, leathery eggs
(similar to those of reptiles), about 11 mm (7⁄16 in) in diameter and
slightly rounder than bird eggs."
https://en.wikipedia.org/wiki/Platypus#:~:text=It%20lays%20one%20to%20three,slightly%20rounder%20than%20bird%20eggs.
4. Therefore, an Echidna is the mammal that lays the biggest eggs.


cc @hwchase17
harrison/agent_multi_inputs
andersenchen 1 year ago committed by GitHub
parent 43c9bd869f
commit 5267ebce2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,58 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LLMCheckerChain\n",
"This notebook showcases how to use LLMCheckerChain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMCheckerChain\n",
"from langchain.llms import OpenAI\n",
"\n",
"llm = OpenAI(temperature=0.7)\n",
"\n",
"text = \"What type of mammal lays the biggest eggs?\"\n",
"\n",
"checker_chain = LLMCheckerChain(llm=llm, verbose=True)\n",
"\n",
"checker_chain.run(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -5,6 +5,7 @@ from langchain.chains import (
ConversationChain,
LLMBashChain,
LLMChain,
LLMCheckerChain,
LLMMathChain,
PALChain,
QAWithSourcesChain,
@ -27,6 +28,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch
__all__ = [
"LLMChain",
"LLMBashChain",
"LLMCheckerChain",
"LLMMathChain",
"SelfAskWithSearchChain",
"SerpAPIWrapper",

@ -3,6 +3,7 @@ from langchain.chains.api.base import APIChain
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_checker.base import LLMCheckerChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.mapreduce import MapReduceChain
@ -19,6 +20,7 @@ __all__ = [
"ConversationChain",
"LLMChain",
"LLMBashChain",
"LLMCheckerChain",
"LLMMathChain",
"PALChain",
"QAWithSourcesChain",

@ -0,0 +1,4 @@
"""Chain that tries to verify assumptions before answering a question.
Heavily borrowed from https://github.com/jagilley/fact-checker
"""

@ -0,0 +1,98 @@
"""Chain for question-answering with self-verification."""
from typing import Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_checker.prompt import (
CHECK_ASSERTIONS_PROMPT,
CREATE_DRAFT_ANSWER_PROMPT,
LIST_ASSERTIONS_PROMPT,
REVISED_ANSWER_PROMPT,
)
from langchain.chains.sequential import SequentialChain
from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate
class LLMCheckerChain(Chain, BaseModel):
"""Chain for question-answering with self-verification.
Example:
.. code-block:: python
from langchain import OpenAI, LLMCheckerChain
llm = OpenAI(temperature=0.7)
checker_chain = LLMCheckerChain(llm=llm)
"""
llm: LLM
"""LLM wrapper to use."""
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT
"""Prompt to use when questioning the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.input_key]
create_draft_answer_chain = LLMChain(
llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement"
)
list_assertions_chain = LLMChain(
llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions"
)
check_assertions_chain = LLMChain(
llm=self.llm,
prompt=self.check_assertions_prompt,
output_key="checked_assertions",
)
revised_answer_chain = LLMChain(
llm=self.llm,
prompt=self.revised_answer_prompt,
output_key="revised_statement",
)
chains = [
create_draft_answer_chain,
list_assertions_chain,
check_assertions_chain,
revised_answer_chain,
]
question_to_checked_assertions_chain = SequentialChain(
chains=chains,
input_variables=["question"],
output_variables=["revised_statement"],
verbose=True,
)
output = question_to_checked_assertions_chain({"question": question})
return {self.output_key: output["revised_statement"]}

@ -0,0 +1,31 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
_CREATE_DRAFT_ANSWER_TEMPLATE = """{question}\n\n"""
CREATE_DRAFT_ANSWER_PROMPT = PromptTemplate(
input_variables=["question"], template=_CREATE_DRAFT_ANSWER_TEMPLATE
)
_LIST_ASSERTIONS_TEMPLATE = """Here is a statement:
{statement}
Make a bullet point list of the assumptions you made when producing the above statement.\n\n"""
LIST_ASSERTIONS_PROMPT = PromptTemplate(
input_variables=["statement"], template=_LIST_ASSERTIONS_TEMPLATE
)
_CHECK_ASSERTIONS_TEMPLATE = """Here is a bullet point list of assertions:
{assertions}
For each assertion, determine whether it is true or false. If it is false, explain why.\n\n"""
CHECK_ASSERTIONS_PROMPT = PromptTemplate(
input_variables=["assertions"], template=_CHECK_ASSERTIONS_TEMPLATE
)
_REVISED_ANSWER_TEMPLATE = """{checked_assertions}
Question: In light of the above assertions and checks, how would you answer the question '{question}'?
Answer:"""
REVISED_ANSWER_PROMPT = PromptTemplate(
input_variables=["checked_assertions", "question"],
template=_REVISED_ANSWER_TEMPLATE,
)

@ -0,0 +1,43 @@
# flake8: noqa E501
"""Test LLMCheckerChain functionality."""
import pytest
from langchain.chains.llm_checker.base import LLMCheckerChain
from langchain.chains.llm_checker.prompt import (
_CHECK_ASSERTIONS_TEMPLATE,
_CREATE_DRAFT_ANSWER_TEMPLATE,
_LIST_ASSERTIONS_TEMPLATE,
_REVISED_ANSWER_TEMPLATE,
)
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture
def fake_llm_checker_chain() -> LLMCheckerChain:
"""Fake LLMCheckerChain for testing."""
queries = {
_CREATE_DRAFT_ANSWER_TEMPLATE.format(
question="Which mammal lays the biggest eggs?"
): "I don't know which mammal layers the biggest eggs.",
_LIST_ASSERTIONS_TEMPLATE.format(
statement="I don't know which mammal layers the biggest eggs.",
): "1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.",
_CHECK_ASSERTIONS_TEMPLATE.format(
assertions="1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.",
): "1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE",
_REVISED_ANSWER_TEMPLATE.format(
checked_assertions="1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE",
question="Which mammal lays the biggest eggs?",
): "I still don't know.",
}
fake_llm = FakeLLM(queries=queries)
return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a")
def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None:
"""Test simple question that should not need python."""
question = "Which mammal lays the biggest eggs?"
output = fake_llm_checker_chain.run(question)
assert output == "I still don't know."
Loading…
Cancel
Save