Compare commits

...

17 Commits

Author SHA1 Message Date
cassanof 49485efc8f undo 9 months ago
cassanof 9c8dc6617e decent few-shot 9 months ago
cassanof 7247f0c306 fix 9 months ago
cassanof 468ed1f8bf exec 9 months ago
cassanof cc6b8f1192 rstrip codellama 9 months ago
cassanof f748264b35 allow to get smol boi or big boi 9 months ago
cassanof 0a9ebab46b big boi llama 9 months ago
cassanof 45c7f2c50e fix 9 months ago
cassanof 872e56c928 a 9 months ago
cassanof 1f62bab132 fix parse 9 months ago
cassanof 1d70a0026f fix 9 months ago
cassanof 3dec2edc92 out extract 9 months ago
cassanof 3b2de1a67e fix 9 months ago
cassanof f7d1613f8e tensor 9 months ago
cassanof 0365be2c6e factory 9 months ago
cassanof afcfd427a2 code llaam 9 months ago
cassanof e5cdf8c260 codellama 9 months ago

@ -1,7 +1,7 @@
from .py_generate import PyGenerator
from .rs_generate import RsGenerator
from .generator_types import Generator
from .model import ModelBase, GPT4, GPT35, StarChat, GPTDavinci
from .model import CodeLlama, ModelBase, GPT4, GPT35, StarChat, GPTDavinci
def generator_factory(lang: str) -> Generator:
@ -20,6 +20,12 @@ def model_factory(model_name: str) -> ModelBase:
return GPT35()
elif model_name == "starchat":
return StarChat()
elif model_name.startswith("codellama"):
# if it has `-` in the name, version was specified
kwargs = {}
if "-" in model_name:
kwargs["version"] = model_name.split("-")[1]
return CodeLlama(**kwargs)
elif model_name.startswith("text-davinci"):
return GPTDavinci(model_name)
else:

