mirror of
https://github.com/GammaTauAI/reflexion-human-eval
synced 2024-11-11 19:10:53 +00:00
176 lines
5.9 KiB
Python
176 lines
5.9 KiB
Python
import warnings
|
|
from lazzzy.ucs import ucs
|
|
from utils import enumerate_resume, write_jsonl
|
|
from executors import executor_factory
|
|
from generators import generator_factory
|
|
|
|
from typing import List, Set, Tuple
|
|
|
|
|
|
DEBUG = True
|
|
|
|
def debug_print(*args):
|
|
if DEBUG:
|
|
print(*args, flush=True)
|
|
|
|
class State:
|
|
def __init__(self, code: str, feedback: str, reflection: str, state: Tuple[bool]):
|
|
self.code = code
|
|
self.feedback = feedback
|
|
self.reflection = reflection
|
|
self.state = state
|
|
|
|
def __repr__(self):
|
|
return f"State(code={self.code}, feedback={self.feedback}, reflection={self.reflection}, state={self.state})"
|
|
|
|
def is_goal(self):
|
|
return all(self.state)
|
|
|
|
def __hash__(self):
|
|
return hash((self.code, self.feedback, self.reflection))
|
|
|
|
def get_unique_id(self):
|
|
res = 0
|
|
for i in range(len(self.state)):
|
|
res += self.state[i] * (2**i)
|
|
|
|
return res
|
|
|
|
|
|
def run_reflexion_ucs(
|
|
dataset: List[dict],
|
|
model: str,
|
|
language: str,
|
|
max_iters: int,
|
|
pass_at_k: int,
|
|
log_path: str,
|
|
verbose: bool,
|
|
expansion_factor: int
|
|
) -> None:
|
|
exe = executor_factory(language)
|
|
gen = generator_factory(language)
|
|
|
|
num_items = len(dataset)
|
|
num_success = 0
|
|
for i, item in enumerate_resume(dataset, log_path):
|
|
cur_pass = 0
|
|
is_solved = False
|
|
reflections = []
|
|
cur_func_impl = ""
|
|
while cur_pass < pass_at_k and not is_solved:
|
|
debug_print(f"item {i} pass {cur_pass}")
|
|
tests_i = gen.internal_tests(item["prompt"], model, 1)
|
|
if len(tests_i) == 0:
|
|
warnings.warn(f"no internal tests generated for item {i}")
|
|
|
|
# first attempt
|
|
debug_print("first attempt")
|
|
cur_func_impl = gen.func_impl(item["prompt"], model, "simple")
|
|
assert isinstance(cur_func_impl, str) # num_comps of 1
|
|
is_passing, feedback, state = exe.execute(cur_func_impl, tests_i)
|
|
|
|
debug_print(f"first attempt: \n{cur_func_impl}\n{feedback}\n{state}")
|
|
|
|
# if solved, exit--pass_at_k 1 early
|
|
if is_passing:
|
|
debug_print("solved at first attempt")
|
|
is_solved = exe.evaluate(item["entry_point"], cur_func_impl, item["test"])
|
|
num_success += 1 if is_solved else 0
|
|
break
|
|
|
|
reflection = gen.self_reflection(
|
|
cur_func_impl, feedback, model)
|
|
reflections.append(reflection)
|
|
|
|
start = State(cur_func_impl, feedback, reflection, state)
|
|
|
|
def expand(state: State) -> Set[Tuple[State, float]]:
|
|
nonlocal max_iters
|
|
nonlocal expansion_factor
|
|
nonlocal item
|
|
nonlocal model
|
|
nonlocal tests_i
|
|
nonlocal reflections
|
|
|
|
new_states: Set[Tuple[State, float]] = set()
|
|
|
|
debug_print(f"start expansion of: {state.state}")
|
|
new_funcs = gen.func_impl(
|
|
func_sig=item["prompt"],
|
|
model=model,
|
|
strategy="reflexion",
|
|
prev_func_impl=state.code,
|
|
feedback=state.feedback,
|
|
self_reflection=state.reflection,
|
|
num_comps=expansion_factor,
|
|
temperature=0.75
|
|
)
|
|
assert isinstance(new_funcs, list)
|
|
debug_print(f"generated num of funcs: {len(new_funcs)}")
|
|
|
|
already_seen = set()
|
|
|
|
for new_func in new_funcs:
|
|
if new_func in already_seen:
|
|
debug_print(f"skipping a func because already seen.")
|
|
continue
|
|
|
|
already_seen.add(new_func)
|
|
|
|
is_passing, feedback, new_state = exe.execute(new_func, tests_i)
|
|
debug_print(f"expanding: \n{new_func}\n{feedback}\n{new_state}")
|
|
|
|
if is_passing:
|
|
# return immediately if solved
|
|
return set([(State(new_func, feedback, "", new_state), 0)])
|
|
|
|
new_reflection = gen.self_reflection(new_func, feedback, model)
|
|
reflections.append(new_reflection)
|
|
|
|
num_failing = len([x for x in new_state if not x])
|
|
new_states.add(
|
|
(State(new_func, feedback, new_reflection, new_state), num_failing))
|
|
|
|
|
|
debug_print(f"returning new states: {new_states}")
|
|
|
|
return new_states
|
|
|
|
def when_none(l: List[State]) -> State:
|
|
debug_print(f"when_none called on: {l}")
|
|
return min(l, key=lambda x: len([y for y in x.state if not y]))
|
|
|
|
# this is either the goal state, or if not found, the current best state (lowest failed tests)
|
|
best = ucs(
|
|
start=start,
|
|
expand=expand,
|
|
is_goal=lambda x: x.is_goal(),
|
|
# NOTE: this way we reduce our search space significantly
|
|
# the maximum number of nodes is 2^num_tests,
|
|
# which is 2^5 = 32
|
|
get_unique_id=lambda x: x.get_unique_id(),
|
|
when_none=when_none
|
|
)
|
|
assert best is not None # impossible due to our when_none
|
|
|
|
print("BEST CODE:\n\n\n")
|
|
print(best.code)
|
|
is_passing = exe.evaluate(
|
|
item["entry_point"], best.code, item["test"], timeout=5)
|
|
if is_passing:
|
|
item["solution"] = best.code
|
|
is_solved = True
|
|
num_success += 1
|
|
break # breaking pass@k loop
|
|
|
|
cur_pass += 1
|
|
|
|
item["is_solved"] = is_solved
|
|
item["reflections"] = reflections
|
|
item["solution"] = cur_func_impl
|
|
write_jsonl(log_path, [item], append=True)
|
|
|
|
if verbose:
|
|
print(
|
|
f'completed {i+1}/{num_items}: acc = {round(num_success/(i+1), 2)}')
|