Organize notebooks

pull/13/head
Beck LaBash 1 year ago
parent 4924ce40f2
commit e531a5c0d6

File diff suppressed because one or more lines are too long

@ -1,84 +0,0 @@
import joblib
from react_cls import CoTAgent
from mocks import DocStoreExplorerMock, LLMMock
import numpy as np
def summarize_trial(agents):
correct = [a for a in agents if a.is_correct()]
incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]
return correct, incorrect
def log_trial(agents, trial_n):
correct, incorrect = summarize_trial(agents)
log = f"""
########################################
BEGIN TRIAL {trial_n}
Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}
#######################################
"""
log += '------------- BEGIN CORRECT AGENTS -------------\n\n'
for agent in correct:
log += f'Context: {agent.context} Question: {agent.question}{agent.scratchpad}\nCorrect answer: {agent.key}\n\n'
log += '------------- BEGIN INCORRECT AGENTS -----------\n\n'
for agent in incorrect:
log += f'Context: {agent.context} Question: {agent.question}{agent.scratchpad}\nCorrect answer: {agent.key}\n\n'
return log
if __name__ == '__main__':
hotpot = joblib.load('data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)
hotpot['supporting_paragraphs'] = None
for ind, row in hotpot.iterrows():
supporting_articles = row['supporting_facts']['title']
articles = row['context']['title']
sentences = row['context']['sentences']
supporting_paragraphs = []
for article in supporting_articles:
supporting_paragraph = ''.join(sentences[np.where(articles == article)][0])
supporting_paragraphs.append(supporting_paragraph)
hotpot.at[ind, 'supporting_paragraphs'] = supporting_paragraphs
for ind, row in hotpot.iterrows():
supporting_paragraphs = row['supporting_paragraphs']
supporting_paragraphs = '\n\n'.join(supporting_paragraphs)
hotpot.at[ind, 'supporting_paragraphs'] = supporting_paragraphs
agents = [CoTAgent(row['question'], row['supporting_paragraphs'], row['answer']) for _, row in hotpot.iterrows()]
trial = 0
log = ''
for agent in [a for a in agents if not a.is_correct()]:
agent.run(reflect = False)
print(f'Answer: {agent.key}')
trial += 1
log += log_trial(agents, trial)
correct, incorrect = summarize_trial(agents)
print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')
dicts = [dict(a.__dict__) for a in agents]
for d in dicts:
for k, v in d.items():
d[k] = str(v)
joblib.dump(dicts, 'output/base_cot/cot_reflect_50_correct_dicts-8-trials.joblib')
print(log)
with open('output/base_cot/100_questions_8_trials.txt', 'w') as f:
f.write(log)
trial = 0
log = ''
q = 0
agents_to_run = [a for a in agents if not a.is_correct()]
while q < len(agents_to_run):
print(f'Trial: {trial} ({q}/{len(agents_to_run)})')
agents_to_run[q].run()
q += 1
trial += 1
log += log_trial(agents, trial)
correct, incorrect, halted = summarize_trial(agents)
print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)}')

