diff --git a/data.py b/data.py index 8a0dd83f..7d61154d 100644 --- a/data.py +++ b/data.py @@ -9,10 +9,6 @@ from transformers import DefaultDataCollator def tokenize_inputs(config, tokenizer, examples): max_length = config["max_length"] - # ignore bos - newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0] - if newline_tokens[0] == tokenizer.bos_token_id: - newline_tokens = newline_tokens[1:] # hacky backward compatible different_eos = tokenizer.eos_token != "" @@ -22,7 +18,7 @@ def tokenize_inputs(config, tokenizer, examples): if response.count("") > 0: response = response.replace("", tokenizer.eos_token) - prompt_len = len(tokenizer(prompt, return_tensors="pt")["input_ids"][0]) + prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0]) # hack if our prompt is super long # we need to include some labels so we arbitrarily trunacate at max_length // 2 @@ -33,7 +29,7 @@ def tokenize_inputs(config, tokenizer, examples): new_len = min(max_length // 2, len(prompt) // 2) prompt = prompt[:new_len] # get new prompt length - prompt_len = tokenizer(prompt, return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item() + prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item() assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}" @@ -41,11 +37,13 @@ def tokenize_inputs(config, tokenizer, examples): truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze() labels = input_tokens.clone() - labels[:prompt_len + len(newline_tokens)] = -100 + labels[:prompt_len] = -100 if len(labels) < max_length: # pad to max_length with -100 labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)]) + assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}" + if (labels == -100).sum() == len(labels) - 1: print(prompt) print(response)