|
|
|
@ -22,24 +22,35 @@ def tokenize_inputs(config, tokenizer, examples):
|
|
|
|
|
if response.count("</s>") > 0:
|
|
|
|
|
response = response.replace("</s>", tokenizer.eos_token)
|
|
|
|
|
|
|
|
|
|
prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
|
|
|
|
|
prompt_len = len(tokenizer(prompt, return_tensors="pt")["input_ids"][0])
|
|
|
|
|
|
|
|
|
|
# hack if our prompt is super long
|
|
|
|
|
# we need to include some labels
|
|
|
|
|
if prompt_len >= max_length - 1:
|
|
|
|
|
prompt = prompt[:len(prompt) // 2]
|
|
|
|
|
prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
|
|
|
|
|
# we need to include some labels so we arbitrarily trunacate at max_length // 2
|
|
|
|
|
# if the length is too long
|
|
|
|
|
if prompt_len >= max_length // 2:
|
|
|
|
|
# if prompt is too long, truncate
|
|
|
|
|
# but make sure to truncate to at max 1024 tokens
|
|
|
|
|
new_len = min(max_length // 2, len(prompt) // 2)
|
|
|
|
|
prompt = prompt[:new_len]
|
|
|
|
|
# get new prompt length
|
|
|
|
|
prompt_len = tokenizer(prompt, return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
|
|
|
|
|
|
|
|
|
|
assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
|
|
|
|
|
|
|
|
|
|
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
|
|
|
|
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
labels = input_tokens.clone()
|
|
|
|
|
labels[:prompt_len + len(newline_tokens)] = -100
|
|
|
|
|
if len(labels) < max_length:
|
|
|
|
|
# pad to max_length with -100
|
|
|
|
|
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
|
|
|
|
|
|
|
|
|
|
if (labels == -100).sum() == len(labels) - 1:
|
|
|
|
|
print(prompt)
|
|
|
|
|
print(response)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
|
|
|
|
|
out["labels"].append(labels)
|
|
|
|
|
out["input_ids"].append(input_tokens)
|
|
|
|
|