|
|
|
@ -31,7 +31,7 @@ def tokenize_inputs(config, tokenizer, examples):
|
|
|
|
|
|
|
|
|
|
# add target tokens, remove bos
|
|
|
|
|
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
|
|
|
|
|
# add eos token, enforce stopping if we don't truncate
|
|
|
|
|
# add eos token, enforce stopping if we don't truncate
|
|
|
|
|
# we don't want long code to stop generating if truncated during training
|
|
|
|
|
if newline_plus_inputs + len(target_tokens) < max_length:
|
|
|
|
|
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id
|
|
|
|
@ -67,7 +67,7 @@ def load_data(config, tokenizer):
|
|
|
|
|
dataset = load_dataset("json", data_files=files, split="train")
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
dataset = load_dataset(dataset_path, split='train')
|
|
|
|
|
dataset = load_dataset(dataset_path, split="train")
|
|
|
|
|
|
|
|
|
|
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])
|
|
|
|
|
|
|
|
|
|