working on ucs version of reflexion

This commit is contained in:
elleven11 2023-03-28 11:23:29 -04:00
parent 891dbd3f85
commit 1ee99dfcb7
14 changed files with 510 additions and 60 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "lazzzy"]
path = lazzzy
url = https://github.com/gammatauai/lazzzy

View File

@ -1,5 +1,7 @@
from typing import NamedTuple
from typing import NamedTuple, List, Tuple
class ExecuteResult(NamedTuple):
is_passing: bool
feedback: str
state: Tuple[bool]

View File

@ -32,6 +32,15 @@ def py_execute(func: str, tests: List[str], timeout: int = 5) -> ExecuteResult:
failed_tests += [f"{tests[i]} # output: {output}"]
is_passing = False
state = []
for test in tests:
if test in success_tests:
state += [True]
else:
state += [False]
state = tuple(state)
feedback = "Tested passed:"
for test in success_tests:
feedback += f"\n{test}"
@ -39,7 +48,7 @@ def py_execute(func: str, tests: List[str], timeout: int = 5) -> ExecuteResult:
for test in failed_tests:
feedback += f"\n{test}"
return ExecuteResult(is_passing, feedback)
return ExecuteResult(is_passing, feedback, state)
def py_evaluate(name: str, func: str, test: str, timeout: int = 5) -> bool:
"""

View File

@ -5,24 +5,24 @@ import openai
import jsonlines
from tenacity import (
retry,
stop_after_attempt, # type: ignore
wait_random_exponential, # type: ignore
stop_after_attempt, # type: ignore
wait_random_exponential, # type: ignore
)
from typing import Union, List, Optional
openai.api_key = os.getenv("OPENAI_API_KEY")
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gpt_completion(
model: str,
prompt: Union[str, List[str]],
prompt: str,
max_tokens: int = 256,
stop_strs: Optional[List[str]] = None,
temperature: float = 0.0,
) -> Union[str, List[str]]:
# check if batched or not
is_batched = isinstance(prompt, list)
num_comps=1,
) -> List[str] | str:
response = openai.Completion.create(
model=model,
prompt=prompt,
@ -32,22 +32,23 @@ def gpt_completion(
frequency_penalty=0.0,
presence_penalty=0.0,
stop=stop_strs,
n=num_comps,
)
if is_batched:
res: List[str] = [""] * len(prompt)
for choice in response.choices: # type: ignore
res[choice.index] = choice.text
return res
return response.choices[0].text # type: ignore
if num_comps == 1:
return response.choices[0].text # type: ignore
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
return [choice.text for choice in response.choices] # type: ignore
@retry(wait=wait_random_exponential(min=1, max=180), stop=stop_after_attempt(6))
def gpt_chat(
model: str,
system_message: str,
user_message: str,
max_tokens: int = 256,
temperature: float = 0.0,
) -> str:
model: str,
system_message: str,
user_message: str,
max_tokens: int = 256,
temperature: float = 0.0,
num_comps=1,
) -> List[str] | str:
response = openai.ChatCompletion.create(
model=model,
messages=[
@ -59,8 +60,13 @@ def gpt_chat(
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
n=num_comps,
)
return response.choices[0].message.content # type: ignore
if num_comps == 1:
return response.choices[0].message.content # type: ignore
return [choice.message.content for choice in response.choices] # type: ignore
def parse_body(text):
lines = text.split('\n')

View File

@ -43,14 +43,28 @@ def py_generate_self_reflection(func: str, feedback: str, model: str) -> str:
reflection = gpt_completion(model, f'{PY_SELF_REFLECTION_COMPLETION_INSTRUCTION}\n{func}\n\n{feedback}\n\nExplanation:')
return reflection # type: ignore
# fixes the indentation of the function body.
# only checks if the first line is indented correctly, and if not, fixes it.
def py_fix_indentation(func: str) -> str:
lines = func.splitlines()
if len(lines) == 0:
return func
first_line = lines[0]
if first_line.startswith(' '):
return func
else:
return ' ' + func
def py_generate_func_impl(
func_sig: str,
model: str,
strategy: str,
prev_func_impl: Optional[str] = None,
feedback: Optional[str] = None,
self_reflection: Optional[str] = None
) -> str:
self_reflection: Optional[str] = None,
num_comps = 1,
temperature = 0.0,
) -> str | List[str]:
if strategy != "reflexion" and strategy != "simple":
raise ValueError(f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`")
if strategy == "reflexion" and (prev_func_impl is None or feedback is None or self_reflection is None):
@ -59,17 +73,23 @@ def py_generate_func_impl(
if model == "gpt-4" or model == "gpt-3.5-turbo":
if strategy == "reflexion":
message = f"previous implementation:\n{prev_func_impl}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}"
func_body = gpt_chat(model, PY_REFLEXION_CHAT_INSTRUCTION, message)
# func_bodies is a really bad name, as it can also be just 1 string
func_bodies = gpt_chat(model, PY_REFLEXION_CHAT_INSTRUCTION, message, num_comps=num_comps, temperature=temperature)
else:
func_body = gpt_chat(model, PY_SIMPLE_CHAT_INSTRUCTION if strategy == "simple" else PY_REFLEXION_CHAT_INSTRUCTION, func_sig)
func_bodies = gpt_chat(model, PY_SIMPLE_CHAT_INSTRUCTION if strategy == "simple" else PY_REFLEXION_CHAT_INSTRUCTION, func_sig, num_comps=num_comps, temperature=temperature)
else:
if strategy == "reflexion":
prompt = f"{PY_REFLEXION_COMPLETION_INSTRUCTION}\n{prev_func_impl}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}"
func_body = gpt_completion(model, prompt)
func_bodies = gpt_completion(model, prompt, num_comps=num_comps, temperature=temperature)
else:
prompt = f"{PY_SIMPLE_COMPLETION_INSTRUCTION}\n{func_sig}"
func_body = gpt_completion(model, prompt)
return func_sig + func_body # type: ignore
func_bodies = gpt_completion(model, prompt, num_comps=num_comps, temperature=temperature)
if num_comps == 1:
assert isinstance(func_bodies, str)
return func_sig + py_fix_indentation(func_bodies)
else:
return [func_sig + py_fix_indentation(func_body) for func_body in func_bodies]
def py_generate_internal_tests(func_sig: str, model: str, committee_size: int=1) -> List[str]:
def parse_tests(tests: str) -> List[str]:
@ -80,10 +100,10 @@ def py_generate_internal_tests(func_sig: str, model: str, committee_size: int=1)
"""
if model == "gpt-4" or model == "gpt-3.5-turbo":
message = f'{PY_TEST_GENERATION_FEW_SHOT}\n\nfunc signature:\n{func_sig}\nunit tests:'
output = gpt_chat(model, PY_TEST_GENERATION_CHAT_INSTRUCTION, message)
output = gpt_chat(model, PY_TEST_GENERATION_CHAT_INSTRUCTION, message, max_tokens=1024)
else:
prompt = f'{PY_TEST_GENERATION_COMPLETION_INSTRUCTION}\n\nfunc signature:\n{func_sig}\nunit tests:'
output = gpt_completion(model, prompt)
output = gpt_completion(model, prompt, max_tokens=1024)
cur_tests: List[str] = parse_tests(output) # type: ignore
# TODO: NOT SUPPORTED YET

