|
|
|
@ -1,8 +1,6 @@
|
|
|
|
|
import os
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
|
|
|
|
|
from transformers.trainer_pt_utils import get_parameter_names
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.optim import AdamW
|
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
from read import read_config
|
|
|
|
@ -45,7 +43,7 @@ def train(accelerator, config):
|
|
|
|
|
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
|
|
|
|
# llama has no pad token, set it to new token
|
|
|
|
|
# if no pad token, set it to eos
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
@ -76,21 +74,9 @@ def train(accelerator, config):
|
|
|
|
|
else DummyOptim
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
no_decay = ["bias", "LayerNorm.weight"]
|
|
|
|
|
optimizer_grouped_parameters = [
|
|
|
|
|
{
|
|
|
|
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
|
|
|
"weight_decay": config["weight_decay"],
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
|
|
|
"weight_decay": 0.0,
|
|
|
|
|
},
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# karpathy doesn't decay embeddding, maybe we should exclude
|
|
|
|
|
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
|
|
|
|
|
optimizer = optimizer_cls(optimizer_grouped_parameters, lr=config["lr"])
|
|
|
|
|
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
|
|
|
|
|
|
|
|
|
|
if accelerator.state.deepspeed_plugin is not None:
|
|
|
|
|
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
|
|
|