diff --git a/data.py b/data.py index 358dd007..ff519abb 100644 --- a/data.py +++ b/data.py @@ -22,24 +22,35 @@ def tokenize_inputs(config, tokenizer, examples): if response.count("") > 0: response = response.replace("", tokenizer.eos_token) - prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0]) + prompt_len = len(tokenizer(prompt, return_tensors="pt")["input_ids"][0]) # hack if our prompt is super long - # we need to include some labels - if prompt_len >= max_length - 1: - prompt = prompt[:len(prompt) // 2] - prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0]) + # we need to include some labels so we arbitrarily trunacate at max_length // 2 + # if the length is too long + if prompt_len >= max_length // 2: + # if prompt is too long, truncate + # but make sure to truncate to at max 1024 tokens + new_len = min(max_length // 2, len(prompt) // 2) + prompt = prompt[:new_len] + # get new prompt length + prompt_len = tokenizer(prompt, return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item() + + assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}" input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token, truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze() - labels = input_tokens.clone() labels[:prompt_len + len(newline_tokens)] = -100 if len(labels) < max_length: # pad to max_length with -100 labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)]) + if (labels == -100).sum() == len(labels) - 1: + print(prompt) + print(response) + raise + input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"] out["labels"].append(labels) out["input_ids"].append(input_tokens)