mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-10 01:10:35 +00:00
fix: clean up data, pad at end
This commit is contained in:
parent
2e2e9f4339
commit
5c5f41ba36
33
configs/train/finetune_gptj.yaml
Normal file
33
configs/train/finetune_gptj.yaml
Normal file
@ -0,0 +1,33 @@
|
||||
# model/tokenizer
|
||||
model_name: "EleutherAI/gpt-j-6B"
|
||||
tokenizer_name: "EleutherAI/gpt-j-6B"
|
||||
gradient_checkpointing: true
|
||||
save_name: "nomic-ai/gpt4all-gptj-multiturn-lr-aggressive"
|
||||
|
||||
# dataset
|
||||
streaming: false
|
||||
num_proc: 64
|
||||
dataset_path: "data_multiplus"
|
||||
max_length: 1024
|
||||
batch_size: 8
|
||||
|
||||
# train dynamics
|
||||
lr: 2.0e-5
|
||||
min_lr: 0
|
||||
weight_decay: 0.0
|
||||
eval_every: 200
|
||||
eval_steps: 105
|
||||
save_every: 400
|
||||
log_grads_every: 200
|
||||
output_dir: "ckpts/gpt4all-gptj-full-multiturn-lr-aggreive"
|
||||
checkpoint: null
|
||||
lora: false
|
||||
warmup_steps: 500
|
||||
num_epochs: 4
|
||||
|
||||
# logging
|
||||
wandb: true
|
||||
wandb_entity: vicuna
|
||||
wandb_project_name: vicuna
|
||||
seed: 42
|
||||
|
52
data.py
52
data.py
@ -11,42 +11,38 @@ def tokenize_inputs(config, tokenizer, examples):
|
||||
max_length = config["max_length"]
|
||||
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
|
||||
# ignore bos
|
||||
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:]
|
||||
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
|
||||
if newline_tokens[0] == tokenizer.bos_token_id:
|
||||
newline_tokens = newline_tokens[1:]
|
||||
|
||||
out = {"labels": [], "attention_mask": []}
|
||||
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
|
||||
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze()
|
||||
input_len = len(input_tokens)
|
||||
# hacky backward compatible
|
||||
different_eos = tokenizer.eos_token != "</s>"
|
||||
out = {"labels": [], "input_ids": []}
|
||||
for prompt, response in zip(examples["prompt"], examples["response"]):
|
||||
if different_eos:
|
||||
if response.count("</s>") > 0:
|
||||
response = response.replace("</s>", tokenizer.eos_token)
|
||||
|
||||
# plus one since we remove bos from response
|
||||
# but we subtract one since we want to add eos token
|
||||
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
|
||||
# remove bos
|
||||
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
|
||||
prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
|
||||
|
||||
input_ids[i, :input_len] = input_tokens
|
||||
# add newline between prompt and response
|
||||
newline_plus_inputs = input_len + len(newline_tokens)
|
||||
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
|
||||
# hack if our prompt is super long
|
||||
# we need to include some labels
|
||||
if prompt_len >= max_length - 1:
|
||||
prompt = prompt[:len(prompt) // 2]
|
||||
|
||||
# add target tokens, remove bos
|
||||
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
|
||||
# add eos token, enforce stopping if we don't truncate
|
||||
# we don't want long code to stop generating if truncated during training
|
||||
if newline_plus_inputs + len(target_tokens) < max_length:
|
||||
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id
|
||||
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
||||
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
||||
|
||||
labels = input_ids[i].clone()
|
||||
labels[: newline_plus_inputs] = -100
|
||||
labels[labels == tokenizer.pad_token_id] = -100
|
||||
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response
|
||||
|
||||
attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int()
|
||||
labels = input_tokens.clone()
|
||||
labels[:prompt_len + len(newline_tokens)] = -100
|
||||
if len(labels) < max_length:
|
||||
# pad to max_length with -100
|
||||
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
|
||||
|
||||
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
||||
out["labels"].append(labels)
|
||||
out["attention_mask"].append(attention_mask)
|
||||
|
||||
out["input_ids"] = input_ids
|
||||
out["input_ids"].append(input_tokens)
|
||||
|
||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user