diff --git a/programming_runs/generators/generator_utils.py b/programming_runs/generators/generator_utils.py index fecb7d1..1f4798a 100644 --- a/programming_runs/generators/generator_utils.py +++ b/programming_runs/generators/generator_utils.py @@ -1,6 +1,5 @@ from generators.model import ModelBase, Message import random -from parse import parse_code_block, add_code_block from typing import Union, List, Optional, Callable @@ -155,7 +154,8 @@ def generic_generate_self_reflection( model: ModelBase, self_reflection_chat_instruction: str, self_reflection_completion_instruction: str, - self_reflection_few_shot: Optional[str] = None + add_code_block: Callable[[str], str], + self_reflection_few_shot: Optional[str] = None, ) -> str: if model.is_chat: if self_reflection_few_shot is not None: diff --git a/programming_runs/generators/py_generate.py b/programming_runs/generators/py_generate.py index 9ff5de9..4701e03 100644 --- a/programming_runs/generators/py_generate.py +++ b/programming_runs/generators/py_generate.py @@ -5,7 +5,7 @@ from .generator_utils import generic_generate_func_impl, generic_generate_intern from typing import Optional, List, Union import ast import re -from parse import parse_code_block, add_code_block +from .parse import parse_code_block, add_code_block PY_SIMPLE_COMPLETION_INSTRUCTION = "# Write the body of this function only." PY_REFLEXION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature).\n\n-----" @@ -255,6 +255,7 @@ class PyGenerator(Generator): model=model, self_reflection_chat_instruction=PY_SELF_REFLECTION_CHAT_INSTRUCTION, self_reflection_completion_instruction=PY_SELF_REFLECTION_COMPLETION_INSTRUCTION, + add_code_block=lambda x: add_code_block(x, "python"), self_reflection_few_shot=PY_SELF_REFLECTION_FEW_SHOT ) diff --git a/programming_runs/generators/rs_generate.py b/programming_runs/generators/rs_generate.py index 95e7aee..2da41b7 100644 --- a/programming_runs/generators/rs_generate.py +++ b/programming_runs/generators/rs_generate.py @@ -1,7 +1,7 @@ from generators.model import ModelBase from .generator_types import Generator from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection -from parse import parse_code_block, add_code_block +from .parse import parse_code_block, add_code_block from typing import List, Optional, Union @@ -149,6 +149,7 @@ class RsGenerator(Generator): model=model, self_reflection_chat_instruction=RS_SELF_REFLECTION_CHAT_INSTRUCTION, self_reflection_completion_instruction=RS_SELF_REFLECTION_COMPLETION_INSTRUCTION, + add_code_block=lambda x: add_code_block(x, "rust"), self_reflection_few_shot=RS_SELF_REFLECTION_FEW_SHOT, )