{ "cells": [ { "cell_type": "markdown", "id": "984169ca", "metadata": {}, "source": [ "# Agent VectorDB Question Answering Benchmarking\n", "\n", "Here we go over how to benchmark performance on a question answering task using an agent to route between multiple vectordatabases.\n", "\n", "It is highly reccomended that you do any evaluation/benchmarking with tracing enabled. See [here](https://langchain.readthedocs.io/en/latest/tracing.html) for an explanation of what tracing is and how to set it up." ] }, { "cell_type": "code", "execution_count": 1, "id": "7b57a50f", "metadata": {}, "outputs": [], "source": [ "# Comment this out if you are NOT using tracing\n", "import os\n", "os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"" ] }, { "cell_type": "markdown", "id": "8a16b75d", "metadata": {}, "source": [ "## Loading the data\n", "First, let's load the data." ] }, { "cell_type": "code", "execution_count": 2, "id": "5b2d5e98", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset json (/Users/qt/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-vectordb-qa-sota-pg-d3ae24016b514f92/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e)\n", "100%|██████████| 1/1 [00:00<00:00, 414.42it/s]\n" ] } ], "source": [ "from langchain.evaluation.loading import load_dataset\n", "dataset = load_dataset(\"agent-vectordb-qa-sota-pg\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "61375342", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'question': 'What is the purpose of the NATO Alliance?',\n", " 'answer': 'The purpose of the NATO Alliance is to secure peace and stability in Europe after World War 2.',\n", " 'steps': [{'tool': 'State of Union QA System', 'tool_input': None},\n", " {'tool': None, 'tool_input': 'What is the purpose of the NATO Alliance?'}]}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[0]" ] }, { "cell_type": "code", "execution_count": 4, "id": "02500304", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'question': 'What is the purpose of YC?',\n", " 'answer': 'The purpose of YC is to cause startups to be founded that would not otherwise have existed.',\n", " 'steps': [{'tool': 'Paul Graham QA System', 'tool_input': None},\n", " {'tool': None, 'tool_input': 'What is the purpose of YC?'}]}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[-1]" ] }, { "cell_type": "markdown", "id": "4ab6a716", "metadata": {}, "source": [ "## Setting up a chain\n", "Now we need to create some pipelines for doing question answering. Step one in that is creating indexes over the data in question." ] }, { "cell_type": "code", "execution_count": 5, "id": "c18680b5", "metadata": {}, "outputs": [], "source": [ "from langchain.document_loaders import TextLoader\n", "loader = TextLoader(\"../../modules/state_of_the_union.txt\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "7f0de2b3", "metadata": {}, "outputs": [], "source": [ "from langchain.indexes import VectorstoreIndexCreator" ] }, { "cell_type": "code", "execution_count": 12, "id": "ef84ff99", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using embedded DuckDB without persistence: data will be transient\n" ] } ], "source": [ "vectorstore_sota = VectorstoreIndexCreator(vectorstore_kwargs={\"collection_name\":\"sota\"}).from_loaders([loader]).vectorstore" ] }, { "cell_type": "markdown", "id": "f0b5d8f6", "metadata": {}, "source": [ "Now we can create a question answering chain." ] }, { "cell_type": "code", "execution_count": 13, "id": "8843cb0c", "metadata": {}, "outputs": [], "source": [ "from langchain.chains import RetrievalQA\n", "from langchain.llms import OpenAI" ] }, { "cell_type": "code", "execution_count": 16, "id": "573719a0", "metadata": {}, "outputs": [], "source": [ "chain_sota = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), chain_type=\"stuff\", retriever=vectorstore_sota.as_retriever(), input_key=\"question\")\n" ] }, { "cell_type": "markdown", "id": "e48b03d8", "metadata": {}, "source": [ "Now we do the same for the Paul Graham data." ] }, { "cell_type": "code", "execution_count": 17, "id": "c2dbb014", "metadata": {}, "outputs": [], "source": [ "loader = TextLoader(\"../../modules/paul_graham_essay.txt\")" ] }, { "cell_type": "code", "execution_count": 19, "id": "98d16f08", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using embedded DuckDB without persistence: data will be transient\n" ] } ], "source": [ "vectorstore_pg = VectorstoreIndexCreator(vectorstore_kwargs={\"collection_name\":\"paul_graham\"}).from_loaders([loader]).vectorstore" ] }, { "cell_type": "code", "execution_count": 20, "id": "ec0aab02", "metadata": {}, "outputs": [], "source": [ "chain_pg = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), chain_type=\"stuff\", retriever=vectorstore_pg.as_retriever(), input_key=\"question\")\n" ] }, { "cell_type": "markdown", "id": "76b5f8fb", "metadata": {}, "source": [ "We can now set up an agent to route between them." ] }, { "cell_type": "code", "execution_count": 22, "id": "ade1aafa", "metadata": {}, "outputs": [], "source": [ "from langchain.agents import initialize_agent, Tool\n", "from langchain.agents import AgentType\n", "tools = [\n", " Tool(\n", " name = \"State of Union QA System\",\n", " func=chain_sota.run,\n", " description=\"useful for when you need to answer questions about the most recent state of the union address. Input should be a fully formed question.\"\n", " ),\n", " Tool(\n", " name = \"Paul Graham System\",\n", " func=chain_pg.run,\n", " description=\"useful for when you need to answer questions about Paul Graham. Input should be a fully formed question.\"\n", " ),\n", "]" ] }, { "cell_type": "code", "execution_count": 34, "id": "104853f8", "metadata": {}, "outputs": [], "source": [ "agent = initialize_agent(tools, OpenAI(temperature=0), agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, max_iterations=4)" ] }, { "cell_type": "markdown", "id": "7f036641", "metadata": {}, "source": [ "## Make a prediction\n", "\n", "First, we can make predictions one datapoint at a time. Doing it at this level of granularity allows use to explore the outputs in detail, and also is a lot cheaper than running over multiple datapoints" ] }, { "cell_type": "code", "execution_count": 35, "id": "4664e79f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'The purpose of the NATO Alliance is to secure peace and stability in Europe after World War 2.'" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agent.run(dataset[0]['question'])" ] }, { "cell_type": "markdown", "id": "d0c16cd7", "metadata": {}, "source": [ "## Make many predictions\n", "Now we can make predictions" ] }, { "cell_type": "code", "execution_count": 36, "id": "799f6c17", "metadata": {}, "outputs": [], "source": [ "predictions = []\n", "predicted_dataset = []\n", "error_dataset = []\n", "for data in dataset:\n", " new_data = {\"input\": data[\"question\"], \"answer\": data[\"answer\"]}\n", " try:\n", " predictions.append(agent(new_data))\n", " predicted_dataset.append(new_data)\n", " except Exception:\n", " error_dataset.append(new_data)" ] }, { "cell_type": "markdown", "id": "49d969fb", "metadata": {}, "source": [ "## Evaluate performance\n", "Now we can evaluate the predictions. The first thing we can do is look at them by eye." ] }, { "cell_type": "code", "execution_count": 37, "id": "1d583f03", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input': 'What is the purpose of the NATO Alliance?',\n", " 'answer': 'The purpose of the NATO Alliance is to secure peace and stability in Europe after World War 2.',\n", " 'output': 'The purpose of the NATO Alliance is to secure peace and stability in Europe after World War 2.'}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions[0]" ] }, { "cell_type": "markdown", "id": "4783344b", "metadata": {}, "source": [ "Next, we can use a language model to score them programatically" ] }, { "cell_type": "code", "execution_count": 38, "id": "d0a9341d", "metadata": {}, "outputs": [], "source": [ "from langchain.evaluation.qa import QAEvalChain" ] }, { "cell_type": "code", "execution_count": 39, "id": "1612dec1", "metadata": {}, "outputs": [], "source": [ "llm = OpenAI(temperature=0)\n", "eval_chain = QAEvalChain.from_llm(llm)\n", "graded_outputs = eval_chain.evaluate(predicted_dataset, predictions, question_key=\"input\", prediction_key=\"output\")" ] }, { "cell_type": "markdown", "id": "79587806", "metadata": {}, "source": [ "We can add in the graded output to the `predictions` dict and then get a count of the grades." ] }, { "cell_type": "code", "execution_count": 40, "id": "2a689df5", "metadata": {}, "outputs": [], "source": [ "for i, prediction in enumerate(predictions):\n", " prediction['grade'] = graded_outputs[i]['text']" ] }, { "cell_type": "code", "execution_count": 41, "id": "27b61215", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Counter({' CORRECT': 28, ' INCORRECT': 5})" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from collections import Counter\n", "Counter([pred['grade'] for pred in predictions])" ] }, { "cell_type": "markdown", "id": "12fe30f4", "metadata": {}, "source": [ "We can also filter the datapoints to the incorrect examples and look at them." ] }, { "cell_type": "code", "execution_count": 42, "id": "47c692a1", "metadata": {}, "outputs": [], "source": [ "incorrect = [pred for pred in predictions if pred['grade'] == \" INCORRECT\"]" ] }, { "cell_type": "code", "execution_count": 43, "id": "0ef976c1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input': 'What are the four common sense steps that the author suggests to move forward safely?',\n", " 'answer': 'The four common sense steps suggested by the author to move forward safely are: stay protected with vaccines and treatments, prepare for new variants, end the shutdown of schools and businesses, and stay vigilant.',\n", " 'output': 'The four common sense steps suggested in the most recent State of the Union address are: cutting the cost of prescription drugs, providing a pathway to citizenship for Dreamers, revising laws so businesses have the workers they need and families don’t wait decades to reunite, and protecting access to health care and preserving a woman’s right to choose.',\n", " 'grade': ' INCORRECT'}" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "incorrect[0]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 }