main
cassanof 9 months ago
parent 9c8dc6617e
commit 49485efc8f

@ -1,4 +1,4 @@
from generators.model import ModelBase, Message, message_to_str, messages_to_str
from generators.model import ModelBase, Message
import random
from typing import Union, List, Optional, Callable
@ -41,7 +41,7 @@ def generic_generate_func_impl(
content=prompt,
),
Message(
role="user", # TODO: check this
role="user", # TODO: check this
content=reflexion_few_shot,
),
Message(
@ -61,8 +61,7 @@ def generic_generate_func_impl(
content=f"[improved impl]:\n{func_sig}",
),
]
func_bodies = model.generate_chat(
messages=messages, num_comps=num_comps, temperature=temperature)
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
else:
system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}"
print_messages(system_prompt, func_sig)
@ -76,8 +75,7 @@ def generic_generate_func_impl(
content=func_sig,
),
]
func_bodies = model.generate_chat(
messages=messages, num_comps=num_comps, temperature=temperature)
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
else:
if strategy == "reflexion":
prompt = f"{reflexion_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}"
@ -95,8 +93,7 @@ def generic_generate_func_impl(
return func_body_str
else:
func_bodies = [parse_code_block(func_body)
for func_body in func_bodies]
func_bodies = [parse_code_block(func_body) for func_body in func_bodies]
print_generated_func_body("\n\n".join(func_bodies))
return func_bodies
@ -105,7 +102,7 @@ def generic_generate_internal_tests(
func_sig: str,
model: ModelBase,
max_num_tests: int,
test_generation_few_shot: List[Message],
test_generation_few_shot: str,
test_generation_chat_instruction: str,
test_generation_completion_instruction: str,
parse_tests: Callable[[str], List[str]],
@ -122,7 +119,7 @@ def generic_generate_internal_tests(
),
Message(
role="user",
content=f"{messages_to_str(test_generation_few_shot)}\n\n[func signature]:\n{func_sig}\n\n[think]:"
content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[think]:"
)
]
output = model.generate_chat(messages=messages, max_tokens=1024)
@ -133,10 +130,9 @@ def generic_generate_internal_tests(
role="system",
content=test_generation_chat_instruction,
),
] + test_generation_few_shot + [
Message(
role="user",
content=f"{func_sig}"
content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[unit tests]:",
)
]
output = model.generate_chat(messages=messages, max_tokens=1024)
@ -197,7 +193,6 @@ def sample_n_random(items: List[str], n: int) -> List[str]:
return items
return random.sample(items, n)
def print_messages(system_message_text: str, user_message_text: str) -> None:
print(f"""----------------------- SYSTEM MESSAGE -----------------------)
{system_message_text}
@ -207,9 +202,7 @@ def print_messages(system_message_text: str, user_message_text: str) -> None:
----------------------------------------------
""", flush=True)
def print_generated_func_body(func_body_str: str) -> None:
print(f"""--------------------- GENERATED FUNC BODY ---------------------
{func_body_str}
------------------------------------------""")

@ -1,4 +1,4 @@
from generators.model import Message, ModelBase, messages_to_str
from generators.model import ModelBase, message_to_str
from .generator_types import Generator
from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection
@ -221,22 +221,24 @@ The implementation failed 4 out of the 7 test cases due to an IndexError. The is
END OF EXAMPLES
"""
PY_TEST_GENERATION_FEW_SHOT = [
Message(role="user", content="""def add3Numbers(x, y, z):
PY_TEST_GENERATION_FEW_SHOT = """Examples:
func signature:
def add3Numbers(x, y, z):
\"\"\" Add three numbers together.
This function takes three numbers as input and returns the sum of the three numbers.
\"\"\""""),
Message(role="assistant", content="""assert add3Numbers(1, 2, 3) == 6
\"\"\"
unit tests:
assert add3Numbers(1, 2, 3) == 6
assert add3Numbers(-1, 2, 3) == 4
assert add3Numbers(1, -2, 3) == 2
assert add3Numbers(1, 2, -3) == 0
assert add3Numbers(-3, -2, -1) == -6
assert add3Numbers(0, 0, 0) == 0""")
]
assert add3Numbers(0, 0, 0) == 0
"""
PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.
{messages_to_str(PY_TEST_GENERATION_FEW_SHOT)}"""
{PY_TEST_GENERATION_FEW_SHOT}"""
PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring."""

@ -1,4 +1,4 @@
from generators.model import Message, ModelBase, messages_to_str
from generators.model import ModelBase
from .generator_types import Generator
from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection
from .parse import parse_code_block, add_code_block
@ -47,18 +47,21 @@ fn add(a: i32, b: i32) -> i32 {
END EXAMPLES
'''
RS_TEST_GENERATION_FEW_SHOT = [
Message(role="user", content="""/// Add three numbers together.
RS_TEST_GENERATION_FEW_SHOT = """For example:
func signature:
/// Add three numbers together.
/// This function takes three numbers as input and returns the sum of the three numbers.
fn add3Numbers(x: i32, y: i32, z: i32) -> i32 {
"""),
Message(role="assistant", content="""assert_eq!(add3Numbers(1, 2, 3), 6);
unit tests:
assert_eq!(add3Numbers(1, 2, 3), 6);
assert_eq!(add3Numbers(-1, 2, 3), 4);
assert_eq!(add3Numbers(1, -2, 3), 2);
assert_eq!(add3Numbers(1, 2, -3), 0);
assert_eq!(add3Numbers(-3, -2, -1), -6);
assert_eq!(add3Numbers(0, 0, 0), 0);""")
]
assert_eq!(add3Numbers(0, 0, 0), 0);
"""
RS_SELF_REFLECTION_FEW_SHOT = '''Example 1:
[function impl]:

Loading…
Cancel
Save