forked from Archives/langchain
Add LLMCheckerChain (#281)
Implementation of https://github.com/jagilley/fact-checker. Works pretty well. <img width="993" alt="Screenshot 2022-12-07 at 4 41 47 PM" src="https://user-images.githubusercontent.com/101075607/206302751-356a19ff-d000-4798-9aee-9c38b7f532b9.png"> Verifying this manually: 1. "Only two kinds of egg-laying mammals are left on the planet today—the duck-billed platypus and the echidna, or spiny anteater." https://www.scientificamerican.com/article/extreme-monotremes/ 2. "An [Echidna] egg weighs 1.5 to 2 grams (0.05 to 0.07 oz)[[19]](https://en.wikipedia.org/wiki/Echidna#cite_note-19) and is about 1.4 centimetres (0.55 in) long." https://en.wikipedia.org/wiki/Echidna#:~:text=sleep%20is%20suppressed.-,Reproduction,a%20reptile%2Dlike%20egg%20tooth. 3. "A [platypus] lays one to three (usually two) small, leathery eggs (similar to those of reptiles), about 11 mm (7⁄16 in) in diameter and slightly rounder than bird eggs." https://en.wikipedia.org/wiki/Platypus#:~:text=It%20lays%20one%20to%20three,slightly%20rounder%20than%20bird%20eggs. 4. Therefore, an Echidna is the mammal that lays the biggest eggs. cc @hwchase17harrison/agent_multi_inputs
parent
43c9bd869f
commit
5267ebce2d
@ -0,0 +1,58 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# LLMCheckerChain\n",
|
||||
"This notebook showcases how to use LLMCheckerChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import LLMCheckerChain\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0.7)\n",
|
||||
"\n",
|
||||
"text = \"What type of mammal lays the biggest eggs?\"\n",
|
||||
"\n",
|
||||
"checker_chain = LLMCheckerChain(llm=llm, verbose=True)\n",
|
||||
"\n",
|
||||
"checker_chain.run(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
"""Chain that tries to verify assumptions before answering a question.
|
||||
|
||||
Heavily borrowed from https://github.com/jagilley/fact-checker
|
||||
"""
|
@ -0,0 +1,98 @@
|
||||
"""Chain for question-answering with self-verification."""
|
||||
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_checker.prompt import (
|
||||
CHECK_ASSERTIONS_PROMPT,
|
||||
CREATE_DRAFT_ANSWER_PROMPT,
|
||||
LIST_ASSERTIONS_PROMPT,
|
||||
REVISED_ANSWER_PROMPT,
|
||||
)
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
|
||||
class LLMCheckerChain(Chain, BaseModel):
|
||||
"""Chain for question-answering with self-verification.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain import OpenAI, LLMCheckerChain
|
||||
llm = OpenAI(temperature=0.7)
|
||||
checker_chain = LLMCheckerChain(llm=llm)
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
"""LLM wrapper to use."""
|
||||
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
|
||||
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
|
||||
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
|
||||
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT
|
||||
"""Prompt to use when questioning the documents."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
question = inputs[self.input_key]
|
||||
|
||||
create_draft_answer_chain = LLMChain(
|
||||
llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement"
|
||||
)
|
||||
list_assertions_chain = LLMChain(
|
||||
llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions"
|
||||
)
|
||||
check_assertions_chain = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.check_assertions_prompt,
|
||||
output_key="checked_assertions",
|
||||
)
|
||||
|
||||
revised_answer_chain = LLMChain(
|
||||
llm=self.llm,
|
||||
prompt=self.revised_answer_prompt,
|
||||
output_key="revised_statement",
|
||||
)
|
||||
|
||||
chains = [
|
||||
create_draft_answer_chain,
|
||||
list_assertions_chain,
|
||||
check_assertions_chain,
|
||||
revised_answer_chain,
|
||||
]
|
||||
|
||||
question_to_checked_assertions_chain = SequentialChain(
|
||||
chains=chains,
|
||||
input_variables=["question"],
|
||||
output_variables=["revised_statement"],
|
||||
verbose=True,
|
||||
)
|
||||
output = question_to_checked_assertions_chain({"question": question})
|
||||
return {self.output_key: output["revised_statement"]}
|
@ -0,0 +1,31 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
_CREATE_DRAFT_ANSWER_TEMPLATE = """{question}\n\n"""
|
||||
CREATE_DRAFT_ANSWER_PROMPT = PromptTemplate(
|
||||
input_variables=["question"], template=_CREATE_DRAFT_ANSWER_TEMPLATE
|
||||
)
|
||||
|
||||
_LIST_ASSERTIONS_TEMPLATE = """Here is a statement:
|
||||
{statement}
|
||||
Make a bullet point list of the assumptions you made when producing the above statement.\n\n"""
|
||||
LIST_ASSERTIONS_PROMPT = PromptTemplate(
|
||||
input_variables=["statement"], template=_LIST_ASSERTIONS_TEMPLATE
|
||||
)
|
||||
|
||||
_CHECK_ASSERTIONS_TEMPLATE = """Here is a bullet point list of assertions:
|
||||
{assertions}
|
||||
For each assertion, determine whether it is true or false. If it is false, explain why.\n\n"""
|
||||
CHECK_ASSERTIONS_PROMPT = PromptTemplate(
|
||||
input_variables=["assertions"], template=_CHECK_ASSERTIONS_TEMPLATE
|
||||
)
|
||||
|
||||
_REVISED_ANSWER_TEMPLATE = """{checked_assertions}
|
||||
|
||||
Question: In light of the above assertions and checks, how would you answer the question '{question}'?
|
||||
|
||||
Answer:"""
|
||||
REVISED_ANSWER_PROMPT = PromptTemplate(
|
||||
input_variables=["checked_assertions", "question"],
|
||||
template=_REVISED_ANSWER_TEMPLATE,
|
||||
)
|
@ -0,0 +1,43 @@
|
||||
# flake8: noqa E501
|
||||
|
||||
"""Test LLMCheckerChain functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm_checker.base import LLMCheckerChain
|
||||
from langchain.chains.llm_checker.prompt import (
|
||||
_CHECK_ASSERTIONS_TEMPLATE,
|
||||
_CREATE_DRAFT_ANSWER_TEMPLATE,
|
||||
_LIST_ASSERTIONS_TEMPLATE,
|
||||
_REVISED_ANSWER_TEMPLATE,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_checker_chain() -> LLMCheckerChain:
|
||||
"""Fake LLMCheckerChain for testing."""
|
||||
queries = {
|
||||
_CREATE_DRAFT_ANSWER_TEMPLATE.format(
|
||||
question="Which mammal lays the biggest eggs?"
|
||||
): "I don't know which mammal layers the biggest eggs.",
|
||||
_LIST_ASSERTIONS_TEMPLATE.format(
|
||||
statement="I don't know which mammal layers the biggest eggs.",
|
||||
): "1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.",
|
||||
_CHECK_ASSERTIONS_TEMPLATE.format(
|
||||
assertions="1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.",
|
||||
): "1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE",
|
||||
_REVISED_ANSWER_TEMPLATE.format(
|
||||
checked_assertions="1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE",
|
||||
question="Which mammal lays the biggest eggs?",
|
||||
): "I still don't know.",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
question = "Which mammal lays the biggest eggs?"
|
||||
output = fake_llm_checker_chain.run(question)
|
||||
assert output == "I still don't know."
|
Loading…
Reference in New Issue