@ -1,52 +1,42 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Notebook for running CoT with context + Reflexion"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import joblib\n",
"from react_cls import CoTAgent\n",
"from mocks import DocStoreExplorerMock, LLMMock\n",
"import numpy as np"
"import sys, os\n",
"sys.path.append('../../')"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def summarize_trial(agents):\n",
" correct = [a for a in agents if a.is_correct()]\n",
" incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]\n",
" return correct, incorrect\n",
"\n",
"def remove_fewshot(prompt: str) -> str:\n",
" prefix = prompt.split('Here are some examples:')[0]\n",
" suffix = prompt.split('(END OF EXAMPLES)')[1]\n",
" return prefix.strip('\\n').strip() + '\\n' + suffix.strip('\\n').strip()\n",
"\n",
"def log_trial(agents, trial_n):\n",
" correct, incorrect = summarize_trial(agents)\n",
"\n",
" log = f\"\"\"\n",
"########################################\n",
"BEGIN TRIAL {trial_n}\n",
"Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}\n",
"#######################################\n",
"\"\"\"\n",
"\n",
" log += '------------- BEGIN CORRECT AGENTS -------------\\n\\n'\n",
" for agent in correct:\n",
" log += remove_fewshot(agent._build_agent_prompt()) + f'\\nCorrect answer: {agent.key}\\n\\n'\n",
"\n",
" log += '------------- BEGIN INCORRECT AGENTS -----------\\n\\n'\n",
" for agent in incorrect:\n",
" log += remove_fewshot(agent._build_agent_prompt()) + f'\\nCorrect answer: {agent.key}\\n\\n'\n",
"\n",
" return log"
"imp"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import joblib\n",
"from react_cls import CoTAgent\n",
"from mocks import DocStoreExplorerMock, LLMMock\n",
"import numpy as np"
]
},
{

@ -9,38 +9,8 @@
"import joblib\n",
"from react_cls import CoTAgent\n",
"from mocks import DocStoreExplorerMock, LLMMock\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def summarize_trial(agents):\n",
" correct = [a for a in agents if a.is_correct()]\n",
" incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]\n",
" return correct, incorrect\n",
"\n",
"def log_trial(agents, trial_n):\n",
" correct, incorrect = summarize_trial(agents)\n",
"\n",
" log = f\"\"\"\n",
"########################################\n",
"BEGIN TRIAL {trial_n}\n",
"Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}\n",
"#######################################\n",
"\"\"\"\n",
"\n",
" log += '------------- BEGIN CORRECT AGENTS -------------\\n\\n'\n",
" for agent in correct:\n",
" log += f'Context: {agent.context} Question: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
"\n",
" log += '------------- BEGIN INCORRECT AGENTS -----------\\n\\n'\n",
" for agent in incorrect:\n",
" log += f'Context: {agent.context} Question: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
" return log"
"import numpy as np\n",
"from util import summarize_trial, log_trial"
]
},
{

File diff suppressed because it is too large Load Diff

@ -1,11 +1,31 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Notebook for running Chain-of-Thought with no context"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys, os\n",
"sys.path.append('../..')\n",
"root = '../../root/'"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from util import summarize_trial, log_trial, save_agents\n",
"import joblib\n",
"from react_cls import CoTAgent\n",
"from mocks import DocStoreExplorerMock, LLMMock\n",
@ -13,34 +33,11 @@
]
},
{
"cell_type": "code",
"execution_count": 21,
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"def summarize_trial(agents):\n",
" correct = [a for a in agents if a.is_correct()]\n",
" incorrect = [a for a in agents if a.is_finished() and not a.is_correct()]\n",
" return correct, incorrect\n",
"\n",
"def log_trial(agents, trial_n):\n",
" correct, incorrect = summarize_trial(agents)\n",
"\n",
" log = f\"\"\"\n",
"########################################\n",
"BEGIN TRIAL {trial_n}\n",
"Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}\n",
"#######################################\n",
"\"\"\"\n",
"\n",
" log += '------------- BEGIN CORRECT AGENTS -------------\\n\\n'\n",
" for agent in correct:\n",
" log += f'Context: {agent.context}\\nQuestion: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
"\n",
" log += '------------- BEGIN INCORRECT AGENTS -----------\\n\\n'\n",
" for agent in incorrect:\n",
" log += f'Context: {agent.context}\\nQuestion: {agent.question}{agent.scratchpad}\\nCorrect answer: {agent.key}\\n\\n'\n",
" return log"
"#### Load the HotPotQA Sample"
]
},
{
@ -52,6 +49,14 @@
"hotpot = joblib.load('data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Initialize a CoTAgent for each question"
]
},
{
"cell_type": "code",
"execution_count": 23,
@ -67,13 +72,11 @@
]
},
{
"cell_type": "code",
"execution_count": 24,
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"trial = 0\n",
"log = ''"
"#### Run trials"
]
},
{
@ -1201,6 +1204,8 @@
}
],
"source": [
"trial = 0\n",
"log = ''\n",
"for i in range(5):\n",
" for agent in [a for a in agents if not a.is_correct()]:\n",
" agent.run(reflect = False)\n",
@ -1212,93 +1217,23 @@
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"with open('output/base_cot_no_context/100_questions_5_trials.txt', 'w') as f:\n",
" f.write(log)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['output/base_cot_no_context/cot_33_correct_dicts-5-trials.joblib']"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dicts = [dict(a.__dict__) for a in agents]\n",
"for d in dicts:\n",
" for k, v in d.items():\n",
" d[k] = str(v)\n",
"\n",
"joblib.dump(dicts, 'output/base_cot_no_context/cot_33_correct_dicts-5-trials.joblib')"
]
},
{
"cell_type": "code",
"execution_count": 28,
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"with open('output/base_cot/100_questions_8_trials.txt', 'w') as f:\n",
" f.write(log)"
"#### Save the result log"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"dicts = joblib.load('output/base_cot/cot_reflect_50_correct_dicts-8-trials.joblib')"
"with open(root + '/CoT/no_context/last_trial_and_reflexion/100_questions_5_trials.txt', 'w') as f:\n",
" f.write(log)\n",
"save_agents(agents, root + '/CoT/no_context/last_trial_and_reflexion/')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['question', 'context', 'key', 'agent_prompt', 'reflect_prompt', 'cot_examples', 'reflect_examples', 'llm', 'reflections', 'answer', 'step_n', 'scratchpad', 'finished'])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dicts[0].keys()\n",
"for d in dicts:\n",
" agent = CoTAgent(d['question'], d['context'], d['key'])\n",
" agent.reflections = d['reflections']\n",
" agent.scratchpad = d['scratchpad']\n",
" agent.answer = d['answer']\n",
" agent.step_n = d['step_n']\n",
" agent.finished = d['finished']\n",
" agents.append(agent)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save