resume and optz

pull/8/head
elleven11 2 years ago
parent 4fe89a1439
commit 8609d18678

@ -176,27 +176,35 @@ class RsExecutor(Executor):
TODO: do it actually
"""
tmp_dir, tmp_path = create_temp_project()
print(f"Evaluating\n{func + test}", flush=True)
write_to_file_toplevel(tmp_path, func + test)
res = run_with_timeout(
"cargo check --message-format=json", tmp_dir, timeout=timeout)
"cargo check --message-format=json", tmp_dir, timeout=timeout, print_debug=True)
assert res is not None, "Timeout in cargo check, wow"
errs = grab_compile_errs(res[0]) # (check returns stdin)
if len(errs) > 0:
# cleanup the temp directory
os.system(f"rm -rf {tmp_dir}")
print("Compile errors. Failed eval", flush=True)
return False
# compile and run the binary
res = run_with_timeout("cargo run", tmp_dir,
timeout=timeout, print_debug=False)
timeout=timeout, print_debug=True)
os.system(f"rm -rf {tmp_dir}")
if res is None:
print("Timeout?. Failed eval", flush=True)
return False
else:
errs = grab_runtime_errs(res[1])
if len(errs) > 0:
print("Runtime errors. Failed eval", flush=True)
return False
print("Passed eval", flush=True)
return len(errs) == 0
@ -295,16 +303,26 @@ def grab_runtime_errs(inp: str) -> List[RuntimeErr]:
curr_left = None
panic_reason = None
for line in split:
if "panicked at" in line:
if "fatal runtime" in line:
# we have a panic
panic_idx = line.index("fatal runtime")
panic_reason = line[panic_idx + len("fatal runtime") + 1:]
elif "panicked at" in line:
panic_idx = line.index("panicked at")
# strip source line if it exists
if "src/main.rs" in line:
line = line[:line.index("src/main.rs")]
panic_reason = line[panic_idx + len("panicked at") + 1:]
elif "left:" in line:
curr_left = line.split("`")[1]
split = line.split("`")
if len(split) < 2:
continue
curr_left = split[1]
elif "right:" in line:
curr_right = line.split("`")[1]
split = line.split("`")
if len(split) < 2:
continue
curr_right = split[1]
# get the line and column number
fileinto = line.split(",")[-1]
line = int(fileinto.split(":")[1])

@ -16,6 +16,7 @@ def download_dataset(dataset_name: str):
print(entry)
item["entry_point"] = entry
item["test"] = item["tests"]
item["test"] = item["test"][1:] # there is some garbage at the start
del item["tests"]
final.append(item)

@ -48,9 +48,6 @@ def main(args):
log_dir, f"{dataset_name}_{args.strategy}_{args.max_iters}_{args.model}_pass_at_k_{args.pass_at_k}_{args.language}.jsonl")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if os.path.exists(log_path):
raise ValueError(
f"Log path `{log_path}` already exists in `{log_dir}`")
# check if the strategy is valid
if args.strategy not in ["simple", "reflexion", "reflexion-ucs"]:

@ -1,4 +1,4 @@
from utils import write_jsonl
from utils import enumerate_resume, write_jsonl
from executors import executor_factory
from generators import generator_factory
@ -19,7 +19,7 @@ def run_reflexion(
num_items = len(dataset)
num_success = 0
for i, item in enumerate(dataset):
for i, item in enumerate_resume(dataset, log_path):
cur_pass = 0
is_solved = False
reflections = []

@ -1,6 +1,6 @@
import warnings
from lazzzy.ucs import ucs
from utils import write_jsonl
from utils import enumerate_resume, write_jsonl
from executors import executor_factory
from generators import generator_factory
@ -52,7 +52,7 @@ def run_reflexion_ucs(
num_items = len(dataset)
num_success = 0
for i, item in enumerate(dataset):
for i, item in enumerate_resume(dataset, log_path):
cur_pass = 0
is_solved = False
reflections = []

@ -1,4 +1,4 @@
from utils import write_jsonl
from utils import enumerate_resume, write_jsonl
from executors import executor_factory
from generators import generator_factory
@ -20,7 +20,7 @@ def run_simple(
num_items = len(dataset)
num_success = 0
for i, item in enumerate(dataset):
for i, item in enumerate_resume(dataset, log_path):
cur_pass = 0
is_solved = False
cur_func_impl = ""

@ -8,6 +8,7 @@ from typing import List
openai.api_key = os.getenv("OPENAI_API_KEY")
def read_jsonl(path: str) -> List[dict]:
if not os.path.exists(path):
raise FileNotFoundError(f"File `{path}` does not exist.")
@ -19,14 +20,36 @@ def read_jsonl(path: str) -> List[dict]:
items += [item]
return items
def write_jsonl(path: str, data: List[dict], append: bool = False):
with jsonlines.open(path, mode='a' if append else 'w') as writer:
for item in data:
writer.write(item)
def read_jsonl_gz(path: str) -> List[dict]:
if not path.endswith(".jsonl.gz"):
raise ValueError(f"File `{path}` is not a jsonl.gz file.")
with gzip.open(path, "rt") as f:
data = [json.loads(line) for line in f]
return data
# generator that returns the item and the index in the dataset.
# if the results_path exists, it will skip all items that have been processed
# before.
def enumerate_resume(dataset, results_path):
if not os.path.exists(results_path):
for i, item in enumerate(dataset):
yield i, item
else:
count = 0
with jsonlines.open(results_path) as reader:
for item in reader:
count += 1
for i, item in enumerate(dataset):
# skip items that have been processed before
if i < count:
continue
yield i, item

Loading…
Cancel
Save