mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
cc60fed3be
Notebook shows preference scoring between two chains and reports wilson score interval + p value I think I'll add the option to insert ground truth labels but doesn't have to be in this PR
448 lines
14 KiB
Plaintext
448 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Comparing Chain Outputs\n",
|
|
"\n",
|
|
"Suppose you have two different prompts (or LLMs). How do you know which will generate \"better\" results?\n",
|
|
"\n",
|
|
"One automated way to predict the preferred configuration is to use a `PairwiseStringEvaluator` like the `PairwiseStringEvalChain`<a name=\"cite_ref-1\"></a>[<sup>[1]</sup>](#cite_note-1). This chain prompts an LLM to select which output is preferred, given a specific input.\n",
|
|
"\n",
|
|
"For this evalution, we will need 3 things:\n",
|
|
"1. An evaluator\n",
|
|
"2. A dataset of inputs\n",
|
|
"3. 2 (or more) LLMs, Chains, or Agents to compare\n",
|
|
"\n",
|
|
"Then we will aggregate the restults to determine the preferred model.\n",
|
|
"\n",
|
|
"### Step 1. Create the Evaluator\n",
|
|
"\n",
|
|
"In this example, you will use gpt-4 to select which output is preferred."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Optional if you are tracing the notebook\n",
|
|
"%env LANGCHAIN_PROJECT=\"Comparing Chain Outputs\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"from langchain.evaluation.comparison import PairwiseStringEvalChain\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(model=\"gpt-4\")\n",
|
|
"\n",
|
|
"eval_chain = PairwiseStringEvalChain.from_llm(llm=llm)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Step 2. Select Dataset\n",
|
|
"\n",
|
|
"If you already have real usage data for your LLM, you can use a representative sample. More examples\n",
|
|
"provide more reliable results. We will use some example queries someone might have about how to use langchain here."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Found cached dataset parquet (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___parquet/LangChainDatasets--langchain-howto-queries-bbb748bbee7e77aa/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "d852a1884480457292c90d8bd9d4f1e6",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/1 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain.evaluation.loading import load_dataset\n",
|
|
"\n",
|
|
"dataset = load_dataset(\"langchain-howto-queries\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Step 3. Define Models to Compare\n",
|
|
"\n",
|
|
"We will be comparing two agents in this case."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain import SerpAPIWrapper\n",
|
|
"from langchain.agents import initialize_agent, Tool\n",
|
|
"from langchain.agents import AgentType\n",
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"\n",
|
|
"\n",
|
|
"# Initialize the language model\n",
|
|
"# You can add your own OpenAI API key by adding openai_api_key=\"<your_api_key>\" \n",
|
|
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")\n",
|
|
"\n",
|
|
"# Initialize the SerpAPIWrapper for search functionality\n",
|
|
"#Replace <your_api_key> in openai_api_key=\"<your_api_key>\" with your actual SerpAPI key.\n",
|
|
"search = SerpAPIWrapper()\n",
|
|
"\n",
|
|
"# Define a list of tools offered by the agent\n",
|
|
"tools = [\n",
|
|
" Tool(\n",
|
|
" name=\"Search\",\n",
|
|
" func=search.run,\n",
|
|
" coroutine=search.arun,\n",
|
|
" description=\"Useful when you need to answer questions about current events. You should ask targeted questions.\"\n",
|
|
" ),\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"functions_agent = initialize_agent(tools, llm, agent=AgentType.OPENAI_MULTI_FUNCTIONS, verbose=False)\n",
|
|
"conversations_agent = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"list(zip(*[iter(batch_results)]*2)### Step 4. Generate Responses\n",
|
|
"\n",
|
|
"We will generate outputs for each of the models before evaluating them."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "b076d6bf6680422aa9082d4bad4d98a3",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/20 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised ServiceUnavailableError: The server is overloaded or not ready yet..\n",
|
|
"Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised ServiceUnavailableError: The server is overloaded or not ready yet..\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from tqdm.notebook import tqdm\n",
|
|
"import asyncio\n",
|
|
"\n",
|
|
"results = []\n",
|
|
"agents = [functions_agent, conversations_agent]\n",
|
|
"concurrency_level = 6 # How many concurrent agents to run. May need to decrease if OpenAI is rate limiting.\n",
|
|
"\n",
|
|
"# We will only run the first 20 examples of this dataset to speed things up\n",
|
|
"# This will lead to larger confidence intervals downstream.\n",
|
|
"batch = []\n",
|
|
"for example in tqdm(dataset[:20]):\n",
|
|
" batch.extend([agent.acall(example['inputs']) for agent in agents])\n",
|
|
" if len(batch) >= concurrency_level:\n",
|
|
" batch_results = await asyncio.gather(*batch, return_exceptions=True)\n",
|
|
" results.extend(list(zip(*[iter(batch_results)]*2)))\n",
|
|
" batch = []\n",
|
|
"if batch:\n",
|
|
" batch_results = await asyncio.gather(*batch, return_exceptions=True)\n",
|
|
" results.extend(list(zip(*[iter(batch_results)]*2)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Step 5. Evaluate Pairs\n",
|
|
"\n",
|
|
"Now it's time to evaluate the results. For each agent response, run the evaluation chain to select which output is preferred (or return a tie).\n",
|
|
"\n",
|
|
"Randomly select the input order to reduce the likelihood that one model will be preferred just because it is presented first."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import random\n",
|
|
"\n",
|
|
"def predict_preferences(dataset, results) -> list:\n",
|
|
" preferences = []\n",
|
|
"\n",
|
|
" for example, (res_a, res_b) in zip(dataset, results):\n",
|
|
" input_ = example['inputs']\n",
|
|
" # Flip a coin to reduce persistent position bias\n",
|
|
" if random.random() < 0.5:\n",
|
|
" pred_a, pred_b = res_a, res_b\n",
|
|
" a, b = \"a\", \"b\"\n",
|
|
" else:\n",
|
|
" pred_a, pred_b = res_b, res_a\n",
|
|
" a, b = \"b\", \"a\"\n",
|
|
" eval_res = eval_chain.evaluate_string_pairs(\n",
|
|
" output_a=pred_a['output'] if isinstance(pred_a, dict) else str(pred_a),\n",
|
|
" output_b=pred_b['output'] if isinstance(pred_b, dict) else str(pred_b),\n",
|
|
" input=input_\n",
|
|
" )\n",
|
|
" if eval_res[\"value\"] == \"A\":\n",
|
|
" preferences.append(a)\n",
|
|
" elif eval_res[\"value\"] == \"B\":\n",
|
|
" preferences.append(b)\n",
|
|
" else:\n",
|
|
" preferences.append(None) # No preference\n",
|
|
" return preferences"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"preferences = predict_preferences(dataset, results)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"**Print out the ratio of preferences.**"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"OpenAI Functions Agent: 90.00%\n",
|
|
"Structured Chat Agent: 10.00%\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from collections import Counter\n",
|
|
"\n",
|
|
"name_map = {\n",
|
|
" \"a\": \"OpenAI Functions Agent\",\n",
|
|
" \"b\": \"Structured Chat Agent\",\n",
|
|
"}\n",
|
|
"counts = Counter(preferences)\n",
|
|
"pref_ratios = {\n",
|
|
" k: v/len(preferences) for k, v in\n",
|
|
" counts.items()\n",
|
|
"}\n",
|
|
"for k, v in pref_ratios.items():\n",
|
|
" print(f\"{name_map.get(k)}: {v:.2%}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Estimate Confidence Intervals\n",
|
|
"\n",
|
|
"The results seem pretty clear, but if you want to have a better sense of how confident we are, that model \"A\" (the OpenAI Functions Agent) is the preferred model, we can calculate confidence intervals. \n",
|
|
"\n",
|
|
"Below, use the Wilson score to estimate the confidence interval."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from math import sqrt\n",
|
|
"\n",
|
|
"def wilson_score_interval(preferences: list, which: str = \"a\", z: float = 1.96) -> tuple:\n",
|
|
" \"\"\"Estimate the confidence interval using the Wilson score.\n",
|
|
" \n",
|
|
" See: https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Wilson_score_interval\n",
|
|
" for more details, including when to use it and when it should not be used.\n",
|
|
" \"\"\"\n",
|
|
" total_preferences = preferences.count('a') + preferences.count('b')\n",
|
|
" n_s = preferences.count(which)\n",
|
|
"\n",
|
|
" if total_preferences == 0:\n",
|
|
" return (0, 0)\n",
|
|
"\n",
|
|
" p_hat = n_s / total_preferences\n",
|
|
"\n",
|
|
" denominator = 1 + (z**2) / total_preferences\n",
|
|
" adjustment = (z / denominator) * sqrt(p_hat*(1-p_hat)/total_preferences + (z**2)/(4*total_preferences*total_preferences))\n",
|
|
" center = (p_hat + (z**2) / (2*total_preferences)) / denominator\n",
|
|
" lower_bound = min(max(center - adjustment, 0.0), 1.0)\n",
|
|
" upper_bound = min(max(center + adjustment, 0.0), 1.0)\n",
|
|
"\n",
|
|
" return (lower_bound, upper_bound)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The \"OpenAI Functions Agent\" would be preferred between 69.90% and 97.21% percent of the time (with 95% confidence).\n",
|
|
"The \"Structured Chat Agent\" would be preferred between 2.79% and 30.10% percent of the time (with 95% confidence).\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for which_, name in name_map.items():\n",
|
|
" low, high = wilson_score_interval(preferences, which=which_)\n",
|
|
" print(f'The \"{name}\" would be preferred between {low:.2%} and {high:.2%} percent of the time (with 95% confidence).')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Print out the p-value.**"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The p-value is 0.00040. If the null hypothesis is true (i.e., if the selected eval chain actually has no preference between the models),\n",
|
|
"then there is a 0.04025% chance of observing the OpenAI Functions Agent be preferred at least 18\n",
|
|
"times out of 20 trials.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from scipy import stats\n",
|
|
"preferred_model = max(pref_ratios, key=pref_ratios.get)\n",
|
|
"successes = preferences.count(preferred_model)\n",
|
|
"n = len(preferences) - preferences.count(None)\n",
|
|
"p_value = stats.binom_test(successes, n, p=0.5, alternative='two-sided')\n",
|
|
"print(f\"\"\"The p-value is {p_value:.5f}. If the null hypothesis is true (i.e., if the selected eval chain actually has no preference between the models),\n",
|
|
"then there is a {p_value:.5%} chance of observing the {name_map.get(preferred_model)} be preferred at least {successes}\n",
|
|
"times out of {n} trials.\"\"\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"<a name=\"cite_note-1\"></a>_1. Note: Automated evals are still an open research topic and are best used alongside other evaluation approaches. \n",
|
|
"LLM preferences exhibit biases, including banal ones like the order of outputs.\n",
|
|
"In choosing preferences, \"ground truth\" may not be taken into account, which may lead to scores that aren't grounded in utility._"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|