Compare commits

...

7 Commits

Author SHA1 Message Date
cassanof abd6ed5a3d woops 9 months ago
cassanof f799068dec fix eos 9 months ago
cassanof be7fb52cf7 cahce 9 months ago
cassanof 3fec014d0a tok 9 months ago
cassanof 7470891d85 no pipeline 9 months ago
cassanof f42e651445 better design 9 months ago
cassanof c3cfcb8863 upd 9 months ago

@ -1 +1 @@
Subproject commit 2ab159560725cf8482600ca1d0adf55d1b315c14
Subproject commit 228163abdc983712bebfd8e26f7e7d360830e648

@ -7,8 +7,11 @@ from tenacity import (
wait_random_exponential, # type: ignore
)
import openai
from transformers import AutoModelForCausalLM, AutoTokenizer
MessageRole = Literal["system", "user", "assistant"]
@dataclasses.dataclass()
class Message():
role: MessageRole
@ -64,6 +67,7 @@ def gpt_chat(
return [choice.message.content for choice in response.choices] # type: ignore
class ModelBase():
def __init__(self, name: str):
self.name = name
@ -106,47 +110,115 @@ class GPTDavinci(ModelBase):
return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps)
class StarChat(ModelBase):
def __init__(self):
import torch
from transformers import pipeline
self.name = "star-chat"
self.pipe = pipeline(
"text-generation", model="HuggingFaceH4/starchat-beta", torch_dtype=torch.bfloat16, device_map="auto")
class HFModelBase(ModelBase):
"""
Base for huggingface chat models
"""
def __init__(self, model_name: str, model, tokenizer, eos_token_id=None):
self.name = model_name
self.model = model
self.tokenizer = tokenizer
self.eos_token_id = eos_token_id if eos_token_id is not None else self.tokenizer.eos_token_id
self.is_chat = True
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
# NOTE: HF does not like temp of 0.0.
# NOTE: HF does not like temp of 0.0.
if temperature < 0.0001:
temperature = 0.0001
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
if i == len(messages) - 1:
prompt += "<|assistant|>\n"
outputs = self.pipe(
prompt = self.prepare_prompt(messages)
outputs = self.model.generate(
prompt,
max_new_tokens=max_tokens,
max_new_tokens=min(
max_tokens, self.model.config.max_position_embeddings),
use_cache=True,
do_sample=True,
temperature=temperature,
top_p=0.95,
eos_token_id=49155,
eos_token_id=self.eos_token_id,
num_return_sequences=num_comps,
)
outs = [output['generated_text'] for output in outputs] # type: ignore
outs = self.tokenizer.batch_decode(outputs, skip_special_tokens=False)
assert isinstance(outs, list)
for i, out in enumerate(outs):
assert isinstance(out, str)
out = out.split("<|assistant|>")[1]
if out.endswith("<|end|>"):
out = out[:-len("<|end|>")]
outs[i] = out
outs[i] = self.extract_output(out)
if len(outs) == 1:
return outs[0] # type: ignore
else:
return outs # type: ignore
def prepare_prompt(self, messages: List[Message]) -> List[int]:
raise NotImplementedError
def extract_output(self, output: str) -> str:
raise NotImplementedError
class StarChat(HFModelBase):
def __init__(self):
import torch
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/starchat-beta",
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(
"HuggingFaceH4/starchat-beta",
)
super().__init__("star-chat", model, tokenizer, eos_token_id=49155)
def prepare_prompt(self, messages: List[Message]) -> List[int]:
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
if i == len(messages) - 1:
prompt += "<|assistant|>\n"
return self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
def extract_output(self, output: str) -> str:
out = output.split("<|assistant|>")[1]
if out.endswith("<|end|>"):
out = out[:-len("<|end|>")]
return out
class CodeLlama(HFModelBase):
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
def __init__(self):
super().__init__("code-llama", "codellama/CodeLlama-34b-Instruct-hf", 2)
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
add_eos_token=True,
add_bos_token=True,
padding_side='left'
)
def prepare_prompt(self, messages: List[Message]) -> str:
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
if i == len(messages) - 1:
prompt += "<|assistant|>\n"
return prompt
def extract_output(self, output: str) -> str:
out = output.split("<|assistant|>")[1]
if out.endswith("<|end|>"):
out = out[:-len("<|end|>")]
return out

@ -0,0 +1,10 @@
CUDA_VISIBLE_DEVICES=$1 python main.py \
--run_name "reflexion_codellama_$1" \
--root_dir "root" \
--dataset_path ./benchmarks/humaneval-py.jsonl \
--strategy "reflexion" \
--language "py" \
--model "codellama" \
--pass_at_k "1" \
--max_iters "2" \
--verbose | tee ./logs/reflexion_codellama_$1
Loading…
Cancel
Save