add indentation parse

This commit is contained in:
Noah Shinn 2023-03-28 17:59:26 -04:00
parent 985a397921
commit 2223481bfe
4 changed files with 55 additions and 17 deletions

View File

@ -94,11 +94,11 @@ def get_output(func: str, assert_statement: str, timeout: int = 5) -> str:
except TimeoutError:
return "TIMEOUT"
except Exception as e:
return str(type(e).__name__)
return str(e)
if __name__ == "__main__":
pass
# Test the function
# func = "def add(a, b):\n while True:\n x = 1\n return a + b"
# tests = ["assert add(1, 2) == 3", "assert add(1, 2) == 4"]
# print(execute_with_feedback(func, tests, timeout=1))
func = "def add(a, b):\n while True:\n x = 1\n return a + b"
tests = ["assert add(1, 2) == 3", "assert add(1, 2) == 4"]
print(py_execute(func, tests, timeout=1))

View File

@ -74,3 +74,5 @@ def parse_body(text):
if 'return' in lines[i]:
return '\n'.join(lines[:i+1])
return text

View File

@ -43,17 +43,17 @@ def py_generate_self_reflection(func: str, feedback: str, model: str) -> str:
reflection = gpt_completion(model, f'{PY_SELF_REFLECTION_COMPLETION_INSTRUCTION}\n{func}\n\n{feedback}\n\nExplanation:')
return reflection # type: ignore
# fixes the indentation of the function body.
# only checks if the first line is indented correctly, and if not, fixes it.
def py_fix_indentation(func: str) -> str:
lines = func.splitlines()
if len(lines) == 0:
return func
first_line = lines[0]
if first_line.startswith(' '):
return func
else:
return ' ' + func
# # fixes the indentation of the function body.
# # only checks if the first line is indented correctly, and if not, fixes it.
# def py_fix_indentation(func: str) -> str:
# lines = func.splitlines()
# if len(lines) == 0:
# return func
# first_line = lines[0]
# if first_line.startswith(' '):
# return func
# else:
# return ' ' + func
def py_generate_func_impl(
func_sig: str,
@ -62,8 +62,8 @@ def py_generate_func_impl(
prev_func_impl: Optional[str] = None,
feedback: Optional[str] = None,
self_reflection: Optional[str] = None,
num_comps = 1,
temperature = 0.0,
num_comps: int = 1,
temperature: float = 0.0,
) -> str | List[str]:
if strategy != "reflexion" and strategy != "simple":
raise ValueError(f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`")
@ -116,3 +116,39 @@ def py_generate_internal_tests(func_sig: str, model: str, committee_size: int=1)
# cur_refinement_num += 1
return cur_tests
DUMMY_FUNC_SIG = "def func():"
DUMMY_FUNC_CALL = "func()"
def handle_first_line_indent(func_body: str) -> str:
if func_body.startswith(" "):
return func_body
split = func_body.splitlines()
return f" {split[0]}\n" + "\n".join(split[1:])
def handle_entire_body_indent(func_body: str) -> str:
split = func_body.splitlines()
res = "\n".join([" " + line for line in split])
return res
def py_fix_indentation(func_body: str) -> str:
"""
3 cases:
1. good syntax
2. first line not good
3. entire body not good
"""
def parse_indent_rec(f_body: str, cur_state: int) -> str:
if cur_state > 1:
return f_body
code = f'{DUMMY_FUNC_SIG}\n{f_body}\n{DUMMY_FUNC_CALL}'
try:
exec(code)
return f_body
except (IndentationError, SyntaxError):
p_func = handle_first_line_indent if cur_state == 0 else handle_entire_body_indent
return parse_indent_rec(p_func(func_body), cur_state + 1)
except Exception:
return f_body
return parse_indent_rec(func_body, 0)