mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-08 07:10:32 +00:00
fix: prompt len for larger
This commit is contained in:
parent
63ff39653d
commit
8dd99cc00a
2
data.py
2
data.py
@ -9,7 +9,6 @@ from transformers import DefaultDataCollator
|
||||
|
||||
def tokenize_inputs(config, tokenizer, examples):
|
||||
max_length = config["max_length"]
|
||||
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
|
||||
# ignore bos
|
||||
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0]
|
||||
if newline_tokens[0] == tokenizer.bos_token_id:
|
||||
@ -29,6 +28,7 @@ def tokenize_inputs(config, tokenizer, examples):
|
||||
# we need to include some labels
|
||||
if prompt_len >= max_length - 1:
|
||||
prompt = prompt[:len(prompt) // 2]
|
||||
prompt_len = len(tokenizer(prompt, truncation=True, return_tensors="pt")["input_ids"][0])
|
||||
|
||||
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
|
||||
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
|
||||
|
Loading…
Reference in New Issue
Block a user