mirror of
https://github.com/GammaTauAI/reflexion-human-eval
synced 2024-11-13 13:10:26 +00:00
120 lines
4.2 KiB
Python
120 lines
4.2 KiB
Python
import os
|
|
import json
|
|
import argparse
|
|
|
|
from alfworld_trial import run_trial
|
|
from generate_reflections import update_memory
|
|
|
|
from typing import Any, List, Dict
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--num_trials", type=int, help="The number of trials to run")
|
|
parser.add_argument("--num_envs", type=int, help="The number of environments per trial")
|
|
parser.add_argument("--run_name", type=str, help="The name of the run")
|
|
parser.add_argument("--use_memory", action='store_true', help="Allow the Agent to use memory")
|
|
parser.add_argument("--is_resume", action='store_true', help="To resume run")
|
|
parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="")
|
|
parser.add_argument("--start_trial_num", type=int, help="If resume, the start trial num", default=0)
|
|
parser.add_argument("--model", type=str, help="The model to use. One of `gpt-4`, `gpt-3.5-turbo`, or `text-davinci-003")
|
|
|
|
args = parser.parse_args()
|
|
|
|
assert args.num_trials > 0, "Number of trials should be positive"
|
|
assert args.num_envs > 0, "Number of environments should be positive"
|
|
|
|
return args
|
|
|
|
def main(args) -> None:
|
|
if args.is_resume:
|
|
if not os.path.exists(args.resume_dir):
|
|
raise ValueError(f"Resume directory `{args.resume_dir}` does not exist")
|
|
logging_dir = args.resume_dir
|
|
|
|
# load environment configs
|
|
env_config_path: str = os.path.join(args.resume_dir, f'env_results_trial_{args.start_trial_num - 1}.json')
|
|
if not os.path.exists(env_config_path):
|
|
raise ValueError(f"Environment config file `{env_config_path}` does not exist")
|
|
with open(env_config_path, 'r') as rf:
|
|
env_configs: List[Dict[str, Any]] = json.load(rf)
|
|
else:
|
|
# Create the run directory
|
|
if not os.path.exists(args.run_name):
|
|
os.makedirs(args.run_name)
|
|
logging_dir = args.run_name
|
|
|
|
# initialize environment configs
|
|
env_configs: List[Dict[str, Any]] = []
|
|
for i in range(args.num_envs):
|
|
env_configs += [{
|
|
'name': f'env_{i}',
|
|
'memory': [],
|
|
'is_success': False,
|
|
'skip': False
|
|
}]
|
|
|
|
world_log_path: str = os.path.join(logging_dir, 'world.log')
|
|
|
|
# print start status to user
|
|
if args.is_resume:
|
|
print(f"""
|
|
-----
|
|
Resuming run with the following parameters:
|
|
Run name: {logging_dir}
|
|
Number of trials: {args.num_trials}
|
|
Number of environments: {args.num_envs}
|
|
Use memory: {args.use_memory}
|
|
Resume trial number: {args.start_trial_num}
|
|
|
|
Sending all logs to `{args.run_name}`
|
|
-----
|
|
""")
|
|
else:
|
|
print(f"""
|
|
-----
|
|
Starting run with the following parameters:
|
|
Run name: {logging_dir}
|
|
Number of trials: {args.num_trials}
|
|
Number of environments: {args.num_envs}
|
|
Use memory: {args.use_memory}
|
|
|
|
Sending all logs to `{args.run_name}`
|
|
-----
|
|
""")
|
|
|
|
# run trials
|
|
trial_idx = args.start_trial_num
|
|
while trial_idx < args.num_trials:
|
|
with open(world_log_path, 'a') as wf:
|
|
wf.write(f'\n\n***** Start Trial #{trial_idx} *****\n\n')
|
|
|
|
# set paths to log files
|
|
trial_log_path: str = os.path.join(args.run_name, f'trial_{trial_idx}.log')
|
|
trial_env_configs_log_path: str = os.path.join(args.run_name, f'env_results_trial_{trial_idx}.json')
|
|
if os.path.exists(trial_log_path):
|
|
open(trial_log_path, 'w').close()
|
|
if os.path.exists(trial_env_configs_log_path):
|
|
open(trial_env_configs_log_path, 'w').close()
|
|
|
|
# run trial
|
|
run_trial(trial_log_path, world_log_path, trial_idx, env_configs, args.use_memory, args.model)
|
|
|
|
# update memory if needed
|
|
if args.use_memory:
|
|
env_configs: List[Dict[str, Any]] = update_memory(trial_log_path, env_configs)
|
|
|
|
# log env configs for trial
|
|
with open(trial_env_configs_log_path, 'w') as wf:
|
|
json.dump(env_configs, wf, indent=4)
|
|
|
|
# log world for trial
|
|
with open(world_log_path, 'a') as wf:
|
|
wf.write(f'\n\n***** End Trial #{trial_idx} *****\n\n')
|
|
|
|
trial_idx += 1
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args()
|
|
main(args)
|