mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
10dab053b4
This pull request adds an enum class for the various types of agents used in the project, located in the `agent_types.py` file. Currently, the project is using hardcoded strings for the initialization of these agents, which can lead to errors and make the code harder to maintain. With the introduction of the new enums, the code will be more readable and less error-prone. The new enum members include: - ZERO_SHOT_REACT_DESCRIPTION - REACT_DOCSTORE - SELF_ASK_WITH_SEARCH - CONVERSATIONAL_REACT_DESCRIPTION - CHAT_ZERO_SHOT_REACT_DESCRIPTION - CHAT_CONVERSATIONAL_REACT_DESCRIPTION In this PR, I have also replaced the hardcoded strings with the appropriate enum members throughout the codebase, ensuring a smooth transition to the new approach.
505 lines
12 KiB
Plaintext
505 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "984169ca",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Agent VectorDB Question Answering Benchmarking\n",
|
|
"\n",
|
|
"Here we go over how to benchmark performance on a question answering task using an agent to route between multiple vectordatabases.\n",
|
|
"\n",
|
|
"It is highly reccomended that you do any evaluation/benchmarking with tracing enabled. See [here](https://langchain.readthedocs.io/en/latest/tracing.html) for an explanation of what tracing is and how to set it up."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "7b57a50f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Comment this out if you are NOT using tracing\n",
|
|
"import os\n",
|
|
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8a16b75d",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Loading the data\n",
|
|
"First, let's load the data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "5b2d5e98",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Found cached dataset json (/Users/harrisonchase/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-vectordb-qa-sota-pg-d3ae24016b514f92/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "4c389519842e4b65afc33006a531dcbc",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/1 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain.evaluation.loading import load_dataset\n",
|
|
"dataset = load_dataset(\"agent-vectordb-qa-sota-pg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "61375342",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'question': 'What is the purpose of the NATO Alliance?',\n",
|
|
" 'answer': 'The purpose of the NATO Alliance is to secure peace and stability in Europe after World War 2.',\n",
|
|
" 'steps': [{'tool': 'State of Union QA System', 'tool_input': None},\n",
|
|
" {'tool': None, 'tool_input': 'What is the purpose of the NATO Alliance?'}]}"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dataset[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "02500304",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'question': 'What is the purpose of YC?',\n",
|
|
" 'answer': 'The purpose of YC is to cause startups to be founded that would not otherwise have existed.',\n",
|
|
" 'steps': [{'tool': 'Paul Graham QA System', 'tool_input': None},\n",
|
|
" {'tool': None, 'tool_input': 'What is the purpose of YC?'}]}"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dataset[-1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4ab6a716",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Setting up a chain\n",
|
|
"Now we need to create some pipelines for doing question answering. Step one in that is creating indexes over the data in question."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "c18680b5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.document_loaders import TextLoader\n",
|
|
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "7f0de2b3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.indexes import VectorstoreIndexCreator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "ef84ff99",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Running Chroma using direct local API.\n",
|
|
"Using DuckDB in-memory for database. Data will be transient.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"vectorstore_sota = VectorstoreIndexCreator(vectorstore_kwargs={\"collection_name\":\"sota\"}).from_loaders([loader]).vectorstore"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f0b5d8f6",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we can create a question answering chain."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "8843cb0c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.chains import RetrievalQA\n",
|
|
"from langchain.llms import OpenAI"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "573719a0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chain_sota = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), chain_type=\"stuff\", retriever=vectorstore_sota, input_key=\"question\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e48b03d8",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we do the same for the Paul Graham data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "c2dbb014",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"loader = TextLoader(\"../../modules/paul_graham_essay.txt\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "98d16f08",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Running Chroma using direct local API.\n",
|
|
"Using DuckDB in-memory for database. Data will be transient.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"vectorstore_pg = VectorstoreIndexCreator(vectorstore_kwargs={\"collection_name\":\"paul_graham\"}).from_loaders([loader]).vectorstore"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "ec0aab02",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chain_pg = RetrievalQA.from_chain_type(llm=OpenAI(temperature=0), chain_type=\"stuff\", retriever=vectorstore_pg, input_key=\"question\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "76b5f8fb",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can now set up an agent to route between them."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "ade1aafa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents import initialize_agent, Tool\n",
|
|
"from langchain.agents.agent_types import AgentType\n",
|
|
"tools = [\n",
|
|
" Tool(\n",
|
|
" name = \"State of Union QA System\",\n",
|
|
" func=chain_sota.run,\n",
|
|
" description=\"useful for when you need to answer questions about the most recent state of the union address. Input should be a fully formed question.\"\n",
|
|
" ),\n",
|
|
" Tool(\n",
|
|
" name = \"Paul Graham System\",\n",
|
|
" func=chain_pg.run,\n",
|
|
" description=\"useful for when you need to answer questions about Paul Graham. Input should be a fully formed question.\"\n",
|
|
" ),\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "104853f8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent = initialize_agent(tools, OpenAI(temperature=0), agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, max_iterations=3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7f036641",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Make a prediction\n",
|
|
"\n",
|
|
"First, we can make predictions one datapoint at a time. Doing it at this level of granularity allows use to explore the outputs in detail, and also is a lot cheaper than running over multiple datapoints"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "4664e79f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'The purpose of the NATO Alliance is to promote peace and security in the North Atlantic region by providing a collective defense against potential threats.'"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent.run(dataset[0]['question'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d0c16cd7",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Make many predictions\n",
|
|
"Now we can make predictions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "799f6c17",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"predictions = []\n",
|
|
"predicted_dataset = []\n",
|
|
"error_dataset = []\n",
|
|
"for data in dataset:\n",
|
|
" new_data = {\"input\": data[\"question\"], \"answer\": data[\"answer\"]}\n",
|
|
" try:\n",
|
|
" predictions.append(agent(new_data))\n",
|
|
" predicted_dataset.append(new_data)\n",
|
|
" except Exception:\n",
|
|
" error_dataset.append(new_data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "49d969fb",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Evaluate performance\n",
|
|
"Now we can evaluate the predictions. The first thing we can do is look at them by eye."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1d583f03",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"predictions[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4783344b",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next, we can use a language model to score them programatically"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d0a9341d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.evaluation.qa import QAEvalChain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"id": "1612dec1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"llm = OpenAI(temperature=0)\n",
|
|
"eval_chain = QAEvalChain.from_llm(llm)\n",
|
|
"graded_outputs = eval_chain.evaluate(predicted_dataset, predictions, question_key=\"input\", prediction_key=\"output\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "79587806",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can add in the graded output to the `predictions` dict and then get a count of the grades."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"id": "2a689df5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for i, prediction in enumerate(predictions):\n",
|
|
" prediction['grade'] = graded_outputs[i]['text']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"id": "27b61215",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Counter({' CORRECT': 19, ' INCORRECT': 14})"
|
|
]
|
|
},
|
|
"execution_count": 42,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from collections import Counter\n",
|
|
"Counter([pred['grade'] for pred in predictions])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "12fe30f4",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also filter the datapoints to the incorrect examples and look at them."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"id": "47c692a1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"incorrect = [pred for pred in predictions if pred['grade'] == \" INCORRECT\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"id": "0ef976c1",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': 'What is the purpose of the Bipartisan Innovation Act mentioned in the text?',\n",
|
|
" 'answer': 'The Bipartisan Innovation Act will make record investments in emerging technologies and American manufacturing to level the playing field with China and other competitors.',\n",
|
|
" 'output': 'The purpose of the Bipartisan Innovation Act is to promote innovation and entrepreneurship in the United States by providing tax incentives and other support for startups and small businesses.',\n",
|
|
" 'grade': ' INCORRECT'}"
|
|
]
|
|
},
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"incorrect[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7710401a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|