@ -205,4 +205,4 @@ def print_messages(system_message_text: str, user_message_text: str) -> None:
def print_generated_func_body(func_body_str: str) -> None:
print(f"""--------------------- GENERATED FUNC BODY ---------------------
{func_body_str}
------------------------------------------""")
------------------------------------------""")

@ -7,7 +7,6 @@ from tenacity import (
wait_random_exponential, # type: ignore
)
import openai
from transformers import AutoModelForCausalLM, AutoTokenizer
MessageRole = Literal["system", "user", "assistant"]
@ -18,6 +17,14 @@ class Message():
content: str
def message_to_str(message: Message) -> str:
return f"{message.role}: {message.content}"
def messages_to_str(messages: List[Message]) -> str:
return "\n".join([message_to_str(message) for message in messages])
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gpt_completion(
model: str,
@ -152,7 +159,7 @@ class HFModelBase(ModelBase):
else:
return outs # type: ignore
def prepare_prompt(self, messages: List[Message]) -> List[int]:
def prepare_prompt(self, messages: List[Message]):
raise NotImplementedError
def extract_output(self, output: str) -> str:
@ -162,6 +169,7 @@ class HFModelBase(ModelBase):
class StarChat(HFModelBase):
def __init__(self):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/starchat-beta",
torch_dtype=torch.bfloat16,
@ -170,9 +178,9 @@ class StarChat(HFModelBase):
tokenizer = AutoTokenizer.from_pretrained(
"HuggingFaceH4/starchat-beta",
)
super().__init__("star-chat", model, tokenizer, eos_token_id=49155)
super().__init__("starchat", model, tokenizer, eos_token_id=49155)
def prepare_prompt(self, messages: List[Message]) -> List[int]:
def prepare_prompt(self, messages: List[Message]):
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
@ -198,27 +206,58 @@ You are a helpful, respectful and honest assistant. Always answer as helpfully a
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
def __init__(self):
super().__init__("code-llama", "codellama/CodeLlama-34b-Instruct-hf", 2)
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
def __init__(self, version: Literal["34b", "13b", "7b"] = "34b"):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
f"codellama/CodeLlama-{version}-Instruct-hf",
add_eos_token=True,
add_bos_token=True,
padding_side='left'
)
def prepare_prompt(self, messages: List[Message]) -> str:
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
if i == len(messages) - 1:
prompt += "<|assistant|>\n"
return prompt
model = AutoModelForCausalLM.from_pretrained(
f"codellama/CodeLlama-{version}-Instruct-hf",
torch_dtype=torch.bfloat16,
device_map="auto",
)
super().__init__("codellama", model, tokenizer)
def prepare_prompt(self, messages: List[Message]):
if messages[0].role != "system":
messages = [
Message(role="system", content=self.DEFAULT_SYSTEM_PROMPT)
] + messages
messages = [
Message(role=messages[1].role, content=self.B_SYS +
messages[0].content + self.E_SYS + messages[1].content)
] + messages[2:]
assert all([msg.role == "user" for msg in messages[::2]]) and all(
[msg.role == "assistant" for msg in messages[1::2]]
), (
"model only supports 'system', 'user' and 'assistant' roles, "
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)
messages_tokens: List[int] = sum(
[
self.tokenizer.encode(
f"{self.B_INST} {(prompt.content).strip()} {self.E_INST} {(answer.content).strip()} ",
)
for prompt, answer in zip(
messages[::2],
messages[1::2],
)
],
[],
)
assert messages[-1].role == "user", f"Last message must be from user, got {messages[-1].role}"
messages_tokens += self.tokenizer.encode(
f"{self.B_INST} {(messages[-1].content).strip()} {self.E_INST}",
)
# remove eos token from last message
messages_tokens = messages_tokens[:-1]
import torch
return torch.tensor([messages_tokens]).to(self.model.device)
def extract_output(self, output: str) -> str:
out = output.split("<|assistant|>")[1]
if out.endswith("<|end|>"):
out = out[:-len("<|end|>")]
out = output.split("[/INST]")[-1].split("</s>")[0].strip()
return out

@ -5,34 +5,45 @@ from typing import Optional
def parse_code_block(string: str, lang: str) -> Optional[str]:
code_pattern = fr"```{lang}\n(.*?)\n```"
match = re.search(code_pattern, string, re.DOTALL)
if match:
return match.group(1)
else:
return parse_first_func(string, lang)
generic_code_pattern = r"```\n(.*?)\n```"
match = re.search(generic_code_pattern, string, re.DOTALL)
if match:
return match.group(1)
return parse_first_func(string, lang)
def parse_first_func(code: str, lang: str) -> Optional[str]:
assert lang == "python", "Only python is supported for now. TODO: Rust"
code_lines = code.split("\n")
def_i = 0
def_i = -1
last_i = 0
got_return = False
for i, line in enumerate(code_lines):
if line.startswith("def "):
if def_i == 0:
if def_i == -1:
def_i = i
else:
break
if line == "" and def_i != 0:
elif "return" in line and def_i != -1:
got_return = True
if line == "" and def_i != -1 and got_return:
last_i = i
break
if last_i == 0:
last_i = len(code_lines) - 1
if def_i == 0:
if def_i == -1:
return None
return "\n".join(code_lines[def_i:last_i+1])
return "\n".join(code_lines[def_i:last_i+1]).rstrip("[/PYTHON]")
def add_code_block(string: str, lang: str) -> str:
return f"```{lang}\n{string}\n```"
@ -63,3 +74,33 @@ def bleh():
return aaa
"""
print(parse_code_block(CODE, "python"))
CODE = """def total_match(lst1: List[str], lst2: List[str]) -> List[str]:
\"\"\"
Write a function that accepts two lists of strings and returns the list that has
total number of chars in the all strings of the list less than the other list.
if the two lists have the same number of chars, return the first list.
Examples
>>> total_match([], [])
[]
>>> total_match(['hi', 'admin'], ['hI', 'Hi'])
['hI', 'Hi']
>>> total_match(['hi', 'admin'], ['hi', 'hi', 'admin', 'project'])
['hi', 'admin']
>>> total_match(['hi', 'admin'], ['hI', 'hi', 'hi'])
['hI', 'hi', 'hi']
>>> total_match(['4'], ['1', '2', '3', '4', '5'])
['4']
\"\"\"
total_chars_lst1 = sum(len(word) for word in lst1)
total_chars_lst2 = sum(len(word) for word in lst2)
if total_chars_lst1 < total_chars_lst2:
return lst1
elif total_chars_lst1 > total_chars_lst2:
return lst2
else:
return lst1
"""
print(parse_code_block(CODE, "python"))

