|
|
|
@ -9,7 +9,6 @@ from transformers import DefaultDataCollator
|
|
|
|
|
|
|
|
|
|
def tokenize_inputs(config, tokenizer, examples):
|
|
|
|
|
max_length = config["max_length"]
|
|
|
|
|
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
|
|
|
|
|
# ignore bos
|
|
|
|
|
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
|
|
|
|
|
if newline_tokens[0] == tokenizer.bos_token_id:
|
|
|
|
@ -29,6 +28,7 @@ def tokenize_inputs(config, tokenizer, examples):
|
|
|
|
|
# we need to include some labels
|
|
|
|
|
if prompt_len >= max_length - 1:
|
|
|
|
|
prompt = prompt[:len(prompt) // 2]
|
|
|
|
|
prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
|
|
|
|
|
|
|
|
|
|
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
|
|
|
|
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
|
|
|
|