diff --git a/docs/use_cases/evaluation/question_answering.ipynb b/docs/use_cases/evaluation/question_answering.ipynb index 02428539..274e77a2 100644 --- a/docs/use_cases/evaluation/question_answering.ipynb +++ b/docs/use_cases/evaluation/question_answering.ipynb @@ -234,6 +234,93 @@ "evalchain.evaluate(examples, predictions, question_key=\"question\", answer_key=\"answer\", prediction_key=\"text\")" ] }, + { + "cell_type": "markdown", + "id": "cb1cf335", + "metadata": {}, + "source": [ + "## Evaluation without Ground Truth\n", + "Its possible to evaluate question answering systems without ground truth. You would need a `\"context\"` input that reflects what the information the LLM uses to answer the question. This context can be obtained by any retreival system. Here's an example of how it works:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c59293f", + "metadata": {}, + "outputs": [], + "source": [ + "context_examples = [\n", + " {\n", + " \"question\": \"How old am I?\",\n", + " \"context\": \"I am 30 years old. I live in New York and take the train to work everyday.\",\n", + " },\n", + " {\n", + " \"question\": 'Who won the NFC championship game in 2023?\"',\n", + " \"context\": \"NFC Championship Game 2023: Philadelphia Eagles 31, San Francisco 49ers 7\"\n", + " }\n", + "]\n", + "QA_PROMPT = \"Answer the question based on the context\\nContext:{context}\\nQuestion:{question}\\nAnswer:\"\n", + "template = PromptTemplate(input_variables=[\"context\", \"question\"], template=QA_PROMPT)\n", + "qa_chain = LLMChain(llm=llm, prompt=template)\n", + "predictions = qa_chain.apply(context_examples)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e500d0cc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'text': 'You are 30 years old.'},\n", + " {'text': ' The Philadelphia Eagles won the NFC championship game in 2023.'}]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predictions" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6d8cbc1d", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.evaluation.qa import ContextQAEvalChain\n", + "eval_chain = ContextQAEvalChain.from_llm(llm)\n", + "graded_outputs = eval_chain.evaluate(context_examples, predictions, question_key=\"question\", prediction_key=\"text\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6c5262d0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'text': ' CORRECT'}, {'text': ' CORRECT'}]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "graded_outputs" + ] + }, { "cell_type": "markdown", "id": "aaa61f0c", @@ -329,7 +416,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.9.16" }, "vscode": { "interpreter": { diff --git a/langchain/evaluation/qa/__init__.py b/langchain/evaluation/qa/__init__.py index 72863316..e2d639f9 100644 --- a/langchain/evaluation/qa/__init__.py +++ b/langchain/evaluation/qa/__init__.py @@ -1,5 +1,9 @@ """Chains and utils related to evaluating question answering functionality.""" -from langchain.evaluation.qa.eval_chain import QAEvalChain +from langchain.evaluation.qa.eval_chain import ( + ContextQAEvalChain, + CotQAEvalChain, + QAEvalChain, +) from langchain.evaluation.qa.generate_chain import QAGenerateChain -__all__ = ["QAEvalChain", "QAGenerateChain"] +__all__ = ["QAEvalChain", "QAGenerateChain", "ContextQAEvalChain", "CotQAEvalChain"] diff --git a/langchain/evaluation/qa/eval_chain.py b/langchain/evaluation/qa/eval_chain.py index 382f9f55..43e3be65 100644 --- a/langchain/evaluation/qa/eval_chain.py +++ b/langchain/evaluation/qa/eval_chain.py @@ -5,7 +5,7 @@ from typing import Any, List from langchain import PromptTemplate from langchain.chains.llm import LLMChain -from langchain.evaluation.qa.eval_prompt import PROMPT +from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT from langchain.llms.base import BaseLLM @@ -58,3 +58,69 @@ class QAEvalChain(LLMChain): ] return self.apply(inputs) + + +class ContextQAEvalChain(LLMChain): + """LLM Chain specifically for evaluating QA w/o GT based on context""" + + @classmethod + def _validate_input_vars(cls, prompt: PromptTemplate) -> None: + expected_input_vars = {"query", "context", "result"} + if expected_input_vars != set(prompt.input_variables): + raise ValueError( + f"Input variables should be {expected_input_vars}, " + f"but got {prompt.input_variables}" + ) + + @classmethod + def from_llm( + cls, llm: BaseLLM, prompt: PromptTemplate = CONTEXT_PROMPT, **kwargs: Any + ) -> ContextQAEvalChain: + """Load QA Eval Chain from LLM. + + Args: + llm (BaseLLM): the base language model to use. + + prompt (PromptTemplate): A prompt template containing the input_variables: + 'query', 'context' and 'result' that will be used as the prompt + for evaluation. + Defaults to PROMPT. + + **kwargs: additional keyword arguments. + + Returns: + ContextQAEvalChain: the loaded QA eval chain. + """ + cls._validate_input_vars(prompt) + return cls(llm=llm, prompt=prompt, **kwargs) + + def evaluate( + self, + examples: List[dict], + predictions: List[dict], + question_key: str = "query", + context_key: str = "context", + prediction_key: str = "result", + ) -> List[dict]: + """Evaluate question answering examples and predictions.""" + inputs = [ + { + "query": example[question_key], + "context": example[context_key], + "result": predictions[i][prediction_key], + } + for i, example in enumerate(examples) + ] + + return self.apply(inputs) + + +class CotQAEvalChain(ContextQAEvalChain): + """LLM Chain specifically for evaluating QA using chain of thought reasoning.""" + + @classmethod + def from_llm( + cls, llm: BaseLLM, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any + ) -> CotQAEvalChain: + cls._validate_input_vars(prompt) + return cls(llm=llm, prompt=prompt, **kwargs) diff --git a/langchain/evaluation/qa/eval_prompt.py b/langchain/evaluation/qa/eval_prompt.py index 4e6c7c6a..5fa5c732 100644 --- a/langchain/evaluation/qa/eval_prompt.py +++ b/langchain/evaluation/qa/eval_prompt.py @@ -19,3 +19,44 @@ GRADE:""" PROMPT = PromptTemplate( input_variables=["query", "result", "answer"], template=template ) + +context_template = """You are a teacher grading a quiz. +You are given a question, the contex the question is about, and the student's answer You are asked to score the student's answer as either CORRECT or INCORRECT, based on the context. + +Example Format: +QUESTION: question here +CONTEXT: context the question is about here +STUDENT ANSWER: student's answer here +GRADE: CORRECT or INCORRECT here + +Please remember to grade them based on being factually accurate. Begin! + +QUESTION: {query} +CONTEXT: {context} +STUDENT ANSWER: {result} +GRADE:""" +CONTEXT_PROMPT = PromptTemplate( + input_variables=["query", "context", "result"], template=context_template +) + + +cot_template = """You are a teacher grading a quiz. +You are given a question, the contex the question is about, and the student's answer You are asked to score the student's answer as either CORRECT or INCORRECT, based on the context. +Write out in a step by step manner your reasoning to be sure that your conclusion is correct. Avoid simply stating the correct answer at the outset. + +Example Format: +QUESTION: question here +CONTEXT: context the question is about here +STUDENT ANSWER: student's answer here +EXPLANATION: step by step reasoning here +GRADE: CORRECT or INCORRECT here + +Please remember to grade them based on being factually accurate. Begin! + +QUESTION: {query} +CONTEXT: {context} +STUDENT ANSWER: {result} +EXPLANATION:""" +COT_PROMPT = PromptTemplate( + input_variables=["query", "context", "result"], template=cot_template +) diff --git a/tests/unit_tests/evaluation/__init__.py b/tests/unit_tests/evaluation/__init__.py new file mode 100644 index 00000000..f5aefd92 --- /dev/null +++ b/tests/unit_tests/evaluation/__init__.py @@ -0,0 +1 @@ +"""New unit tests for the evaluation module.""" diff --git a/tests/unit_tests/evaluation/qa/__init__.py b/tests/unit_tests/evaluation/qa/__init__.py new file mode 100644 index 00000000..791c1731 --- /dev/null +++ b/tests/unit_tests/evaluation/qa/__init__.py @@ -0,0 +1 @@ +"""Tests for QA evaluation chains.""" diff --git a/tests/unit_tests/evaluation/qa/test_eval_chain.py b/tests/unit_tests/evaluation/qa/test_eval_chain.py new file mode 100644 index 00000000..ac77a97f --- /dev/null +++ b/tests/unit_tests/evaluation/qa/test_eval_chain.py @@ -0,0 +1,46 @@ +"""Test LLM Bash functionality.""" +import sys +from typing import Type + +import pytest + +from langchain.evaluation.qa.eval_chain import ( + ContextQAEvalChain, + CotQAEvalChain, + QAEvalChain, +) +from tests.unit_tests.llms.fake_llm import FakeLLM + + +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +def test_eval_chain() -> None: + """Test a simple eval chain.""" + example = {"query": "What's my name", "answer": "John Doe"} + prediction = {"result": "John Doe"} + fake_qa_eval_chain = QAEvalChain.from_llm(FakeLLM()) + + outputs = fake_qa_eval_chain.evaluate([example, example], [prediction, prediction]) + assert outputs[0] == outputs[1] + assert "text" in outputs[0] + assert outputs[0]["text"] == "foo" + + +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="Test not supported on Windows" +) +@pytest.mark.parametrize("chain_cls", [ContextQAEvalChain, CotQAEvalChain]) +def test_context_eval_chain(chain_cls: Type[ContextQAEvalChain]) -> None: + """Test a simple eval chain.""" + example = { + "query": "What's my name", + "context": "The name of this person is John Doe", + } + prediction = {"result": "John Doe"} + fake_qa_eval_chain = chain_cls.from_llm(FakeLLM()) + + outputs = fake_qa_eval_chain.evaluate([example, example], [prediction, prediction]) + assert outputs[0] == outputs[1] + assert "text" in outputs[0] + assert outputs[0]["text"] == "foo"