1
lazzzy Submodule

@ -0,0 +1 @@
Subproject commit 10aa52640c74ccb7512a7e4cb122a3c1c548ce7f

51
main.py
View File

@ -3,25 +3,37 @@ import argparse
from simple import run_simple
from reflexion import run_reflexion
from reflexion_ucs import run_reflexion_ucs
from utils import read_jsonl, read_jsonl_gz
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--run_name", type=str, help="The name of the run")
parser.add_argument("--root_dir", type=str, help="The root logging directory", default="root")
parser.add_argument("--dataset_path", type=str, help="The path to the benchmark dataset", default="root")
parser.add_argument("--strategy", type=str, help="Strategy: `simple`, `reflexion`")
parser.add_argument("--root_dir", type=str,
help="The root logging directory", default="root")
parser.add_argument("--dataset_path", type=str,
help="The path to the benchmark dataset", default="root")
parser.add_argument("--strategy", type=str,
help="Strategy: `simple`, `reflexion`")
parser.add_argument("--language", type=str, help="Strategy: `py`")
parser.add_argument("--model", type=str, help="OpenAI models only for now. For best results, use GPT-4")
parser.add_argument("--pass_at_k", type=int, help="Pass@k metric", default=1)
parser.add_argument("--max_iters", type=int, help="The maximum number of self-improvement iterations", default=10)
parser.add_argument("--verbose", action='store_true', help="To print live logs")
parser.add_argument(
"--model", type=str, help="OpenAI models only for now. For best results, use GPT-4")
parser.add_argument("--pass_at_k", type=int,
help="Pass@k metric", default=1)
parser.add_argument("--max_iters", type=int,
help="The maximum number of self-improvement iterations", default=10)
parser.add_argument("--expansion_factor", type=int,
help="The expansion factor for the reflexion UCS and A* strategy", default=3)
parser.add_argument("--verbose", action='store_true',
help="To print live logs")
# TODO: implement this
# parser.add_argument("--is_resume", action='store_true', help="To resume run")
# parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="")
args = parser.parse_args()
return args
def main(args):
# check if the root dir exists and create it if not
if not os.path.exists(args.root_dir):
@ -29,14 +41,16 @@ def main(args):
# check if log path already exists
log_dir = os.path.join(args.root_dir, args.run_name)
log_path = os.path.join(log_dir, f"{args.strategy}_{args.max_iters}_{args.model}_pass_at_k_{args.pass_at_k}.jsonl")
log_path = os.path.join(
log_dir, f"{args.strategy}_{args.max_iters}_{args.model}_pass_at_k_{args.pass_at_k}.jsonl")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if os.path.exists(log_path):
raise ValueError(f"Log path `{log_path}` already exists in `{log_dir}`")
raise ValueError(
f"Log path `{log_path}` already exists in `{log_dir}`")
# check if the strategy is valid
if args.strategy not in ["simple", "reflexion"]:
if args.strategy not in ["simple", "reflexion", "reflexion-ucs"]:
raise ValueError(f"Strategy `{args.strategy}` is not supported")
# print starting message
@ -56,8 +70,10 @@ pass@k: {args.pass_at_k}
elif args.dataset_path.endswith(".jsonl.gz"):
dataset = read_jsonl_gz(args.dataset_path)
else:
raise ValueError(f"Dataset path `{args.dataset_path}` is not supported")
raise ValueError(
f"Dataset path `{args.dataset_path}` is not supported")
print(f"Loaded {len(dataset)} examples")
# start the run
# evaluate with pass@k
if args.strategy == "simple":
@ -79,6 +95,17 @@ pass@k: {args.pass_at_k}
log_path=log_path,
verbose=args.verbose
)
elif args.strategy == "reflexion-ucs":
run_reflexion_ucs(
dataset=dataset,
model=args.model,
language=args.language,
max_iters=args.max_iters,
pass_at_k=args.pass_at_k,
log_path=log_path,
verbose=args.verbose,
expansion_factor=args.expansion_factor
)
print(f"Done! Check out the logs in `{log_path}`")

