mirror of
https://github.com/GammaTauAI/reflexion-human-eval
synced 2024-11-11 19:10:53 +00:00
Handle no == in get_call_str
This commit is contained in:
parent
e9407a6725
commit
b579fd61e0
@ -1,59 +1,15 @@
|
|||||||
def timeout_handler(_, __):
|
def timeout_handler(_, __):
|
||||||
raise TimeoutError()
|
raise TimeoutError()
|
||||||
|
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
class SubmissionFormatter(ABC):
|
|
||||||
"""
|
|
||||||
Class that converts between HumanEval and Leetcode submission formats.
|
|
||||||
"""
|
|
||||||
@abstractmethod
|
|
||||||
def to_leetcode(self, humaneval_snippet: str):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def to_humaneval(self, leetcode_snippet: str):
|
|
||||||
...
|
|
||||||
|
|
||||||
class PySubmissionFormatter(SubmissionFormatter):
|
|
||||||
def to_leetcode(self, humaneval_snippet: str):
|
|
||||||
return f"""\
|
|
||||||
class Solution:
|
|
||||||
{humaneval_snippet.strip()}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def to_humaneval(self, leetcode_snippet: str):
|
|
||||||
pattern = re.compile(r"class Solution:\s+([\s\S]+)")
|
|
||||||
match = pattern.search(leetcode_snippet)
|
|
||||||
if match:
|
|
||||||
return match.group(1).strip()
|
|
||||||
return leetcode_snippet.strip()
|
|
||||||
|
|
||||||
class RsSubmissionFormatter(SubmissionFormatter):
|
|
||||||
def to_leetcode(self, humaneval_snippet: str):
|
|
||||||
return f"""\
|
|
||||||
impl Solution {{
|
|
||||||
{humaneval_snippet.strip()}
|
|
||||||
}}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def to_humaneval(self, leetcode_snippet: str):
|
|
||||||
pattern = re.compile(r"impl Solution \{([\s\S]+)\}")
|
|
||||||
match = pattern.search(leetcode_snippet)
|
|
||||||
if match:
|
|
||||||
return match.group(1).strip()
|
|
||||||
return leetcode_snippet.strip()
|
|
||||||
|
|
||||||
# Py tests
|
# Py tests
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
formatter = PySubmissionFormatter()
|
# formatter = PySubmissionFormatter()
|
||||||
leetcode_1 = 'class Solution:\n def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n '
|
# leetcode_1 = 'class Solution:\n def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n '
|
||||||
humaneval_1 = 'def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n'
|
# humaneval_1 = 'def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n'
|
||||||
|
|
||||||
assert leetcode_1 == formatter.to_leetcode(humaneval_1)
|
# assert leetcode_1 == formatter.to_leetcode(humaneval_1)
|
||||||
assert humaneval_1 == formatter.to_humaneval(leetcode_1)
|
# assert humaneval_1 == formatter.to_humaneval(leetcode_1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit bdafea818c1f3c968961fa25fbe91d1d970502e4
|
Subproject commit 61f4969745189177a9edd8306217c66d6b8f9edb
|
@ -10,7 +10,8 @@ from .executor_types import ExecuteResult, Executor
|
|||||||
class PyExecutor(Executor):
|
class PyExecutor(Executor):
|
||||||
def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult:
|
def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult:
|
||||||
# Combine function code and assert statement
|
# Combine function code and assert statement
|
||||||
func_test_list = [f'{func}\n{test}' for test in tests]
|
imports = 'from typing import *'
|
||||||
|
func_test_list = [f'{imports}\n{func}\n{test}' for test in tests]
|
||||||
|
|
||||||
# Run the tests and collect the results
|
# Run the tests and collect the results
|
||||||
success_tests = []
|
success_tests = []
|
||||||
@ -77,7 +78,12 @@ check({name})
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_call_str(assert_statement: str) -> str:
|
def get_call_str(assert_statement: str) -> str:
|
||||||
call_str = ast.parse(assert_statement).body[0].test.left # type: ignore
|
ast_parsed = ast.parse(assert_statement)
|
||||||
|
try:
|
||||||
|
call_str = ast_parsed.body[0].test.left # type: ignore
|
||||||
|
except:
|
||||||
|
call_str = ast_parsed.body[0].test # type: ignore
|
||||||
|
|
||||||
return astunparse.unparse(call_str).strip()
|
return astunparse.unparse(call_str).strip()
|
||||||
|
|
||||||
def get_output(func: str, assert_statement: str, timeout: int = 5) -> str:
|
def get_output(func: str, assert_statement: str, timeout: int = 5) -> str:
|
||||||
|
@ -62,7 +62,7 @@ class PyGenerator(Generator):
|
|||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
x = generic_generate_func_impl(
|
x = generic_generate_func_impl(
|
||||||
func_sig=func_sig,
|
func_sig=f'from typing import *\n{func_sig}',
|
||||||
model=model,
|
model=model,
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
prev_func_impl=prev_func_impl,
|
prev_func_impl=prev_func_impl,
|
||||||
|
Loading…
Reference in New Issue
Block a user