From 1eb65193d98bfe2603c33f02b0b86c25b6a8beda Mon Sep 17 00:00:00 2001 From: Beck LaBash <55890162+becklabs@users.noreply.github.com> Date: Thu, 18 May 2023 19:53:30 -0400 Subject: [PATCH] Lazy imports for leetcode --- executors/factory.py | 9 ++++--- executors/leet_executor.py | 12 +++++---- executors/leetcode_env | 2 +- reflexion.py | 6 +++-- test.py | 55 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 72 insertions(+), 12 deletions(-) create mode 100644 test.py diff --git a/executors/factory.py b/executors/factory.py index b5e4009..e36454f 100644 --- a/executors/factory.py +++ b/executors/factory.py @@ -1,15 +1,14 @@ -from .leetcode_env.leetcode_env.leetcode_types import ProgrammingLanguage - from .py_executor import PyExecutor from .rs_executor import RsExecutor -from .leet_executor import LeetExecutor from .executor_types import Executor -from .leetcode_env.leetcode_env.utils import PySubmissionFormatter, RsSubmissionFormatter +from .leet_executor import LeetExecutor def executor_factory(lang: str, is_leet: bool = False) -> Executor: if lang == "py" or lang == "python": if is_leet: print("Using LeetCode Python executor") + from .leetcode_env.leetcode_env.leetcode_types import ProgrammingLanguage + from .leetcode_env.leetcode_env.utils import PySubmissionFormatter, RsSubmissionFormatter return LeetExecutor(ProgrammingLanguage.PYTHON3, PyExecutor(), PySubmissionFormatter) @@ -17,6 +16,8 @@ def executor_factory(lang: str, is_leet: bool = False) -> Executor: return PyExecutor() elif lang == "rs" or lang == "rust": if is_leet: + from .leetcode_env.leetcode_env.leetcode_types import ProgrammingLanguage + from .leetcode_env.leetcode_env.utils import PySubmissionFormatter, RsSubmissionFormatter return LeetExecutor(ProgrammingLanguage.RUST, RsExecutor(), RsSubmissionFormatter) diff --git a/executors/leet_executor.py b/executors/leet_executor.py index 52733e7..b176c68 100644 --- a/executors/leet_executor.py +++ b/executors/leet_executor.py @@ -1,16 +1,18 @@ +from __future__ import annotations + from typing import List from .executor_types import ExecuteResult, Executor from .executor_utils import to_jsonl from datetime import datetime -from .leetcode_env.leetcode_env.utils import SubmissionFormatter -from .leetcode_env.leetcode_env.leetcode_types import ProgrammingLanguage - class LeetExecutor(Executor): - - def __init__(self, lang: ProgrammingLanguage, executor: Executor, formatter: SubmissionFormatter): + def __init__(self, lang, executor: Executor, formatter): + from .leetcode_env.leetcode_env.utils import SubmissionFormatter + from .leetcode_env.leetcode_env.leetcode_types import ProgrammingLanguage from .leetcode_env.leetcode_env.environment import LeetCodeEnv + assert isinstance(formatter, SubmissionFormatter) + assert isinstance(lang, ProgrammingLanguage) self.lang = lang self.executor = executor self.formatter = formatter diff --git a/executors/leetcode_env b/executors/leetcode_env index db41e86..2ab1595 160000 --- a/executors/leetcode_env +++ b/executors/leetcode_env @@ -1 +1 @@ -Subproject commit db41e86d5d1777f4cde7a2973a606120c9a72163 +Subproject commit 2ab159560725cf8482600ca1d0adf55d1b315c14 diff --git a/reflexion.py b/reflexion.py index 079d5e0..79f7eb9 100644 --- a/reflexion.py +++ b/reflexion.py @@ -30,8 +30,10 @@ def run_reflexion( test_feedback = [] cur_func_impl = "" while cur_pass < pass_at_k and not is_solved: - # tests_i = gen.internal_tests(item["prompt"], model, 1) - tests_i = item['visible_tests'] + if is_leetcode: + tests_i = item['visible_tests'] + else: + tests_i = gen.internal_tests(item["prompt"], model, 1) # first attempt cur_func_impl = gen.func_impl(item["prompt"], model, "simple") diff --git a/test.py b/test.py new file mode 100644 index 0000000..4798a21 --- /dev/null +++ b/test.py @@ -0,0 +1,55 @@ +# Fails 2 +def minReverseOperations(self, n: int, p: int, banned: List[int], k: int) -> List[int]: + from collections import deque + banned = set(banned) + arr = tuple(0 if i in banned else (1 if i == p else 0) for i in range(n)) + queue = deque([(arr, p, 0)]) # Add a third element to the tuple to store the number of + ans = [-1] * n + visited = set() + + while queue: + cur_arr, cur_pos, ops = queue.popleft() + + if cur_pos not in visited: + visited.add(cur_pos) + ans[cur_pos] = ops + + for i in range(n): + for j in range(i + k, n + 1): + new_arr = cur_arr[:i] + tuple(reversed(cur_arr[i:j])) + cur_arr[j:] + new_pos = new_arr.index(1) + + if new_pos not in banned and (new_arr, new_pos) not in visited: + queue.append((new_arr, new_pos, ops + 1)) + + return ans + +# Fails 1 +def minReverseOperations(self, n: int, p: int, banned: List[int], k: int) -> List[int]: + from collections import deque + banned = set(banned) + arr = tuple(0 if i in banned else (1 if i == p else 0) for i in range(n)) + queue = deque([(arr, p, 0)]) # Add a third element to the tuple to store the number of operations + ans = [-1] * n + visited = set() + + while queue: + cur_arr, cur_pos, ops = queue.popleft() + + if cur_pos not in visited: + visited.add(cur_pos) + ans[cur_pos] = ops + + for i in range(n): + for j in range(i + k, n + 1): + # Check if the subarray to be reversed contains any banned positions + if any(cur_arr[i:x] for x in range(i, j) if x in banned): + continue + + new_arr = cur_arr[:i] + tuple(reversed(cur_arr[i:j])) + cur_arr[j:] + new_pos = new_arr.index(1) + + if new_pos not in banned and (new_arr, new_pos) not in visited: + queue.append((new_arr, new_pos, ops + 1)) + + return ans \ No newline at end of file