mirror of
https://github.com/GammaTauAI/reflexion-human-eval
synced 2024-11-16 00:12:59 +00:00
working on ucs version of reflexion
This commit is contained in:
parent
891dbd3f85
commit
1ee99dfcb7
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "lazzzy"]
|
||||
path = lazzzy
|
||||
url = https://github.com/gammatauai/lazzzy
|
@ -1,5 +1,7 @@
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, List, Tuple
|
||||
|
||||
|
||||
class ExecuteResult(NamedTuple):
|
||||
is_passing: bool
|
||||
feedback: str
|
||||
state: Tuple[bool]
|
||||
|
@ -32,6 +32,15 @@ def py_execute(func: str, tests: List[str], timeout: int = 5) -> ExecuteResult:
|
||||
failed_tests += [f"{tests[i]} # output: {output}"]
|
||||
is_passing = False
|
||||
|
||||
state = []
|
||||
for test in tests:
|
||||
if test in success_tests:
|
||||
state += [True]
|
||||
else:
|
||||
state += [False]
|
||||
|
||||
state = tuple(state)
|
||||
|
||||
feedback = "Tested passed:"
|
||||
for test in success_tests:
|
||||
feedback += f"\n{test}"
|
||||
@ -39,7 +48,7 @@ def py_execute(func: str, tests: List[str], timeout: int = 5) -> ExecuteResult:
|
||||
for test in failed_tests:
|
||||
feedback += f"\n{test}"
|
||||
|
||||
return ExecuteResult(is_passing, feedback)
|
||||
return ExecuteResult(is_passing, feedback, state)
|
||||
|
||||
def py_evaluate(name: str, func: str, test: str, timeout: int = 5) -> bool:
|
||||
"""
|
||||
|
@ -5,24 +5,24 @@ import openai
|
||||
import jsonlines
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt, # type: ignore
|
||||
wait_random_exponential, # type: ignore
|
||||
stop_after_attempt, # type: ignore
|
||||
wait_random_exponential, # type: ignore
|
||||
)
|
||||
|
||||
from typing import Union, List, Optional
|
||||
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def gpt_completion(
|
||||
model: str,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt: str,
|
||||
max_tokens: int = 256,
|
||||
stop_strs: Optional[List[str]] = None,
|
||||
temperature: float = 0.0,
|
||||
) -> Union[str, List[str]]:
|
||||
# check if batched or not
|
||||
is_batched = isinstance(prompt, list)
|
||||
num_comps=1,
|
||||
) -> List[str] | str:
|
||||
response = openai.Completion.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
@ -32,22 +32,23 @@ def gpt_completion(
|
||||
frequency_penalty=0.0,
|
||||
presence_penalty=0.0,
|
||||
stop=stop_strs,
|
||||
n=num_comps,
|
||||
)
|
||||
if is_batched:
|
||||
res: List[str] = [""] * len(prompt)
|
||||
for choice in response.choices: # type: ignore
|
||||
res[choice.index] = choice.text
|
||||
return res
|
||||
return response.choices[0].text # type: ignore
|
||||
if num_comps == 1:
|
||||
return response.choices[0].text # type: ignore
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
return [choice.text for choice in response.choices] # type: ignore
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=180), stop=stop_after_attempt(6))
|
||||
def gpt_chat(
|
||||
model: str,
|
||||
system_message: str,
|
||||
user_message: str,
|
||||
max_tokens: int = 256,
|
||||
temperature: float = 0.0,
|
||||
) -> str:
|
||||
model: str,
|
||||
system_message: str,
|
||||
user_message: str,
|
||||
max_tokens: int = 256,
|
||||
temperature: float = 0.0,
|
||||
num_comps=1,
|
||||
) -> List[str] | str:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=[
|
||||
@ -59,8 +60,13 @@ def gpt_chat(
|
||||
top_p=1,
|
||||
frequency_penalty=0.0,
|
||||
presence_penalty=0.0,
|
||||
n=num_comps,
|
||||
)
|
||||
return response.choices[0].message.content # type: ignore
|
||||
if num_comps == 1:
|
||||
return response.choices[0].message.content # type: ignore
|
||||
|
||||
return [choice.message.content for choice in response.choices] # type: ignore
|
||||
|
||||
|
||||
def parse_body(text):
|
||||
lines = text.split('\n')
|
||||
|
@ -43,14 +43,28 @@ def py_generate_self_reflection(func: str, feedback: str, model: str) -> str:
|
||||
reflection = gpt_completion(model, f'{PY_SELF_REFLECTION_COMPLETION_INSTRUCTION}\n{func}\n\n{feedback}\n\nExplanation:')
|
||||
return reflection # type: ignore
|
||||
|
||||
# fixes the indentation of the function body.
|
||||
# only checks if the first line is indented correctly, and if not, fixes it.
|
||||
def py_fix_indentation(func: str) -> str:
|
||||
lines = func.splitlines()
|
||||
if len(lines) == 0:
|
||||
return func
|
||||
first_line = lines[0]
|
||||
if first_line.startswith(' '):
|
||||
return func
|
||||
else:
|
||||
return ' ' + func
|
||||
|
||||
def py_generate_func_impl(
|
||||
func_sig: str,
|
||||
model: str,
|
||||
strategy: str,
|
||||
prev_func_impl: Optional[str] = None,
|
||||
feedback: Optional[str] = None,
|
||||
self_reflection: Optional[str] = None
|
||||
) -> str:
|
||||
self_reflection: Optional[str] = None,
|
||||
num_comps = 1,
|
||||
temperature = 0.0,
|
||||
) -> str | List[str]:
|
||||
if strategy != "reflexion" and strategy != "simple":
|
||||
raise ValueError(f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`")
|
||||
if strategy == "reflexion" and (prev_func_impl is None or feedback is None or self_reflection is None):
|
||||
@ -59,17 +73,23 @@ def py_generate_func_impl(
|
||||
if model == "gpt-4" or model == "gpt-3.5-turbo":
|
||||
if strategy == "reflexion":
|
||||
message = f"previous implementation:\n{prev_func_impl}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}"
|
||||
func_body = gpt_chat(model, PY_REFLEXION_CHAT_INSTRUCTION, message)
|
||||
# func_bodies is a really bad name, as it can also be just 1 string
|
||||
func_bodies = gpt_chat(model, PY_REFLEXION_CHAT_INSTRUCTION, message, num_comps=num_comps, temperature=temperature)
|
||||
else:
|
||||
func_body = gpt_chat(model, PY_SIMPLE_CHAT_INSTRUCTION if strategy == "simple" else PY_REFLEXION_CHAT_INSTRUCTION, func_sig)
|
||||
func_bodies = gpt_chat(model, PY_SIMPLE_CHAT_INSTRUCTION if strategy == "simple" else PY_REFLEXION_CHAT_INSTRUCTION, func_sig, num_comps=num_comps, temperature=temperature)
|
||||
else:
|
||||
if strategy == "reflexion":
|
||||
prompt = f"{PY_REFLEXION_COMPLETION_INSTRUCTION}\n{prev_func_impl}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}"
|
||||
func_body = gpt_completion(model, prompt)
|
||||
func_bodies = gpt_completion(model, prompt, num_comps=num_comps, temperature=temperature)
|
||||
else:
|
||||
prompt = f"{PY_SIMPLE_COMPLETION_INSTRUCTION}\n{func_sig}"
|
||||
func_body = gpt_completion(model, prompt)
|
||||
return func_sig + func_body # type: ignore
|
||||
func_bodies = gpt_completion(model, prompt, num_comps=num_comps, temperature=temperature)
|
||||
|
||||
if num_comps == 1:
|
||||
assert isinstance(func_bodies, str)
|
||||
return func_sig + py_fix_indentation(func_bodies)
|
||||
else:
|
||||
return [func_sig + py_fix_indentation(func_body) for func_body in func_bodies]
|
||||
|
||||
def py_generate_internal_tests(func_sig: str, model: str, committee_size: int=1) -> List[str]:
|
||||
def parse_tests(tests: str) -> List[str]:
|
||||
@ -80,10 +100,10 @@ def py_generate_internal_tests(func_sig: str, model: str, committee_size: int=1)
|
||||
"""
|
||||
if model == "gpt-4" or model == "gpt-3.5-turbo":
|
||||
message = f'{PY_TEST_GENERATION_FEW_SHOT}\n\nfunc signature:\n{func_sig}\nunit tests:'
|
||||
output = gpt_chat(model, PY_TEST_GENERATION_CHAT_INSTRUCTION, message)
|
||||
output = gpt_chat(model, PY_TEST_GENERATION_CHAT_INSTRUCTION, message, max_tokens=1024)
|
||||
else:
|
||||
prompt = f'{PY_TEST_GENERATION_COMPLETION_INSTRUCTION}\n\nfunc signature:\n{func_sig}\nunit tests:'
|
||||
output = gpt_completion(model, prompt)
|
||||
output = gpt_completion(model, prompt, max_tokens=1024)
|
||||
cur_tests: List[str] = parse_tests(output) # type: ignore
|
||||
|
||||
# TODO: NOT SUPPORTED YET
|
||||
|
1
lazzzy
Submodule
1
lazzzy
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 10aa52640c74ccb7512a7e4cb122a3c1c548ce7f
|
51
main.py
51
main.py
@ -3,25 +3,37 @@ import argparse
|
||||
|
||||
from simple import run_simple
|
||||
from reflexion import run_reflexion
|
||||
from reflexion_ucs import run_reflexion_ucs
|
||||
from utils import read_jsonl, read_jsonl_gz
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run_name", type=str, help="The name of the run")
|
||||
parser.add_argument("--root_dir", type=str, help="The root logging directory", default="root")
|
||||
parser.add_argument("--dataset_path", type=str, help="The path to the benchmark dataset", default="root")
|
||||
parser.add_argument("--strategy", type=str, help="Strategy: `simple`, `reflexion`")
|
||||
parser.add_argument("--root_dir", type=str,
|
||||
help="The root logging directory", default="root")
|
||||
parser.add_argument("--dataset_path", type=str,
|
||||
help="The path to the benchmark dataset", default="root")
|
||||
parser.add_argument("--strategy", type=str,
|
||||
help="Strategy: `simple`, `reflexion`")
|
||||
parser.add_argument("--language", type=str, help="Strategy: `py`")
|
||||
parser.add_argument("--model", type=str, help="OpenAI models only for now. For best results, use GPT-4")
|
||||
parser.add_argument("--pass_at_k", type=int, help="Pass@k metric", default=1)
|
||||
parser.add_argument("--max_iters", type=int, help="The maximum number of self-improvement iterations", default=10)
|
||||
parser.add_argument("--verbose", action='store_true', help="To print live logs")
|
||||
parser.add_argument(
|
||||
"--model", type=str, help="OpenAI models only for now. For best results, use GPT-4")
|
||||
parser.add_argument("--pass_at_k", type=int,
|
||||
help="Pass@k metric", default=1)
|
||||
parser.add_argument("--max_iters", type=int,
|
||||
help="The maximum number of self-improvement iterations", default=10)
|
||||
parser.add_argument("--expansion_factor", type=int,
|
||||
help="The expansion factor for the reflexion UCS and A* strategy", default=3)
|
||||
parser.add_argument("--verbose", action='store_true',
|
||||
help="To print live logs")
|
||||
# TODO: implement this
|
||||
# parser.add_argument("--is_resume", action='store_true', help="To resume run")
|
||||
# parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
# check if the root dir exists and create it if not
|
||||
if not os.path.exists(args.root_dir):
|
||||
@ -29,14 +41,16 @@ def main(args):
|
||||
|
||||
# check if log path already exists
|
||||
log_dir = os.path.join(args.root_dir, args.run_name)
|
||||
log_path = os.path.join(log_dir, f"{args.strategy}_{args.max_iters}_{args.model}_pass_at_k_{args.pass_at_k}.jsonl")
|
||||
log_path = os.path.join(
|
||||
log_dir, f"{args.strategy}_{args.max_iters}_{args.model}_pass_at_k_{args.pass_at_k}.jsonl")
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
if os.path.exists(log_path):
|
||||
raise ValueError(f"Log path `{log_path}` already exists in `{log_dir}`")
|
||||
raise ValueError(
|
||||
f"Log path `{log_path}` already exists in `{log_dir}`")
|
||||
|
||||
# check if the strategy is valid
|
||||
if args.strategy not in ["simple", "reflexion"]:
|
||||
if args.strategy not in ["simple", "reflexion", "reflexion-ucs"]:
|
||||
raise ValueError(f"Strategy `{args.strategy}` is not supported")
|
||||
|
||||
# print starting message
|
||||
@ -56,8 +70,10 @@ pass@k: {args.pass_at_k}
|
||||
elif args.dataset_path.endswith(".jsonl.gz"):
|
||||
dataset = read_jsonl_gz(args.dataset_path)
|
||||
else:
|
||||
raise ValueError(f"Dataset path `{args.dataset_path}` is not supported")
|
||||
|
||||
raise ValueError(
|
||||
f"Dataset path `{args.dataset_path}` is not supported")
|
||||
|
||||
print(f"Loaded {len(dataset)} examples")
|
||||
# start the run
|
||||
# evaluate with pass@k
|
||||
if args.strategy == "simple":
|
||||
@ -79,6 +95,17 @@ pass@k: {args.pass_at_k}
|
||||
log_path=log_path,
|
||||
verbose=args.verbose
|
||||
)
|
||||
elif args.strategy == "reflexion-ucs":
|
||||
run_reflexion_ucs(
|
||||
dataset=dataset,
|
||||
model=args.model,
|
||||
language=args.language,
|
||||
max_iters=args.max_iters,
|
||||
pass_at_k=args.pass_at_k,
|
||||
log_path=log_path,
|
||||
verbose=args.verbose,
|
||||
expansion_factor=args.expansion_factor
|
||||
)
|
||||
|
||||
print(f"Done! Check out the logs in `{log_path}`")
|
||||
|
||||
|
34
reflexion.py
34
reflexion.py
@ -6,14 +6,14 @@ from typing import List
|
||||
|
||||
|
||||
def run_reflexion(
|
||||
dataset: List[dict],
|
||||
model: str,
|
||||
language: str,
|
||||
max_iters: int,
|
||||
pass_at_k: int,
|
||||
log_path: str,
|
||||
verbose: bool
|
||||
) -> None:
|
||||
dataset: List[dict],
|
||||
model: str,
|
||||
language: str,
|
||||
max_iters: int,
|
||||
pass_at_k: int,
|
||||
log_path: str,
|
||||
verbose: bool
|
||||
) -> None:
|
||||
# should handle more languages later
|
||||
# someone do this but arrange it better
|
||||
evaluate = None
|
||||
@ -47,8 +47,10 @@ def run_reflexion(
|
||||
tests_i = internal_test_generator(item["prompt"], model, 1)
|
||||
|
||||
# first attempt
|
||||
cur_func_impl = parse_body(func_impl_generator(item["prompt"], model, "simple"))
|
||||
is_passing, feedback = execute(cur_func_impl, tests_i)
|
||||
cur_func_impl = parse_body(
|
||||
func_impl_generator(item["prompt"], model, "simple"))
|
||||
assert isinstance(cur_func_impl, str)
|
||||
is_passing, feedback, _ = execute(cur_func_impl, tests_i)
|
||||
|
||||
# if solved, exit early
|
||||
if is_passing:
|
||||
@ -61,7 +63,8 @@ def run_reflexion(
|
||||
cur_feedback = feedback
|
||||
while cur_iter < max_iters:
|
||||
# get self-reflection
|
||||
reflection = self_reflection_generator(cur_func_impl, cur_feedback, model)
|
||||
reflection = self_reflection_generator(
|
||||
cur_func_impl, cur_feedback, model)
|
||||
reflections += [reflection]
|
||||
|
||||
# apply self-reflection in the next attempt
|
||||
@ -73,13 +76,15 @@ def run_reflexion(
|
||||
feedback=cur_feedback,
|
||||
self_reflection=reflection
|
||||
))
|
||||
assert isinstance(cur_func_impl, str)
|
||||
|
||||
# check if all internal unit tests pass
|
||||
is_passing, cur_feedback = execute(cur_func_impl, tests_i)
|
||||
is_passing, cur_feedback, _ = execute(cur_func_impl, tests_i)
|
||||
|
||||
# if solved, check if it passes the real tests, exit early
|
||||
if is_passing or cur_iter == max_iters - 1:
|
||||
is_passing = evaluate(item["entry_point"], cur_func_impl, item["test"], timeout=10)
|
||||
is_passing = evaluate(
|
||||
item["entry_point"], cur_func_impl, item["test"], timeout=10)
|
||||
if is_passing:
|
||||
item["solution"] = cur_func_impl
|
||||
is_solved = True
|
||||
@ -95,4 +100,5 @@ def run_reflexion(
|
||||
write_jsonl(log_path, [item], append=True)
|
||||
|
||||
if verbose:
|
||||
print(f'completed {i+1}/{num_items}: acc = {round(num_success/(i+1), 2)}')
|
||||
print(
|
||||
f'completed {i+1}/{num_items}: acc = {round(num_success/(i+1), 2)}')
|
||||
|
190
reflexion_ucs.py
Normal file
190
reflexion_ucs.py
Normal file
@ -0,0 +1,190 @@
|
||||
from lazzzy.ucs import ucs
|
||||
from utils import write_jsonl, parse_body
|
||||
from executors import py_evaluate, py_execute
|
||||
from generators import py_generate_func_impl, py_generate_self_reflection, py_generate_internal_tests
|
||||
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
|
||||
DEBUG = True
|
||||
|
||||
def debug_print(*args):
|
||||
if DEBUG:
|
||||
print(*args, flush=True)
|
||||
|
||||
class State:
|
||||
def __init__(self, code: str, feedback: str, reflection: str, state: Tuple[bool]):
|
||||
self.code = code
|
||||
self.feedback = feedback
|
||||
self.reflection = reflection
|
||||
self.state = state
|
||||
|
||||
def __repr__(self):
|
||||
return f"State(code={self.code}, feedback={self.feedback}, reflection={self.reflection}, state={self.state})"
|
||||
|
||||
def is_goal(self):
|
||||
return all(self.state)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.code, self.feedback, self.reflection))
|
||||
|
||||
def get_unique_id(self):
|
||||
res = 0
|
||||
for i in range(len(self.state)):
|
||||
res += self.state[i] * (2**i)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def run_reflexion_ucs(
|
||||
dataset: List[dict],
|
||||
model: str,
|
||||
language: str,
|
||||
max_iters: int,
|
||||
pass_at_k: int,
|
||||
log_path: str,
|
||||
verbose: bool,
|
||||
expansion_factor: int
|
||||
) -> None:
|
||||
evaluate = None
|
||||
execute = None
|
||||
self_reflection_generator = None
|
||||
func_impl_generator = None
|
||||
internal_test_generator = None
|
||||
if language == "python" or language == "py":
|
||||
evaluate = py_evaluate
|
||||
execute = py_execute
|
||||
self_reflection_generator = py_generate_self_reflection
|
||||
func_impl_generator = py_generate_func_impl
|
||||
internal_test_generator = py_generate_internal_tests
|
||||
else:
|
||||
raise NotImplementedError(f"language {language} not supported")
|
||||
|
||||
assert not evaluate is None
|
||||
assert not execute is None
|
||||
assert not self_reflection_generator is None
|
||||
assert not func_impl_generator is None
|
||||
assert not internal_test_generator is None
|
||||
|
||||
num_items = len(dataset)
|
||||
num_success = 0
|
||||
for i, item in enumerate(dataset):
|
||||
cur_pass = 0
|
||||
is_solved = False
|
||||
reflections = []
|
||||
cur_func_impl = ""
|
||||
while cur_pass < pass_at_k and not is_solved:
|
||||
debug_print(f"item {i} pass {cur_pass}")
|
||||
tests_i = internal_test_generator(item["prompt"], model, 1)
|
||||
# cut off at 5 tests or less
|
||||
tests_i = tests_i[:min(5, len(tests_i))]
|
||||
|
||||
# first attempt
|
||||
debug_print("first attempt")
|
||||
cur_func_impl = parse_body(
|
||||
func_impl_generator(item["prompt"], model, "simple"))
|
||||
assert isinstance(cur_func_impl, str) # num_comps of 1
|
||||
is_passing, feedback, state = execute(cur_func_impl, tests_i)
|
||||
|
||||
debug_print(f"first attempt: \n{cur_func_impl}\n{feedback}\n{state}")
|
||||
|
||||
# if solved, exit early
|
||||
if is_passing:
|
||||
debug_print("solved at first attempt")
|
||||
is_solved = True
|
||||
num_success += 1
|
||||
break
|
||||
|
||||
reflection = self_reflection_generator(
|
||||
cur_func_impl, feedback, model)
|
||||
reflections.append(reflection)
|
||||
|
||||
start = State(cur_func_impl, feedback, reflection, state)
|
||||
|
||||
def expand(state: State) -> Set[Tuple[State, float]]:
|
||||
nonlocal max_iters
|
||||
nonlocal expansion_factor
|
||||
nonlocal item
|
||||
nonlocal model
|
||||
nonlocal tests_i
|
||||
nonlocal reflections
|
||||
|
||||
new_states: Set[Tuple[State, float]] = set()
|
||||
|
||||
debug_print(f"start expansion of: {state.state}")
|
||||
new_funcs = func_impl_generator(
|
||||
func_sig=item["prompt"],
|
||||
model=model,
|
||||
strategy="reflexion",
|
||||
prev_func_impl=state.code,
|
||||
feedback=state.feedback,
|
||||
self_reflection=state.reflection,
|
||||
num_comps=expansion_factor,
|
||||
temperature=0.75
|
||||
)
|
||||
assert isinstance(new_funcs, list)
|
||||
debug_print(f"generated num of funcs: {len(new_funcs)}")
|
||||
|
||||
already_seen = set()
|
||||
|
||||
for new_func in new_funcs:
|
||||
if new_func in already_seen:
|
||||
debug_print(f"skipping a func because already seen.")
|
||||
continue
|
||||
|
||||
already_seen.add(new_func)
|
||||
|
||||
is_passing, feedback, new_state = execute(new_func, tests_i)
|
||||
debug_print(f"expanding: \n{new_func}\n{feedback}\n{new_state}")
|
||||
|
||||
if is_passing:
|
||||
# return immediately if solved
|
||||
return set([(State(new_func, feedback, "", new_state), 0)])
|
||||
|
||||
new_reflection = self_reflection_generator(new_func, feedback, model)
|
||||
reflections.append(new_reflection)
|
||||
|
||||
num_failing = len([x for x in new_state if not x])
|
||||
new_states.add(
|
||||
(State(new_func, feedback, new_reflection, new_state), num_failing))
|
||||
|
||||
|
||||
debug_print(f"returning new states: {new_states}")
|
||||
|
||||
return new_states
|
||||
|
||||
def when_none(l: List[State]) -> State:
|
||||
debug_print(f"when_none called on: {l}")
|
||||
return min(l, key=lambda x: len([y for y in x.state if not y]))
|
||||
|
||||
# this is either the goal state, or if not found, the current best state (lowest failed tests)
|
||||
best = ucs(
|
||||
start=start,
|
||||
expand=expand,
|
||||
is_goal=lambda x: x.is_goal(),
|
||||
# NOTE: this way we reduce our search space significantly
|
||||
# the maximum number of nodes is 2^num_tests,
|
||||
# which is 2^5 = 32
|
||||
get_unique_id=lambda x: x.get_unique_id(),
|
||||
when_none=when_none
|
||||
)
|
||||
assert best is not None # impossible due to our when_none
|
||||
|
||||
is_passing = evaluate(
|
||||
item["entry_point"], best.code, item["test"], timeout=5)
|
||||
if is_passing:
|
||||
item["solution"] = best.code
|
||||
is_solved = True
|
||||
num_success += 1
|
||||
break # breaking pass@k loop
|
||||
|
||||
cur_pass += 1
|
||||
|
||||
item["is_solved"] = is_solved
|
||||
item["reflections"] = reflections
|
||||
item["solution"] = cur_func_impl
|
||||
write_jsonl(log_path, [item], append=True)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f'completed {i+1}/{num_items}: acc = {round(num_success/(i+1), 2)}')
|
@ -1,4 +1,5 @@
|
||||
jsonlines
|
||||
jsonlines==3.1.0
|
||||
openai==0.27.0
|
||||
datasets
|
||||
tenacity
|
||||
datasets==2.7.0
|
||||
tenacity==8.1.0
|
||||
astunparse==1.6.3
|
||||
|
File diff suppressed because one or more lines are too long
10
run_reflexion_ucs.sh
Executable file
10
run_reflexion_ucs.sh
Executable file
@ -0,0 +1,10 @@
|
||||
python main.py \
|
||||
--run_name "reflexion_ucs_scratch" \
|
||||
--root_dir "root" \
|
||||
--dataset_path ./human-eval/data/HumanEval.jsonl.gz \
|
||||
--strategy "reflexion-ucs" \
|
||||
--language "py" \
|
||||
--model "gpt-4" \
|
||||
--pass_at_k "1" \
|
||||
--max_iters "5" \
|
||||
--verbose
|
@ -35,6 +35,7 @@ def run_simple(
|
||||
cur_func_impl = ""
|
||||
while cur_pass < pass_at_k:
|
||||
cur_func_impl = func_impl_generator(item["prompt"], model, "simple")
|
||||
assert isinstance(cur_func_impl, str)
|
||||
is_passing = evaluate(item["entry_point"], cur_func_impl, item["test"], timeout=10)
|
||||
if is_passing:
|
||||
is_solved = True
|
||||
|
@ -1,7 +1,10 @@
|
||||
import sys
|
||||
import signal
|
||||
|
||||
from utils import read_jsonl
|
||||
|
||||
TIMEOUT = 5 # seconds
|
||||
|
||||
assert len(sys.argv) == 2, "Please provide a log file"
|
||||
LOG_PATH = sys.argv[1]
|
||||
|
||||
@ -26,7 +29,14 @@ def validate_py_results(log_path: str):
|
||||
code = f'{func_impl}\n\n{item["test"]}\n\ncheck({item["entry_point"]})'
|
||||
num_tests = count_test_cases(item["test"])
|
||||
try:
|
||||
def handler(signum, frame):
|
||||
nonlocal i
|
||||
raise Exception("timeout on test case" + str(i))
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(TIMEOUT)
|
||||
exec(code, globals())
|
||||
signal.alarm(0)
|
||||
green_text_out = green_text(f"passes {num_tests}/{num_tests} test cases")
|
||||
print(f"Test {i}: {green_text_out}")
|
||||
num_success += 1
|
||||
|
Loading…
Reference in New Issue
Block a user