You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
reflexion-human-eval/webshop_runs/main.py

119 lines
4.1 KiB
Python

import os
import json
import argparse
from webshop_trial import run_trial
from generate_reflections import update_memory
from typing import Any, List, Dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--num_trials", type=int, help="The number of trials to run")
parser.add_argument("--num_envs", type=int, help="The number of environments per trial")
parser.add_argument("--run_name", type=str, help="The name of the run")
parser.add_argument("--use_memory", action='store_true', help="Allow the Agent to use memory")
parser.add_argument("--is_resume", action='store_true', help="To resume run")
parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="")
parser.add_argument("--start_trial_num", type=int, help="If resume, the start trial num", default=0)
args = parser.parse_args()
assert args.num_trials > 0, "Number of trials should be positive"
assert args.num_envs > 0, "Number of environments should be positive"
return args
def main(args) -> None:
if args.is_resume:
if not os.path.exists(args.resume_dir):
raise ValueError(f"Resume directory `{args.resume_dir}` does not exist")
logging_dir = args.resume_dir
# load environment configs
env_config_path: str = os.path.join(args.resume_dir, f'env_results_trial_{args.start_trial_num - 1}.json')
if not os.path.exists(env_config_path):
raise ValueError(f"Environment config file `{env_config_path}` does not exist")
with open(env_config_path, 'r') as rf:
env_configs: List[Dict[str, Any]] = json.load(rf)
else:
# Create the run directory
if not os.path.exists(args.run_name):
os.makedirs(args.run_name)
logging_dir = args.run_name
# initialize environment configs
env_configs: List[Dict[str, Any]] = []
for i in range(args.num_envs):
env_configs += [{
'name': f'env_{i}',
'memory': [],
'is_success': False
}]
world_log_path: str = os.path.join(logging_dir, 'world.log')
# print start status to user
if args.is_resume:
print(f"""
-----
Resuming run with the following parameters:
Run name: {logging_dir}
Number of trials: {args.num_trials}
Number of environments: {args.num_envs}
Use memory: {args.use_memory}
Resume trial number: {args.start_trial_num}
Sending all logs to `{args.run_name}`
-----
""")
else:
print(f"""
-----
Starting run with the following parameters:
Run name: {logging_dir}
Number of trials: {args.num_trials}
Number of environments: {args.num_envs}
Use memory: {args.use_memory}
Sending all logs to `{args.run_name}`
-----
""")
# run trials
trial_idx = args.start_trial_num
while trial_idx < args.num_trials:
with open(world_log_path, 'a') as wf:
wf.write(f'\n\n***** Start Trial #{trial_idx} *****\n\n')
# set paths to log files
trial_log_path: str = os.path.join(args.run_name, f'trial_{trial_idx}.log')
trial_env_configs_log_path: str = os.path.join(args.run_name, f'env_results_trial_{trial_idx}.json')
if os.path.exists(trial_log_path):
open(trial_log_path, 'w').close()
if os.path.exists(trial_env_configs_log_path):
open(trial_env_configs_log_path, 'w').close()
# run trial
run_trial(trial_log_path, world_log_path, trial_idx, env_configs, args.use_memory)
# update memory if needed
if args.use_memory:
env_configs: List[Dict[str, Any]] = update_memory(trial_log_path, env_configs)
# log env configs for trial
with open(trial_env_configs_log_path, 'w') as wf:
json.dump(env_configs, wf, indent=4)
# log world for trial
with open(world_log_path, 'a') as wf:
wf.write(f'\n\n***** End Trial #{trial_idx} *****\n\n')
trial_idx += 1
if __name__ == '__main__':
args = get_args()
main(args)