import os import json import argparse from webshop_trial import run_trial from generate_reflections import update_memory from typing import Any, List, Dict def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--num_trials", type=int, help="The number of trials to run") parser.add_argument("--num_envs", type=int, help="The number of environments per trial") parser.add_argument("--run_name", type=str, help="The name of the run") parser.add_argument("--use_memory", action='store_true', help="Allow the Agent to use memory") parser.add_argument("--is_resume", action='store_true', help="To resume run") parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="") parser.add_argument("--start_trial_num", type=int, help="If resume, the start trial num", default=0) args = parser.parse_args() assert args.num_trials > 0, "Number of trials should be positive" assert args.num_envs > 0, "Number of environments should be positive" return args def main(args) -> None: if args.is_resume: if not os.path.exists(args.resume_dir): raise ValueError(f"Resume directory `{args.resume_dir}` does not exist") logging_dir = args.resume_dir # load environment configs env_config_path: str = os.path.join(args.resume_dir, f'env_results_trial_{args.start_trial_num - 1}.json') if not os.path.exists(env_config_path): raise ValueError(f"Environment config file `{env_config_path}` does not exist") with open(env_config_path, 'r') as rf: env_configs: List[Dict[str, Any]] = json.load(rf) else: # Create the run directory if not os.path.exists(args.run_name): os.makedirs(args.run_name) logging_dir = args.run_name # initialize environment configs env_configs: List[Dict[str, Any]] = [] for i in range(args.num_envs): env_configs += [{ 'name': f'env_{i}', 'memory': [], 'is_success': False }] world_log_path: str = os.path.join(logging_dir, 'world.log') # print start status to user if args.is_resume: print(f""" ----- Resuming run with the following parameters: Run name: {logging_dir} Number of trials: {args.num_trials} Number of environments: {args.num_envs} Use memory: {args.use_memory} Resume trial number: {args.start_trial_num} Sending all logs to `{args.run_name}` ----- """) else: print(f""" ----- Starting run with the following parameters: Run name: {logging_dir} Number of trials: {args.num_trials} Number of environments: {args.num_envs} Use memory: {args.use_memory} Sending all logs to `{args.run_name}` ----- """) # run trials trial_idx = args.start_trial_num while trial_idx < args.num_trials: with open(world_log_path, 'a') as wf: wf.write(f'\n\n***** Start Trial #{trial_idx} *****\n\n') # set paths to log files trial_log_path: str = os.path.join(args.run_name, f'trial_{trial_idx}.log') trial_env_configs_log_path: str = os.path.join(args.run_name, f'env_results_trial_{trial_idx}.json') if os.path.exists(trial_log_path): open(trial_log_path, 'w').close() if os.path.exists(trial_env_configs_log_path): open(trial_env_configs_log_path, 'w').close() # run trial run_trial(trial_log_path, world_log_path, trial_idx, env_configs, args.use_memory) # update memory if needed if args.use_memory: env_configs: List[Dict[str, Any]] = update_memory(trial_log_path, env_configs) # log env configs for trial with open(trial_env_configs_log_path, 'w') as wf: json.dump(env_configs, wf, indent=4) # log world for trial with open(world_log_path, 'a') as wf: wf.write(f'\n\n***** End Trial #{trial_idx} *****\n\n') trial_idx += 1 if __name__ == '__main__': args = get_args() main(args)