From 8609d1867854299c36d30808c44c121fc5113cd8 Mon Sep 17 00:00:00 2001 From: elleven11 Date: Mon, 3 Apr 2023 16:57:35 -0400 Subject: [PATCH] resume and optz --- executors/rs_executor.py | 28 +++++++++++++++++++++++----- generate_dataset.py | 1 + main.py | 3 --- reflexion.py | 4 ++-- reflexion_ucs.py | 4 ++-- simple.py | 4 ++-- utils.py | 23 +++++++++++++++++++++++ 7 files changed, 53 insertions(+), 14 deletions(-) diff --git a/executors/rs_executor.py b/executors/rs_executor.py index 21cfbab..613e91a 100644 --- a/executors/rs_executor.py +++ b/executors/rs_executor.py @@ -176,27 +176,35 @@ class RsExecutor(Executor): TODO: do it actually """ tmp_dir, tmp_path = create_temp_project() + print(f"Evaluating\n{func + test}", flush=True) write_to_file_toplevel(tmp_path, func + test) res = run_with_timeout( - "cargo check --message-format=json", tmp_dir, timeout=timeout) + "cargo check --message-format=json", tmp_dir, timeout=timeout, print_debug=True) assert res is not None, "Timeout in cargo check, wow" errs = grab_compile_errs(res[0]) # (check returns stdin) if len(errs) > 0: # cleanup the temp directory os.system(f"rm -rf {tmp_dir}") + print("Compile errors. Failed eval", flush=True) return False # compile and run the binary res = run_with_timeout("cargo run", tmp_dir, - timeout=timeout, print_debug=False) + timeout=timeout, print_debug=True) os.system(f"rm -rf {tmp_dir}") if res is None: + print("Timeout?. Failed eval", flush=True) return False else: errs = grab_runtime_errs(res[1]) + if len(errs) > 0: + print("Runtime errors. Failed eval", flush=True) + return False + + print("Passed eval", flush=True) return len(errs) == 0 @@ -295,16 +303,26 @@ def grab_runtime_errs(inp: str) -> List[RuntimeErr]: curr_left = None panic_reason = None for line in split: - if "panicked at" in line: + if "fatal runtime" in line: + # we have a panic + panic_idx = line.index("fatal runtime") + panic_reason = line[panic_idx + len("fatal runtime") + 1:] + elif "panicked at" in line: panic_idx = line.index("panicked at") # strip source line if it exists if "src/main.rs" in line: line = line[:line.index("src/main.rs")] panic_reason = line[panic_idx + len("panicked at") + 1:] elif "left:" in line: - curr_left = line.split("`")[1] + split = line.split("`") + if len(split) < 2: + continue + curr_left = split[1] elif "right:" in line: - curr_right = line.split("`")[1] + split = line.split("`") + if len(split) < 2: + continue + curr_right = split[1] # get the line and column number fileinto = line.split(",")[-1] line = int(fileinto.split(":")[1]) diff --git a/generate_dataset.py b/generate_dataset.py index 91de561..5dc5dde 100644 --- a/generate_dataset.py +++ b/generate_dataset.py @@ -16,6 +16,7 @@ def download_dataset(dataset_name: str): print(entry) item["entry_point"] = entry item["test"] = item["tests"] + item["test"] = item["test"][1:] # there is some garbage at the start del item["tests"] final.append(item) diff --git a/main.py b/main.py index 355561d..5d2b248 100644 --- a/main.py +++ b/main.py @@ -48,9 +48,6 @@ def main(args): log_dir, f"{dataset_name}_{args.strategy}_{args.max_iters}_{args.model}_pass_at_k_{args.pass_at_k}_{args.language}.jsonl") if not os.path.exists(log_dir): os.makedirs(log_dir) - if os.path.exists(log_path): - raise ValueError( - f"Log path `{log_path}` already exists in `{log_dir}`") # check if the strategy is valid if args.strategy not in ["simple", "reflexion", "reflexion-ucs"]: diff --git a/reflexion.py b/reflexion.py index bf84788..a5c16cc 100644 --- a/reflexion.py +++ b/reflexion.py @@ -1,4 +1,4 @@ -from utils import write_jsonl +from utils import enumerate_resume, write_jsonl from executors import executor_factory from generators import generator_factory @@ -19,7 +19,7 @@ def run_reflexion( num_items = len(dataset) num_success = 0 - for i, item in enumerate(dataset): + for i, item in enumerate_resume(dataset, log_path): cur_pass = 0 is_solved = False reflections = [] diff --git a/reflexion_ucs.py b/reflexion_ucs.py index d593e63..2abdbb6 100644 --- a/reflexion_ucs.py +++ b/reflexion_ucs.py @@ -1,6 +1,6 @@ import warnings from lazzzy.ucs import ucs -from utils import write_jsonl +from utils import enumerate_resume, write_jsonl from executors import executor_factory from generators import generator_factory @@ -52,7 +52,7 @@ def run_reflexion_ucs( num_items = len(dataset) num_success = 0 - for i, item in enumerate(dataset): + for i, item in enumerate_resume(dataset, log_path): cur_pass = 0 is_solved = False reflections = [] diff --git a/simple.py b/simple.py index 17e938f..2610c35 100644 --- a/simple.py +++ b/simple.py @@ -1,4 +1,4 @@ -from utils import write_jsonl +from utils import enumerate_resume, write_jsonl from executors import executor_factory from generators import generator_factory @@ -20,7 +20,7 @@ def run_simple( num_items = len(dataset) num_success = 0 - for i, item in enumerate(dataset): + for i, item in enumerate_resume(dataset, log_path): cur_pass = 0 is_solved = False cur_func_impl = "" diff --git a/utils.py b/utils.py index 220940f..4ae633b 100644 --- a/utils.py +++ b/utils.py @@ -8,6 +8,7 @@ from typing import List openai.api_key = os.getenv("OPENAI_API_KEY") + def read_jsonl(path: str) -> List[dict]: if not os.path.exists(path): raise FileNotFoundError(f"File `{path}` does not exist.") @@ -19,14 +20,36 @@ def read_jsonl(path: str) -> List[dict]: items += [item] return items + def write_jsonl(path: str, data: List[dict], append: bool = False): with jsonlines.open(path, mode='a' if append else 'w') as writer: for item in data: writer.write(item) + def read_jsonl_gz(path: str) -> List[dict]: if not path.endswith(".jsonl.gz"): raise ValueError(f"File `{path}` is not a jsonl.gz file.") with gzip.open(path, "rt") as f: data = [json.loads(line) for line in f] return data + + +# generator that returns the item and the index in the dataset. +# if the results_path exists, it will skip all items that have been processed +# before. +def enumerate_resume(dataset, results_path): + if not os.path.exists(results_path): + for i, item in enumerate(dataset): + yield i, item + else: + count = 0 + with jsonlines.open(results_path) as reader: + for item in reader: + count += 1 + + for i, item in enumerate(dataset): + # skip items that have been processed before + if i < count: + continue + yield i, item