Adding an in-context QA evaluation chain + chain of thought reasoning chain for improved accuracy (#2444)

Right now, eval chains require an answer for every question. It's
cumbersome to collect this ground truth so getting around this issue
with 2 things:

* Adding a context param in `ContextQAEvalChain` and simply evaluating
if the question is answered accurately from context
* Adding chain of though explanation prompting to improve the accuracy
of this w/o GT.

This also gets to feature parity with openai/evals which has the same
contextual eval w/o GT.

TODO in follow-up:
* Better prompt inheritance. No need for seperate prompt for CoT
reasoning. How can we merge them together

---------

Co-authored-by: Vashisht Madhavan <vashishtmadhavan@Vashs-MacBook-Pro.local>
doc
Vashisht Madhavan 1 year ago committed by GitHub
parent e131156805
commit aa439ac2ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -234,6 +234,93 @@
"evalchain.evaluate(examples, predictions, question_key=\"question\", answer_key=\"answer\", prediction_key=\"text\")"
]
},
{
"cell_type": "markdown",
"id": "cb1cf335",
"metadata": {},
"source": [
"## Evaluation without Ground Truth\n",
"Its possible to evaluate question answering systems without ground truth. You would need a `\"context\"` input that reflects what the information the LLM uses to answer the question. This context can be obtained by any retreival system. Here's an example of how it works:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c59293f",
"metadata": {},
"outputs": [],
"source": [
"context_examples = [\n",
" {\n",
" \"question\": \"How old am I?\",\n",
" \"context\": \"I am 30 years old. I live in New York and take the train to work everyday.\",\n",
" },\n",
" {\n",
" \"question\": 'Who won the NFC championship game in 2023?\"',\n",
" \"context\": \"NFC Championship Game 2023: Philadelphia Eagles 31, San Francisco 49ers 7\"\n",
" }\n",
"]\n",
"QA_PROMPT = \"Answer the question based on the context\\nContext:{context}\\nQuestion:{question}\\nAnswer:\"\n",
"template = PromptTemplate(input_variables=[\"context\", \"question\"], template=QA_PROMPT)\n",
"qa_chain = LLMChain(llm=llm, prompt=template)\n",
"predictions = qa_chain.apply(context_examples)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e500d0cc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'text': 'You are 30 years old.'},\n",
" {'text': ' The Philadelphia Eagles won the NFC championship game in 2023.'}]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "6d8cbc1d",
"metadata": {},
"outputs": [],
"source": [
"from langchain.evaluation.qa import ContextQAEvalChain\n",
"eval_chain = ContextQAEvalChain.from_llm(llm)\n",
"graded_outputs = eval_chain.evaluate(context_examples, predictions, question_key=\"question\", prediction_key=\"text\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6c5262d0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'text': ' CORRECT'}, {'text': ' CORRECT'}]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graded_outputs"
]
},
{
"cell_type": "markdown",
"id": "aaa61f0c",
@ -329,7 +416,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.9.16"
},
"vscode": {
"interpreter": {

@ -1,5 +1,9 @@
"""Chains and utils related to evaluating question answering functionality."""
from langchain.evaluation.qa.eval_chain import QAEvalChain
from langchain.evaluation.qa.eval_chain import (
ContextQAEvalChain,
CotQAEvalChain,
QAEvalChain,
)
from langchain.evaluation.qa.generate_chain import QAGenerateChain
__all__ = ["QAEvalChain", "QAGenerateChain"]
__all__ = ["QAEvalChain", "QAGenerateChain", "ContextQAEvalChain", "CotQAEvalChain"]

@ -5,7 +5,7 @@ from typing import Any, List
from langchain import PromptTemplate
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import PROMPT
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
from langchain.llms.base import BaseLLM
@ -58,3 +58,69 @@ class QAEvalChain(LLMChain):
]
return self.apply(inputs)
class ContextQAEvalChain(LLMChain):
"""LLM Chain specifically for evaluating QA w/o GT based on context"""
@classmethod
def _validate_input_vars(cls, prompt: PromptTemplate) -> None:
expected_input_vars = {"query", "context", "result"}
if expected_input_vars != set(prompt.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt.input_variables}"
)
@classmethod
def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = CONTEXT_PROMPT, **kwargs: Any
) -> ContextQAEvalChain:
"""Load QA Eval Chain from LLM.
Args:
llm (BaseLLM): the base language model to use.
prompt (PromptTemplate): A prompt template containing the input_variables:
'query', 'context' and 'result' that will be used as the prompt
for evaluation.
Defaults to PROMPT.
**kwargs: additional keyword arguments.
Returns:
ContextQAEvalChain: the loaded QA eval chain.
"""
cls._validate_input_vars(prompt)
return cls(llm=llm, prompt=prompt, **kwargs)
def evaluate(
self,
examples: List[dict],
predictions: List[dict],
question_key: str = "query",
context_key: str = "context",
prediction_key: str = "result",
) -> List[dict]:
"""Evaluate question answering examples and predictions."""
inputs = [
{
"query": example[question_key],
"context": example[context_key],
"result": predictions[i][prediction_key],
}
for i, example in enumerate(examples)
]
return self.apply(inputs)
class CotQAEvalChain(ContextQAEvalChain):
"""LLM Chain specifically for evaluating QA using chain of thought reasoning."""
@classmethod
def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any
) -> CotQAEvalChain:
cls._validate_input_vars(prompt)
return cls(llm=llm, prompt=prompt, **kwargs)

