#### Notebook for running Chain-of-Thought with supporting context experiments 

In [1]:
import sys, os
sys.path.append('..')
root = '../root/'

In [3]:
import joblib
import numpy as np
from agents import CoTAgent, ReflexionStrategy
from util import summarize_trial, log_trial, save_agents

#### Load the HotPotQA Sample

In [4]:
hotpot = joblib.load('../data/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)

hotpot['supporting_paragraphs'] = None
for ind, row in hotpot.iterrows():
 supporting_articles = row['supporting_facts']['title']
 articles = row['context']['title']
 sentences = row['context']['sentences'] 
 supporting_paragraphs = []
 for article in supporting_articles:
 supporting_paragraph = ''.join(sentences[np.where(articles == article)][0])
 supporting_paragraphs.append(supporting_paragraph)
 supporting_paragraphs = '\n\n'.join(supporting_paragraphs)
 hotpot.at[ind, 'supporting_paragraphs'] = supporting_paragraphs

#### Define the Reflexion Strategy

In [5]:
print(ReflexionStrategy.__doc__)


 NONE: No reflection
 LAST_ATTEMPT: Use last reasoning trace in context 
 REFLEXION: Apply reflexion to the next reasoning trace 
 LAST_ATTEMPT_AND_REFLEXION: Use last reasoning trace in context and apply reflexion to the next reasoning trace 
 


In [6]:
strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION

#### Initialize a CoTAgent for each question

In [7]:
from prompts import cot_agent_prompt, cot_reflect_agent_prompt, cot_reflect_prompt
from fewshots import COT, COT_REFLECT
agents = [CoTAgent(row['question'],
 row['supporting_paragraphs'],
 row['answer'],
 agent_prompt=cot_agent_prompt if strategy == ReflexionStrategy.NONE else cot_reflect_agent_prompt,
 cot_examples=COT,
 reflect_prompt=cot_reflect_prompt,
 reflect_examples=COT_REFLECT,
 ) for _, row in hotpot.iterrows()]

#### Run `n` trials

In [8]:
n = 5
trial = 0
log = ''

In [9]:
for i in range(n):
 for agent in [a for a in agents if not a.is_correct()]:
 agent.run(reflexion_strategy = strategy)
 print(f'Answer: {agent.key}')
 trial += 1
 log += log_trial(agents, trial)
 correct, incorrect = summarize_trial(agents)
 print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}')

ValidationError: 1 validation error for HumanMessage
content
 field required (type=value_error.missing)

#### Save the result log

In [27]:
with open(os.path.join(root, 'CoT', 'context', strategy.value, f'{len(agents)}_questions_{trial}_trials.txt'), 'w') as f:
 f.write(log)
save_agents(agents, os.path.join(root, 'CoT', 'context', strategy.value, 'agents'))