View File

@ -6,14 +6,14 @@ from typing import List
def run_reflexion(
dataset: List[dict],
model: str,
language: str,
max_iters: int,
pass_at_k: int,
log_path: str,
verbose: bool
) -> None:
dataset: List[dict],
model: str,
language: str,
max_iters: int,
pass_at_k: int,
log_path: str,
verbose: bool
) -> None:
# should handle more languages later
# someone do this but arrange it better
evaluate = None
@ -47,8 +47,10 @@ def run_reflexion(
tests_i = internal_test_generator(item["prompt"], model, 1)
# first attempt
cur_func_impl = parse_body(func_impl_generator(item["prompt"], model, "simple"))
is_passing, feedback = execute(cur_func_impl, tests_i)
cur_func_impl = parse_body(
func_impl_generator(item["prompt"], model, "simple"))
assert isinstance(cur_func_impl, str)
is_passing, feedback, _ = execute(cur_func_impl, tests_i)
# if solved, exit early
if is_passing:
@ -61,7 +63,8 @@ def run_reflexion(
cur_feedback = feedback
while cur_iter < max_iters:
# get self-reflection
reflection = self_reflection_generator(cur_func_impl, cur_feedback, model)
reflection = self_reflection_generator(
cur_func_impl, cur_feedback, model)
reflections += [reflection]
# apply self-reflection in the next attempt
@ -73,13 +76,15 @@ def run_reflexion(
feedback=cur_feedback,
self_reflection=reflection
))
assert isinstance(cur_func_impl, str)
# check if all internal unit tests pass
is_passing, cur_feedback = execute(cur_func_impl, tests_i)
is_passing, cur_feedback, _ = execute(cur_func_impl, tests_i)
# if solved, check if it passes the real tests, exit early
if is_passing or cur_iter == max_iters - 1:
is_passing = evaluate(item["entry_point"], cur_func_impl, item["test"], timeout=10)
is_passing = evaluate(
item["entry_point"], cur_func_impl, item["test"], timeout=10)
if is_passing:
item["solution"] = cur_func_impl
is_solved = True
@ -95,4 +100,5 @@ def run_reflexion(
write_jsonl(log_path, [item], append=True)
if verbose:
print(f'completed {i+1}/{num_items}: acc = {round(num_success/(i+1), 2)}')
print(
f'completed {i+1}/{num_items}: acc = {round(num_success/(i+1), 2)}')

190
reflexion_ucs.py Normal file
View File

@ -0,0 +1,190 @@
from lazzzy.ucs import ucs
from utils import write_jsonl, parse_body
from executors import py_evaluate, py_execute
from generators import py_generate_func_impl, py_generate_self_reflection, py_generate_internal_tests
from typing import List, Set, Tuple
DEBUG = True
def debug_print(*args):
if DEBUG:
print(*args, flush=True)
class State:
def __init__(self, code: str, feedback: str, reflection: str, state: Tuple[bool]):
self.code = code
self.feedback = feedback
self.reflection = reflection
self.state = state
def __repr__(self):
return f"State(code={self.code}, feedback={self.feedback}, reflection={self.reflection}, state={self.state})"
def is_goal(self):
return all(self.state)
def __hash__(self):
return hash((self.code, self.feedback, self.reflection))
def get_unique_id(self):
res = 0
for i in range(len(self.state)):
res += self.state[i] * (2**i)
return res
def run_reflexion_ucs(
dataset: List[dict],
model: str,
language: str,
max_iters: int,
pass_at_k: int,
log_path: str,
verbose: bool,
expansion_factor: int
) -> None:
evaluate = None
execute = None
self_reflection_generator = None
func_impl_generator = None
internal_test_generator = None
if language == "python" or language == "py":
evaluate = py_evaluate
execute = py_execute
self_reflection_generator = py_generate_self_reflection
func_impl_generator = py_generate_func_impl
internal_test_generator = py_generate_internal_tests
else:
raise NotImplementedError(f"language {language} not supported")
assert not evaluate is None
assert not execute is None
assert not self_reflection_generator is None
assert not func_impl_generator is None
assert not internal_test_generator is None
num_items = len(dataset)
num_success = 0
for i, item in enumerate(dataset):
cur_pass = 0
is_solved = False
reflections = []
cur_func_impl = ""
while cur_pass < pass_at_k and not is_solved:
debug_print(f"item {i} pass {cur_pass}")
tests_i = internal_test_generator(item["prompt"], model, 1)
# cut off at 5 tests or less
tests_i = tests_i[:min(5, len(tests_i))]
# first attempt
debug_print("first attempt")
cur_func_impl = parse_body(
func_impl_generator(item["prompt"], model, "simple"))
assert isinstance(cur_func_impl, str) # num_comps of 1
is_passing, feedback, state = execute(cur_func_impl, tests_i)
debug_print(f"first attempt: \n{cur_func_impl}\n{feedback}\n{state}")
# if solved, exit early
if is_passing:
debug_print("solved at first attempt")
is_solved = True
num_success += 1
break
reflection = self_reflection_generator(
cur_func_impl, feedback, model)
reflections.append(reflection)
start = State(cur_func_impl, feedback, reflection, state)
def expand(state: State) -> Set[Tuple[State, float]]:
nonlocal max_iters
nonlocal expansion_factor
nonlocal item
nonlocal model
nonlocal tests_i
nonlocal reflections
new_states: Set[Tuple[State, float]] = set()
debug_print(f"start expansion of: {state.state}")
new_funcs = func_impl_generator(
func_sig=item["prompt"],
model=model,
strategy="reflexion",
prev_func_impl=state.code,
feedback=state.feedback,
self_reflection=state.reflection,
num_comps=expansion_factor,
temperature=0.75
)
assert isinstance(new_funcs, list)
debug_print(f"generated num of funcs: {len(new_funcs)}")
already_seen = set()
for new_func in new_funcs:
if new_func in already_seen:
debug_print(f"skipping a func because already seen.")
continue
already_seen.add(new_func)
is_passing, feedback, new_state = execute(new_func, tests_i)
debug_print(f"expanding: \n{new_func}\n{feedback}\n{new_state}")
if is_passing:
# return immediately if solved
return set([(State(new_func, feedback, "", new_state), 0)])
new_reflection = self_reflection_generator(new_func, feedback, model)
reflections.append(new_reflection)
num_failing = len([x for x in new_state if not x])
new_states.add(
(State(new_func, feedback, new_reflection, new_state), num_failing))
debug_print(f"returning new states: {new_states}")
return new_states
def when_none(l: List[State]) -> State:
debug_print(f"when_none called on: {l}")
return min(l, key=lambda x: len([y for y in x.state if not y]))
# this is either the goal state, or if not found, the current best state (lowest failed tests)
best = ucs(
start=start,
expand=expand,
is_goal=lambda x: x.is_goal(),
# NOTE: this way we reduce our search space significantly
# the maximum number of nodes is 2^num_tests,
# which is 2^5 = 32
get_unique_id=lambda x: x.get_unique_id(),
when_none=when_none
)
assert best is not None # impossible due to our when_none
is_passing = evaluate(
item["entry_point"], best.code, item["test"], timeout=5)
if is_passing:
item["solution"] = best.code
is_solved = True
num_success += 1
break # breaking pass@k loop
cur_pass += 1
item["is_solved"] = is_solved
item["reflections"] = reflections
item["solution"] = cur_func_impl
write_jsonl(log_path, [item], append=True)
if verbose:
print(
f'completed {i+1}/{num_items}: acc = {round(num_success/(i+1), 2)}')

View File

@ -1,4 +1,5 @@
jsonlines
jsonlines==3.1.0
openai==0.27.0
datasets
tenacity
datasets==2.7.0
tenacity==8.1.0
astunparse==1.6.3

File diff suppressed because one or more lines are too long

10
run_reflexion_ucs.sh Executable file
View File

@ -0,0 +1,10 @@
python main.py \
--run_name "reflexion_ucs_scratch" \
--root_dir "root" \
--dataset_path ./human-eval/data/HumanEval.jsonl.gz \
--strategy "reflexion-ucs" \
--language "py" \
--model "gpt-4" \
--pass_at_k "1" \
--max_iters "5" \
--verbose

View File

@ -35,6 +35,7 @@ def run_simple(
cur_func_impl = ""
while cur_pass < pass_at_k:
cur_func_impl = func_impl_generator(item["prompt"], model, "simple")
assert isinstance(cur_func_impl, str)
is_passing = evaluate(item["entry_point"], cur_func_impl, item["test"], timeout=10)
if is_passing:
is_solved = True

View File

@ -1,7 +1,10 @@
import sys
import signal
from utils import read_jsonl
TIMEOUT = 5 # seconds
assert len(sys.argv) == 2, "Please provide a log file"
LOG_PATH = sys.argv[1]
@ -26,7 +29,14 @@ def validate_py_results(log_path: str):
code = f'{func_impl}\n\n{item["test"]}\n\ncheck({item["entry_point"]})'
num_tests = count_test_cases(item["test"])
try:
def handler(signum, frame):
nonlocal i
raise Exception("timeout on test case" + str(i))
signal.signal(signal.SIGALRM, handler)
signal.alarm(TIMEOUT)
exec(code, globals())
signal.alarm(0)
green_text_out = green_text(f"passes {num_tests}/{num_tests} test cases")
print(f"Test {i}: {green_text_out}")
num_success += 1