|
|
|
@ -11,42 +11,38 @@ def tokenize_inputs(config, tokenizer, examples):
|
|
|
|
|
max_length = config["max_length"]
|
|
|
|
|
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
|
|
|
|
|
# ignore bos
|
|
|
|
|
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:]
|
|
|
|
|
|
|
|
|
|
out = {"labels": [], "attention_mask": []}
|
|
|
|
|
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
|
|
|
|
|
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze()
|
|
|
|
|
input_len = len(input_tokens)
|
|
|
|
|
|
|
|
|
|
# plus one since we remove bos from response
|
|
|
|
|
# but we subtract one since we want to add eos token
|
|
|
|
|
remaining_tokens = max_length - input_len - len(newline_tokens) + 1
|
|
|
|
|
# remove bos
|
|
|
|
|
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]
|
|
|
|
|
|
|
|
|
|
input_ids[i, :input_len] = input_tokens
|
|
|
|
|
# add newline between prompt and response
|
|
|
|
|
newline_plus_inputs = input_len + len(newline_tokens)
|
|
|
|
|
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
|
|
|
|
|
|
|
|
|
|
# add target tokens, remove bos
|
|
|
|
|
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
|
|
|
|
|
# add eos token, enforce stopping if we don't truncate
|
|
|
|
|
# we don't want long code to stop generating if truncated during training
|
|
|
|
|
if newline_plus_inputs + len(target_tokens) < max_length:
|
|
|
|
|
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id
|
|
|
|
|
|
|
|
|
|
labels = input_ids[i].clone()
|
|
|
|
|
labels[: newline_plus_inputs] = -100
|
|
|
|
|
labels[labels == tokenizer.pad_token_id] = -100
|
|
|
|
|
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response
|
|
|
|
|
|
|
|
|
|
attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int()
|
|
|
|
|
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
|
|
|
|
|
if newline_tokens[0] == tokenizer.bos_token_id:
|
|
|
|
|
newline_tokens = newline_tokens[1:]
|
|
|
|
|
|
|
|
|
|
out["labels"].append(labels)
|
|
|
|
|
out["attention_mask"].append(attention_mask)
|
|
|
|
|
# hacky backward compatible
|
|
|
|
|
different_eos = tokenizer.eos_token != "</s>"
|
|
|
|
|
out = {"labels": [], "input_ids": []}
|
|
|
|
|
for prompt, response in zip(examples["prompt"], examples["response"]):
|
|
|
|
|
if different_eos:
|
|
|
|
|
if response.count("</s>") > 0:
|
|
|
|
|
response = response.replace("</s>", tokenizer.eos_token)
|
|
|
|
|
|
|
|
|
|
prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
|
|
|
|
|
|
|
|
|
|
# hack if our prompt is super long
|
|
|
|
|
# we need to include some labels
|
|
|
|
|
if prompt_len >= max_length - 1:
|
|
|
|
|
prompt = prompt[:len(prompt) // 2]
|
|
|
|
|
|
|
|
|
|
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
|
|
|
|
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
|
|
|
|
|
|
|
|
|
out["input_ids"] = input_ids
|
|
|
|
|
|
|
|
|
|
labels = input_tokens.clone()
|
|
|
|
|
labels[:prompt_len + len(newline_tokens)] = -100
|
|
|
|
|
if len(labels) < max_length:
|
|
|
|
|
# pad to max_length with -100
|
|
|
|
|
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
|
|
|
|
|
|
|
|
|
|
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
|
|
|
|
out["labels"].append(labels)
|
|
|
|
|
out["input_ids"].append(input_tokens)
|
|
|
|
|
|
|
|
|
|
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
|
|
|
|
|
|
|
|
|
|