@ -1,4 +1,4 @@
from generators.model import ModelBase
from generators.model import ModelBase, message_to_str
from .generator_types import Generator
from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection
@ -223,22 +223,18 @@ END OF EXAMPLES
PY_TEST_GENERATION_FEW_SHOT = """Examples:
func signature:
def has_close_elements(numbers: List[float], threshold: float) -> bool:
\"\"\" Check if in given list of numbers, are any two numbers closer to each other than
given threshold.
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
False
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
True
def add3Numbers(x, y, z):
\"\"\" Add three numbers together.
This function takes three numbers as input and returns the sum of the three numbers.
\"\"\"
unit tests:
assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True
assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False
assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True
assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False
assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True
assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True
assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False"""
assert add3Numbers(1, 2, 3) == 6
assert add3Numbers(-1, 2, 3) == 4
assert add3Numbers(1, -2, 3) == 2
assert add3Numbers(1, 2, -3) == 0
assert add3Numbers(-3, -2, -1) == -6
assert add3Numbers(0, 0, 0) == 0
"""
PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.
@ -247,7 +243,6 @@ PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant t
PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring."""
class PyGenerator(Generator):
def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str:
return generic_generate_self_reflection(

@ -47,31 +47,20 @@ fn add(a: i32, b: i32) -> i32 {
END EXAMPLES
'''
RS_TEST_GENERATION_FEW_SHOT = """For example:
func signature:
```rust
/// For a given number n, find the largest number that divides n evenly, smaller than n
/// >>> largest_divisor(15)
/// 5
fn largest_divisor(n: isize) -> isize {
for i in (1..n).rev() {
if n % i == 0 {
return i;
}
}
// if no divisor is found, return 1
1
}
```
/// Add three numbers together.
/// This function takes three numbers as input and returns the sum of the three numbers.
fn add3Numbers(x: i32, y: i32, z: i32) -> i32 {
unit tests:
assert_eq!(candidate(3), 1);
assert_eq!(candidate(7), 1);
assert_eq!(candidate(10), 5);
assert_eq!(candidate(100), 50);
assert_eq!(candidate(49), 7);
assert_eq!(add3Numbers(1, 2, 3), 6);
assert_eq!(add3Numbers(-1, 2, 3), 4);
assert_eq!(add3Numbers(1, -2, 3), 2);
assert_eq!(add3Numbers(1, 2, -3), 0);
assert_eq!(add3Numbers(-3, -2, -1), -6);
assert_eq!(add3Numbers(0, 0, 0), 0);
"""
RS_SELF_REFLECTION_FEW_SHOT = '''Example 1:

@ -0,0 +1,10 @@
python main.py \
--run_name "test_simple_run_codellama" \
--root_dir "root" \
--dataset_path ./benchmarks/humaneval-py.jsonl \
--strategy "simple" \
--language "py" \
--model "codellama" \
--pass_at_k "1" \
--max_iters "1" \
--verbose

@ -64,9 +64,11 @@ def enumerate_resume(dataset, results_path):
continue
yield i, item
def resume_success_count(dataset) -> int:
count = 0
for item in dataset:
if item["is_solved"]:
if "is_solved" in item and item["is_solved"]:
count += 1
return count
return count

Loading…
Cancel
Save