mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
525 lines
13 KiB
Plaintext
525 lines
13 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "984169ca",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Agent VectorDB Question Answering Benchmarking\n",
|
||
"\n",
|
||
"Here we go over how to benchmark performance on a question answering task using an agent to route between multiple vectordatabases.\n",
|
||
"\n",
|
||
"It is highly recommended that you do any evaluation/benchmarking with tracing enabled. See [here](https://python.langchain.com/guides/tracing/) for an explanation of what tracing is and how to set it up."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "7b57a50f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Comment this out if you are NOT using tracing\n",
|
||
"import os\n",
|
||
"\n",
|
||
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8a16b75d",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Loading the data\n",
|
||
"First, let's load the data."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "5b2d5e98",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Found cached dataset json (/Users/qt/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-vectordb-qa-sota-pg-d3ae24016b514f92/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e)\n",
|
||
"100%|██████████| 1/1 [00:00<00:00, 414.42it/s]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from langchain.evaluation.loading import load_dataset\n",
|
||
"\n",
|
||
"dataset = load_dataset(\"agent-vectordb-qa-sota-pg\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "61375342",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'question': 'What is the purpose of the NATO Alliance?',\n",
|
||
" 'answer': 'The purpose of the NATO Alliance is to secure peace and stability in Europe after World War 2.',\n",
|
||
" 'steps': [{'tool': 'State of Union QA System', 'tool_input': None},\n",
|
||
" {'tool': None, 'tool_input': 'What is the purpose of the NATO Alliance?'}]}"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"dataset[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "02500304",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'question': 'What is the purpose of YC?',\n",
|
||
" 'answer': 'The purpose of YC is to cause startups to be founded that would not otherwise have existed.',\n",
|
||
" 'steps': [{'tool': 'Paul Graham QA System', 'tool_input': None},\n",
|
||
" {'tool': None, 'tool_input': 'What is the purpose of YC?'}]}"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"dataset[-1]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4ab6a716",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Setting up a chain\n",
|
||
"Now we need to create some pipelines for doing question answering. Step one in that is creating indexes over the data in question."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "c18680b5",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.document_loaders import TextLoader\n",
|
||
"\n",
|
||
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "7f0de2b3",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.indexes import VectorstoreIndexCreator"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "ef84ff99",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Using embedded DuckDB without persistence: data will be transient\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"vectorstore_sota = (\n",
|
||
" VectorstoreIndexCreator(vectorstore_kwargs={\"collection_name\": \"sota\"})\n",
|
||
" .from_loaders([loader])\n",
|
||
" .vectorstore\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f0b5d8f6",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now we can create a question answering chain."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "8843cb0c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.chains import RetrievalQA\n",
|
||
"from langchain.llms import OpenAI"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "573719a0",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"chain_sota = RetrievalQA.from_chain_type(\n",
|
||
" llm=OpenAI(temperature=0),\n",
|
||
" chain_type=\"stuff\",\n",
|
||
" retriever=vectorstore_sota.as_retriever(),\n",
|
||
" input_key=\"question\",\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e48b03d8",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now we do the same for the Paul Graham data."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "c2dbb014",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"loader = TextLoader(\"../../modules/paul_graham_essay.txt\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "98d16f08",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Using embedded DuckDB without persistence: data will be transient\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"vectorstore_pg = (\n",
|
||
" VectorstoreIndexCreator(vectorstore_kwargs={\"collection_name\": \"paul_graham\"})\n",
|
||
" .from_loaders([loader])\n",
|
||
" .vectorstore\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"id": "ec0aab02",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"chain_pg = RetrievalQA.from_chain_type(\n",
|
||
" llm=OpenAI(temperature=0),\n",
|
||
" chain_type=\"stuff\",\n",
|
||
" retriever=vectorstore_pg.as_retriever(),\n",
|
||
" input_key=\"question\",\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "76b5f8fb",
|
||
"metadata": {},
|
||
"source": [
|
||
"We can now set up an agent to route between them."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"id": "ade1aafa",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.agents import initialize_agent, Tool\n",
|
||
"from langchain.agents import AgentType\n",
|
||
"\n",
|
||
"tools = [\n",
|
||
" Tool(\n",
|
||
" name=\"State of Union QA System\",\n",
|
||
" func=chain_sota.run,\n",
|
||
" description=\"useful for when you need to answer questions about the most recent state of the union address. Input should be a fully formed question.\",\n",
|
||
" ),\n",
|
||
" Tool(\n",
|
||
" name=\"Paul Graham System\",\n",
|
||
" func=chain_pg.run,\n",
|
||
" description=\"useful for when you need to answer questions about Paul Graham. Input should be a fully formed question.\",\n",
|
||
" ),\n",
|
||
"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"id": "104853f8",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"agent = initialize_agent(\n",
|
||
" tools,\n",
|
||
" OpenAI(temperature=0),\n",
|
||
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||
" max_iterations=4,\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "7f036641",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Make a prediction\n",
|
||
"\n",
|
||
"First, we can make predictions one datapoint at a time. Doing it at this level of granularity allows use to explore the outputs in detail, and also is a lot cheaper than running over multiple datapoints"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"id": "4664e79f",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'The purpose of the NATO Alliance is to secure peace and stability in Europe after World War 2.'"
|
||
]
|
||
},
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"agent.run(dataset[0][\"question\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d0c16cd7",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Make many predictions\n",
|
||
"Now we can make predictions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"id": "799f6c17",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"predictions = []\n",
|
||
"predicted_dataset = []\n",
|
||
"error_dataset = []\n",
|
||
"for data in dataset:\n",
|
||
" new_data = {\"input\": data[\"question\"], \"answer\": data[\"answer\"]}\n",
|
||
" try:\n",
|
||
" predictions.append(agent(new_data))\n",
|
||
" predicted_dataset.append(new_data)\n",
|
||
" except Exception:\n",
|
||
" error_dataset.append(new_data)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "49d969fb",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Evaluate performance\n",
|
||
"Now we can evaluate the predictions. The first thing we can do is look at them by eye."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"id": "1d583f03",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'input': 'What is the purpose of the NATO Alliance?',\n",
|
||
" 'answer': 'The purpose of the NATO Alliance is to secure peace and stability in Europe after World War 2.',\n",
|
||
" 'output': 'The purpose of the NATO Alliance is to secure peace and stability in Europe after World War 2.'}"
|
||
]
|
||
},
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"predictions[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4783344b",
|
||
"metadata": {},
|
||
"source": [
|
||
"Next, we can use a language model to score them programatically"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"id": "d0a9341d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from langchain.evaluation.qa import QAEvalChain"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"id": "1612dec1",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"llm = OpenAI(temperature=0)\n",
|
||
"eval_chain = QAEvalChain.from_llm(llm)\n",
|
||
"graded_outputs = eval_chain.evaluate(\n",
|
||
" predicted_dataset, predictions, question_key=\"input\", prediction_key=\"output\"\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "79587806",
|
||
"metadata": {},
|
||
"source": [
|
||
"We can add in the graded output to the `predictions` dict and then get a count of the grades."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"id": "2a689df5",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for i, prediction in enumerate(predictions):\n",
|
||
" prediction[\"grade\"] = graded_outputs[i][\"text\"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"id": "27b61215",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Counter({' CORRECT': 28, ' INCORRECT': 5})"
|
||
]
|
||
},
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from collections import Counter\n",
|
||
"\n",
|
||
"Counter([pred[\"grade\"] for pred in predictions])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "12fe30f4",
|
||
"metadata": {},
|
||
"source": [
|
||
"We can also filter the datapoints to the incorrect examples and look at them."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 42,
|
||
"id": "47c692a1",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"incorrect = [pred for pred in predictions if pred[\"grade\"] == \" INCORRECT\"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"id": "0ef976c1",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'input': 'What are the four common sense steps that the author suggests to move forward safely?',\n",
|
||
" 'answer': 'The four common sense steps suggested by the author to move forward safely are: stay protected with vaccines and treatments, prepare for new variants, end the shutdown of schools and businesses, and stay vigilant.',\n",
|
||
" 'output': 'The four common sense steps suggested in the most recent State of the Union address are: cutting the cost of prescription drugs, providing a pathway to citizenship for Dreamers, revising laws so businesses have the workers they need and families don’t wait decades to reunite, and protecting access to health care and preserving a woman’s right to choose.',\n",
|
||
" 'grade': ' INCORRECT'}"
|
||
]
|
||
},
|
||
"execution_count": 43,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"incorrect[0]"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.15"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|