move gen into class

pull/15/head
cassanof 1 year ago
parent 97d5190a7c
commit e60072c524

@ -1,7 +1,7 @@
from .py_generate import PyGenerator
from .rs_generate import RsGenerator
from .generator_types import Generator
from .model import ModelBase, GPT4, GPT35
from .model import ModelBase, GPT4, GPT35, StarChat, GPTDavinci
def generator_factory(lang: str) -> Generator:
@ -18,5 +18,9 @@ def model_factory(model_name: str) -> ModelBase:
return GPT4()
elif model_name == "gpt-3.5-turbo":
return GPT35()
elif model_name == "starchat":
return StarChat()
elif model_name.startswith("text-davinci"):
return GPTDavinci(model_name)
else:
raise ValueError(f"Invalid model name: {model_name}")

@ -36,7 +36,7 @@ def generic_generate_func_impl(
raise ValueError(
f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None")
if model.name == "gpt-4" or model.name == "gpt-3.5-turbo":
if model.is_chat:
if strategy == "reflexion":
message = f"{REFLEXION_FEW_SHOT}\n[previous impl]:\n{prev_func_impl}\n\n[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{self_reflection}\n\n[improved impl]:\n{func_sig}"
# func_bodies is a really bad name, as it can also be just 1 string
@ -46,8 +46,8 @@ def generic_generate_func_impl(
print(' ----------------------- USER MESSAGE -----------------------')
print(message, flush=True)
print('----------------------------------------------')
func_bodies = gpt_chat(model.name, REFLEXION_CHAT_INSTRUCTION,
message, num_comps=num_comps, temperature=temperature)
func_bodies = model.generate_chat(REFLEXION_CHAT_INSTRUCTION,
message, num_comps=num_comps, temperature=temperature)
else:
print('----------------------- SYSTEM MESSAGE -----------------------')
print(SIMPLE_CHAT_INSTRUCTION)
@ -55,17 +55,17 @@ def generic_generate_func_impl(
print(' ----------------------- USER MESSAGE -----------------------')
print(func_sig, flush=True)
print('----------------------------------------------')
func_bodies = gpt_chat(model.name, SIMPLE_CHAT_INSTRUCTION if strategy ==
"simple" else REFLEXION_CHAT_INSTRUCTION, func_sig, num_comps=num_comps, temperature=temperature)
func_bodies = model.generate_chat(SIMPLE_CHAT_INSTRUCTION if strategy ==
"simple" else REFLEXION_CHAT_INSTRUCTION, func_sig, num_comps=num_comps, temperature=temperature)
else:
if strategy == "reflexion":
prompt = f"{REFLEXION_COMPLETION_INSTRUCTION}\n{prev_func_impl}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}"
func_bodies = gpt_completion(
model.name, prompt, num_comps=num_comps, temperature=temperature)
func_bodies = model.generate(
prompt, num_comps=num_comps, temperature=temperature)
else:
prompt = f"{SIMPLE_COMPLETION_INSTRUCTION}\n{func_sig}"
func_bodies = gpt_completion(
model.name, prompt, num_comps=num_comps, temperature=temperature)
func_bodies = model.generate(
prompt, num_comps=num_comps, temperature=temperature)
if num_comps == 1:
assert isinstance(func_bodies, str)
@ -97,19 +97,19 @@ def generic_generate_internal_tests(
Generates tests for a function using a refinement technique with the number
of specified commmittee members.
"""
if model.name == "gpt-4" or model.name == "gpt-3.5-turbo":
if model.is_chat:
if is_react:
message = f'{TEST_GENERATION_FEW_SHOT}\n\n[func signature]:\n{func_sig}\n\n[think]:'
output = gpt_chat(
model.name, TEST_GENERATION_CHAT_INSTRUCTION, message, max_tokens=1024)
output = model.generate_chat(
TEST_GENERATION_CHAT_INSTRUCTION, message, max_tokens=1024)
print(f'React test generation output: {output}')
else:
message = f'{TEST_GENERATION_FEW_SHOT}\n\nfunc signature:\n{func_sig}\nunit tests:'
output = gpt_chat(
model.name, TEST_GENERATION_CHAT_INSTRUCTION, message, max_tokens=1024)
output = model.generate_chat(
TEST_GENERATION_CHAT_INSTRUCTION, message, max_tokens=1024)
else:
prompt = f'{TEST_GENERATION_COMPLETION_INSTRUCTION}\n\nfunc signature:\n{func_sig}\nunit tests:'
output = gpt_completion(model.name, prompt, max_tokens=1024)
output = model.generate(prompt, max_tokens=1024)
all_tests = parse_tests(output) # type: ignore
valid_tests = [test for test in all_tests if is_syntax_valid(test)]
@ -128,21 +128,19 @@ def generic_generate_self_reflection(
SELF_REFLECTION_COMPLETION_INSTRUCTION: str,
SELF_REFLECTION_FEW_SHOT: Optional[str] = None
) -> str:
if model == "gpt-4" or model == "gpt-3.5-turbo":
if model.is_chat:
if SELF_REFLECTION_FEW_SHOT is not None:
reflection = gpt_chat(
model.name,
reflection = model.generate_chat(
SELF_REFLECTION_CHAT_INSTRUCTION,
f'{SELF_REFLECTION_FEW_SHOT}\n\n[function impl]:\n{func}\n\n[unit test results]:\n{feedback}\n\n[self-reflection]:')
print(f'Self reflection output: {reflection}')
else:
reflection = gpt_chat(
model.name,
reflection = model.generate_chat(
SELF_REFLECTION_CHAT_INSTRUCTION,
f'Function implementation:\n{func}\n\nUnit test results:\n{feedback}\n\nSelf-reflection:')
else:
reflection = gpt_completion(
model.name, f'{SELF_REFLECTION_COMPLETION_INSTRUCTION}\n{func}\n\n{feedback}\n\nExplanation:')
reflection = model.generate(
f'{SELF_REFLECTION_COMPLETION_INSTRUCTION}\n{func}\n\n{feedback}\n\nExplanation:')
return reflection # type: ignore

@ -1,16 +1,83 @@
from typing import List, Union, Optional
from generators.generator_utils import gpt_chat, gpt_completion
class ModelBase():
def __init__(self, name):
def __init__(self, name: str):
self.name = name
self.is_chat = False
def __repr__(self):
def __repr__(self) -> str:
return f'{self.name}'
def generate_chat(self, system_message: str, user_message: str, max_tokens=1024, temperature=0.2, num_comps=1) -> Union[List[str], str]:
raise NotImplementedError
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]:
raise NotImplementedError
class GPTChat(ModelBase):
def __init__(self, model_name: str):
self.name = model_name
self.is_chat = True
class GPT4(ModelBase):
def generate_chat(self, system_message: str, user_message: str, max_tokens=1024, temperature=0.2, num_comps=1) -> Union[List[str], str]:
return gpt_chat(self.name, system_message, user_message,
max_tokens, temperature, num_comps)
class GPT4(GPTChat):
def __init__(self):
self.name = "gpt-4"
super().__init__("gpt-4")
class GPT35(ModelBase):
class GPT35(GPTChat):
def __init__(self):
self.name = "gpt-3.5-turbo"
super().__init__("gpt-3.5-turbo")
class GPTDavinci(ModelBase):
def __init__(self, model_name: str):
self.name = model_name
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]:
return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps)
class StarChat(ModelBase):
def __init__(self):
from transformers import pipeline
self.name = "star-chat"
self.pipe = pipeline(
"text-generation", model="HuggingFaceH4/starchat-beta")
self.template = "<|system|>\n{system}<|end|>\n<|user|>\n{query}<|end|>\n<|assistant|>"
self.is_chat = True
def generate_chat(self, system_message: str, user_message: str, max_tokens=1024, temperature=0.2, num_comps=1) -> Union[List[str], str]:
prompt = self.template.format(
system=system_message, query=user_message)
outputs = self.pipe(
prompt,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=0.95,
eos_token_id=49155,
num_return_sequences=num_comps,
)
outs = [output['generated_text'] for output in outputs] # type: ignore
assert isinstance(outs, list)
for i, out in enumerate(outs):
assert isinstance(out, str)
out = out.split("<|assistant|>")[1]
if out.endswith("<|end|>"):
out = out[:-len("<|end|>")]
outs[i] = out
if len(outs) == 1:
return outs[0] # type: ignore
else:
return outs # type: ignore

Loading…
Cancel
Save