@ -19,3 +19,44 @@ GRADE:"""
PROMPT = PromptTemplate(
input_variables=["query", "result", "answer"], template=template
)
context_template = """You are a teacher grading a quiz.
You are given a question, the contex the question is about, and the student's answer You are asked to score the student's answer as either CORRECT or INCORRECT, based on the context.
Example Format:
QUESTION: question here
CONTEXT: context the question is about here
STUDENT ANSWER: student's answer here
GRADE: CORRECT or INCORRECT here
Please remember to grade them based on being factually accurate. Begin!
QUESTION: {query}
CONTEXT: {context}
STUDENT ANSWER: {result}
GRADE:"""
CONTEXT_PROMPT = PromptTemplate(
input_variables=["query", "context", "result"], template=context_template
)
cot_template = """You are a teacher grading a quiz.
You are given a question, the contex the question is about, and the student's answer You are asked to score the student's answer as either CORRECT or INCORRECT, based on the context.
Write out in a step by step manner your reasoning to be sure that your conclusion is correct. Avoid simply stating the correct answer at the outset.
Example Format:
QUESTION: question here
CONTEXT: context the question is about here
STUDENT ANSWER: student's answer here
EXPLANATION: step by step reasoning here
GRADE: CORRECT or INCORRECT here
Please remember to grade them based on being factually accurate. Begin!
QUESTION: {query}
CONTEXT: {context}
STUDENT ANSWER: {result}
EXPLANATION:"""
COT_PROMPT = PromptTemplate(
input_variables=["query", "context", "result"], template=cot_template
)

@ -0,0 +1 @@
"""New unit tests for the evaluation module."""

@ -0,0 +1 @@
"""Tests for QA evaluation chains."""

@ -0,0 +1,46 @@
"""Test LLM Bash functionality."""
import sys
from typing import Type
import pytest
from langchain.evaluation.qa.eval_chain import (
ContextQAEvalChain,
CotQAEvalChain,
QAEvalChain,
)
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
def test_eval_chain() -> None:
"""Test a simple eval chain."""
example = {"query": "What's my name", "answer": "John Doe"}
prediction = {"result": "John Doe"}
fake_qa_eval_chain = QAEvalChain.from_llm(FakeLLM())
outputs = fake_qa_eval_chain.evaluate([example, example], [prediction, prediction])
assert outputs[0] == outputs[1]
assert "text" in outputs[0]
assert outputs[0]["text"] == "foo"
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
@pytest.mark.parametrize("chain_cls", [ContextQAEvalChain, CotQAEvalChain])
def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None:
"""Test a simple eval chain."""
example = {
"query": "What's my name",
"context": "The name of this person is John Doe",
}
prediction = {"result": "John Doe"}
fake_qa_eval_chain = chain_cls.from_llm(FakeLLM())
outputs = fake_qa_eval_chain.evaluate([example, example], [prediction, prediction])
assert outputs[0] == outputs[1]
assert "text" in outputs[0]
assert outputs[0]["text"] == "foo"
Loading…
Cancel
Save