mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-08 07:10:32 +00:00
fix: tokenization error
This commit is contained in:
parent
df29c62521
commit
c0a9065032
12
data.py
12
data.py
@ -9,10 +9,6 @@ from transformers import DefaultDataCollator
|
|||||||
|
|
||||||
def tokenize_inputs(config, tokenizer, examples):
|
def tokenize_inputs(config, tokenizer, examples):
|
||||||
max_length = config["max_length"]
|
max_length = config["max_length"]
|
||||||
# ignore bos
|
|
||||||
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
|
|
||||||
if newline_tokens[0] == tokenizer.bos_token_id:
|
|
||||||
newline_tokens = newline_tokens[1:]
|
|
||||||
|
|
||||||
# hacky backward compatible
|
# hacky backward compatible
|
||||||
different_eos = tokenizer.eos_token != "</s>"
|
different_eos = tokenizer.eos_token != "</s>"
|
||||||
@ -22,7 +18,7 @@ def tokenize_inputs(config, tokenizer, examples):
|
|||||||
if response.count("</s>") > 0:
|
if response.count("</s>") > 0:
|
||||||
response = response.replace("</s>", tokenizer.eos_token)
|
response = response.replace("</s>", tokenizer.eos_token)
|
||||||
|
|
||||||
prompt_len = len(tokenizer(prompt, return_tensors="pt")["input_ids"][0])
|
prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])
|
||||||
|
|
||||||
# hack if our prompt is super long
|
# hack if our prompt is super long
|
||||||
# we need to include some labels so we arbitrarily trunacate at max_length // 2
|
# we need to include some labels so we arbitrarily trunacate at max_length // 2
|
||||||
@ -33,7 +29,7 @@ def tokenize_inputs(config, tokenizer, examples):
|
|||||||
new_len = min(max_length // 2, len(prompt) // 2)
|
new_len = min(max_length // 2, len(prompt) // 2)
|
||||||
prompt = prompt[:new_len]
|
prompt = prompt[:new_len]
|
||||||
# get new prompt length
|
# get new prompt length
|
||||||
prompt_len = tokenizer(prompt, return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
|
prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
|
||||||
|
|
||||||
assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
|
assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
|
||||||
|
|
||||||
@ -41,11 +37,13 @@ def tokenize_inputs(config, tokenizer, examples):
|
|||||||
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
||||||
|
|
||||||
labels = input_tokens.clone()
|
labels = input_tokens.clone()
|
||||||
labels[:prompt_len + len(newline_tokens)] = -100
|
labels[:prompt_len] = -100
|
||||||
if len(labels) < max_length:
|
if len(labels) < max_length:
|
||||||
# pad to max_length with -100
|
# pad to max_length with -100
|
||||||
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
|
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
|
||||||
|
|
||||||
|
assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}"
|
||||||
|
|
||||||
if (labels == -100).sum() == len(labels) - 1:
|
if (labels == -100).sum() == len(labels) - 1:
|
||||||
print(prompt)
|
print(prompt)
|
||||||
print(response)
|
print(response)
|
||||||
|
Loading…
Reference in New Issue
Block a user