diff --git a/data.py b/data.py index 4457d93e..e5a7fb14 100644 --- a/data.py +++ b/data.py @@ -31,7 +31,7 @@ def tokenize_inputs(config, tokenizer, examples): # add target tokens, remove bos input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens - # add eos token, enforce stopping if we don't truncate + # add eos token, enforce stopping if we don't truncate # we don't want long code to stop generating if truncated during training if newline_plus_inputs + len(target_tokens) < max_length: input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id @@ -67,7 +67,7 @@ def load_data(config, tokenizer): dataset = load_dataset("json", data_files=files, split="train") else: - dataset = load_dataset(dataset_path, split='train') + dataset = load_dataset(dataset_path, split="train") dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])