fix: clean up data, pad at end

This commit is contained in:
Zach Nussbaum 2023-04-04 20:53:23 +00:00
parent f45eb001a1
commit c68311810a
2 changed files with 57 additions and 28 deletions

View File

@ -0,0 +1,33 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6B"
tokenizer_name: "EleutherAI/gpt-j-6B"
gradient_checkpointing: true
save_name: "nomic-ai/gpt4all-gptj-multiturn-lr-aggressive"
# dataset
streaming: false
num_proc: 64
dataset_path: "data_multiplus"
max_length: 1024
batch_size: 8
# train dynamics
lr: 2.0e-5
min_lr: 0
weight_decay: 0.0
eval_every: 200
eval_steps: 105
save_every: 400
log_grads_every: 200
output_dir: "ckpts/gpt4all-gptj-full-multiturn-lr-aggreive"
checkpoint: null
lora: false
warmup_steps: 500
num_epochs: 4
# logging
wandb: true
wandb_entity: vicuna
wandb_project_name: vicuna
seed: 42

52
data.py
View File

@ -11,42 +11,38 @@ def tokenize_inputs(config, tokenizer, examples):
max_length = config["max_length"] max_length = config["max_length"]
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id) input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
# ignore bos # ignore bos
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:] newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
if newline_tokens[0] == tokenizer.bos_token_id:
newline_tokens = newline_tokens[1:]
out = {"labels": [], "attention_mask": []} # hacky backward compatible
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])): different_eos = tokenizer.eos_token != "</s>"
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze() out = {"labels": [], "input_ids": []}
input_len = len(input_tokens) for prompt, response in zip(examples["prompt"], examples["response"]):
if different_eos:
if response.count("</s>") > 0:
response = response.replace("</s>", tokenizer.eos_token)
# plus one since we remove bos from response prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
# but we subtract one since we want to add eos token
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
# remove bos
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
input_ids[i, :input_len] = input_tokens # hack if our prompt is super long
# add newline between prompt and response # we need to include some labels
newline_plus_inputs = input_len + len(newline_tokens) if prompt_len >= max_length - 1:
input_ids[i, input_len: newline_plus_inputs] = newline_tokens prompt = prompt[:len(prompt) // 2]
# add target tokens, remove bos input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
# add eos token, enforce stopping if we don't truncate
# we don't want long code to stop generating if truncated during training
if newline_plus_inputs + len(target_tokens) < max_length:
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id
labels = input_ids[i].clone()
labels[: newline_plus_inputs] = -100
labels[labels == tokenizer.pad_token_id] = -100
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response
attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int() labels = input_tokens.clone()
labels[:prompt_len + len(newline_tokens)] = -100
if len(labels) < max_length:
# pad to max_length with -100
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
out["labels"].append(labels) out["labels"].append(labels)
out["attention_mask"].append(attention_mask) out["input_ids"].append(input_tokens)
out["input_ids"] = input_ids
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}