mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-08 07:10:32 +00:00
fix: clean up data, pad at end
This commit is contained in:
parent
f45eb001a1
commit
c68311810a
33
configs/train/finetune_gptj.yaml
Normal file
33
configs/train/finetune_gptj.yaml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# model/tokenizer
|
||||||
|
model_name: "EleutherAI/gpt-j-6B"
|
||||||
|
tokenizer_name: "EleutherAI/gpt-j-6B"
|
||||||
|
gradient_checkpointing: true
|
||||||
|
save_name: "nomic-ai/gpt4all-gptj-multiturn-lr-aggressive"
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
streaming: false
|
||||||
|
num_proc: 64
|
||||||
|
dataset_path: "data_multiplus"
|
||||||
|
max_length: 1024
|
||||||
|
batch_size: 8
|
||||||
|
|
||||||
|
# train dynamics
|
||||||
|
lr: 2.0e-5
|
||||||
|
min_lr: 0
|
||||||
|
weight_decay: 0.0
|
||||||
|
eval_every: 200
|
||||||
|
eval_steps: 105
|
||||||
|
save_every: 400
|
||||||
|
log_grads_every: 200
|
||||||
|
output_dir: "ckpts/gpt4all-gptj-full-multiturn-lr-aggreive"
|
||||||
|
checkpoint: null
|
||||||
|
lora: false
|
||||||
|
warmup_steps: 500
|
||||||
|
num_epochs: 4
|
||||||
|
|
||||||
|
# logging
|
||||||
|
wandb: true
|
||||||
|
wandb_entity: vicuna
|
||||||
|
wandb_project_name: vicuna
|
||||||
|
seed: 42
|
||||||
|
|
52
data.py
52
data.py
@ -11,42 +11,38 @@ def tokenize_inputs(config, tokenizer, examples):
|
|||||||
max_length = config["max_length"]
|
max_length = config["max_length"]
|
||||||
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
|
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
|
||||||
# ignore bos
|
# ignore bos
|
||||||
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:]
|
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
|
||||||
|
if newline_tokens[0] == tokenizer.bos_token_id:
|
||||||
|
newline_tokens = newline_tokens[1:]
|
||||||
|
|
||||||
out = {"labels": [], "attention_mask": []}
|
# hacky backward compatible
|
||||||
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
|
different_eos = tokenizer.eos_token != "</s>"
|
||||||
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze()
|
out = {"labels": [], "input_ids": []}
|
||||||
input_len = len(input_tokens)
|
for prompt, response in zip(examples["prompt"], examples["response"]):
|
||||||
|
if different_eos:
|
||||||
|
if response.count("</s>") > 0:
|
||||||
|
response = response.replace("</s>", tokenizer.eos_token)
|
||||||
|
|
||||||
# plus one since we remove bos from response
|
prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
|
||||||
# but we subtract one since we want to add eos token
|
|
||||||
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
|
|
||||||
# remove bos
|
|
||||||
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
|
|
||||||
|
|
||||||
input_ids[i, :input_len] = input_tokens
|
# hack if our prompt is super long
|
||||||
# add newline between prompt and response
|
# we need to include some labels
|
||||||
newline_plus_inputs = input_len + len(newline_tokens)
|
if prompt_len >= max_length - 1:
|
||||||
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
|
prompt = prompt[:len(prompt) // 2]
|
||||||
|
|
||||||
# add target tokens, remove bos
|
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
||||||
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
|
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
||||||
# add eos token, enforce stopping if we don't truncate
|
|
||||||
# we don't want long code to stop generating if truncated during training
|
|
||||||
if newline_plus_inputs + len(target_tokens) < max_length:
|
|
||||||
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id
|
|
||||||
|
|
||||||
labels = input_ids[i].clone()
|
|
||||||
labels[: newline_plus_inputs] = -100
|
|
||||||
labels[labels == tokenizer.pad_token_id] = -100
|
|
||||||
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response
|
|
||||||
|
|
||||||
attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int()
|
labels = input_tokens.clone()
|
||||||
|
labels[:prompt_len + len(newline_tokens)] = -100
|
||||||
|
if len(labels) < max_length:
|
||||||
|
# pad to max_length with -100
|
||||||
|
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
|
||||||
|
|
||||||
|
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
||||||
out["labels"].append(labels)
|
out["labels"].append(labels)
|
||||||
out["attention_mask"].append(attention_mask)
|
out["input_ids"].append(input_tokens)
|
||||||
|
|
||||||
out["input_ids"] = input_ids
|
|
||||||
|
|
||||||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user