From 91d7fd20aec2625ecd946a2a39ad80dc35d38019 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Sat, 14 Jan 2023 12:23:48 -0300 Subject: [PATCH] feat: add custom prompt for QAEvalChain chain (#610) I originally had only modified the `from_llm` to include the prompt but I realized that if the prompt keys used on the custom prompt didn't match the default prompt, it wouldn't work because of how `apply` works. So I made some changes to the evaluate method to check if the prompt is the default and if not, it will check if the input keys are the same as the prompt key and update the inputs appropriately. Let me know if there is a better way to do this. Also added the custom prompt to the QA eval notebook. --- .../evaluation/question_answering.ipynb | 54 ++++++++++++++++++- langchain/evaluation/qa/eval_chain.py | 24 +++++++-- 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/docs/use_cases/evaluation/question_answering.ipynb b/docs/use_cases/evaluation/question_answering.ipynb index b80844eb43..98de75375b 100644 --- a/docs/use_cases/evaluation/question_answering.ipynb +++ b/docs/use_cases/evaluation/question_answering.ipynb @@ -190,6 +190,51 @@ " print()" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "782ae8c8", + "metadata": {}, + "source": [ + "## Customize Prompt\n", + "\n", + "You can also customize the prompt that is used. Here is an example prompting it using a score from 0 to 10.\n", + "The custom prompt requires 3 input variables: \"query\", \"answer\" and \"result\". Where \"query\" is the question, \"answer\" is the ground truth answer, and \"result\" is the predicted answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "153425c4", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts.prompt import PromptTemplate\n", + "\n", + "_PROMPT_TEMPLATE = \"\"\"You are an expert professor specialized in grading students' answers to questions.\n", + "You are grading the following question:\n", + "{query}\n", + "Here is the real answer:\n", + "{answer}\n", + "You are grading the following predicted answer:\n", + "{result}\n", + "What grade do you give from 0 to 10, where 0 is the lowest (very low similarity) and 10 is the highest (very high similarity)?\n", + "\"\"\"\n", + "\n", + "PROMPT = PromptTemplate(input_variables=[\"query\", \"answer\", \"result\"], template=_PROMPT_TEMPLATE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a3b0fb7", + "metadata": {}, + "outputs": [], + "source": [ + "evalchain = QAEvalChain.from_llm(llm=llm,prompt=PROMPT)\n", + "evalchain.evaluate(examples, predictions, question_key=\"question\", answer_key=\"answer\", prediction_key=\"text\")" + ] + }, { "cell_type": "markdown", "id": "aaa61f0c", @@ -271,7 +316,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -285,7 +330,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.7 (default, Sep 16 2021, 08:50:36) \n[Clang 10.0.0 ]" + }, + "vscode": { + "interpreter": { + "hash": "53f3bc57609c7a84333bb558594977aa5b4026b1d6070b93987956689e367341" + } } }, "nbformat": 4, diff --git a/langchain/evaluation/qa/eval_chain.py b/langchain/evaluation/qa/eval_chain.py index ded97302cf..e712cff7ad 100644 --- a/langchain/evaluation/qa/eval_chain.py +++ b/langchain/evaluation/qa/eval_chain.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import Any, List +from langchain import PromptTemplate from langchain.chains.llm import LLMChain from langchain.evaluation.qa.eval_prompt import PROMPT from langchain.llms.base import BaseLLM @@ -12,9 +13,25 @@ class QAEvalChain(LLMChain): """LLM Chain specifically for evaluating question answering.""" @classmethod - def from_llm(cls, llm: BaseLLM, **kwargs: Any) -> QAEvalChain: - """Load QA Eval Chain from LLM.""" - return cls(llm=llm, prompt=PROMPT, **kwargs) + def from_llm( + cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any + ) -> QAEvalChain: + """Load QA Eval Chain from LLM. + + Args: + llm (BaseLLM): the base language model to use. + + prompt (PromptTemplate): A prompt template containing the input_variables: + 'input', 'answer' and 'result' that will be used as the prompt + for evaluation. + Defaults to PROMPT. + + **kwargs: additional keyword arguments. + + Returns: + QAEvalChain: the loaded QA eval chain. + """ + return cls(llm=llm, prompt=prompt, **kwargs) def evaluate( self, @@ -33,4 +50,5 @@ class QAEvalChain(LLMChain): "result": predictions[i][prediction_key], } inputs.append(_input) + return self.apply(inputs)