Handle no == in get_call_str

This commit is contained in:
Beck LaBash 2023-04-11 20:47:41 -04:00
parent e9407a6725
commit b579fd61e0
4 changed files with 16 additions and 54 deletions

View File

@ -1,59 +1,15 @@
def timeout_handler(_, __):
raise TimeoutError()
import re
from abc import ABC, abstractmethod
class SubmissionFormatter(ABC):
"""
Class that converts between HumanEval and Leetcode submission formats.
"""
@abstractmethod
def to_leetcode(self, humaneval_snippet: str):
...
@abstractmethod
def to_humaneval(self, leetcode_snippet: str):
...
class PySubmissionFormatter(SubmissionFormatter):
def to_leetcode(self, humaneval_snippet: str):
return f"""\
class Solution:
{humaneval_snippet.strip()}
"""
def to_humaneval(self, leetcode_snippet: str):
pattern = re.compile(r"class Solution:\s+([\s\S]+)")
match = pattern.search(leetcode_snippet)
if match:
return match.group(1).strip()
return leetcode_snippet.strip()
class RsSubmissionFormatter(SubmissionFormatter):
def to_leetcode(self, humaneval_snippet: str):
return f"""\
impl Solution {{
{humaneval_snippet.strip()}
}}
"""
def to_humaneval(self, leetcode_snippet: str):
pattern = re.compile(r"impl Solution \{([\s\S]+)\}")
match = pattern.search(leetcode_snippet)
if match:
return match.group(1).strip()
return leetcode_snippet.strip()
# Py tests
if __name__ == "__main__":
formatter = PySubmissionFormatter()
leetcode_1 = 'class Solution:\n def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n '
humaneval_1 = 'def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n'
# if __name__ == "__main__":
# formatter = PySubmissionFormatter()
# leetcode_1 = 'class Solution:\n def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n '
# humaneval_1 = 'def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n'
assert leetcode_1 == formatter.to_leetcode(humaneval_1)
assert humaneval_1 == formatter.to_humaneval(leetcode_1)
# assert leetcode_1 == formatter.to_leetcode(humaneval_1)
# assert humaneval_1 == formatter.to_humaneval(leetcode_1)

@ -1 +1 @@
Subproject commit bdafea818c1f3c968961fa25fbe91d1d970502e4
Subproject commit 61f4969745189177a9edd8306217c66d6b8f9edb

View File

@ -10,7 +10,8 @@ from .executor_types import ExecuteResult, Executor
class PyExecutor(Executor):
def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult:
# Combine function code and assert statement
func_test_list = [f'{func}\n{test}' for test in tests]
imports = 'from typing import *'
func_test_list = [f'{imports}\n{func}\n{test}' for test in tests]
# Run the tests and collect the results
success_tests = []
@ -77,7 +78,12 @@ check({name})
return False
def get_call_str(assert_statement: str) -> str:
call_str = ast.parse(assert_statement).body[0].test.left # type: ignore
ast_parsed = ast.parse(assert_statement)
try:
call_str = ast_parsed.body[0].test.left # type: ignore
except:
call_str = ast_parsed.body[0].test # type: ignore
return astunparse.unparse(call_str).strip()
def get_output(func: str, assert_statement: str, timeout: int = 5) -> str:

View File

@ -62,7 +62,7 @@ class PyGenerator(Generator):
temperature: float = 0.0,
) -> Union[str, List[str]]:
x = generic_generate_func_impl(
func_sig=func_sig,
func_sig=f'from typing import *\n{func_sig}',
model=model,
strategy=strategy,
prev_func_impl=prev_func_impl,