diff --git a/data.py b/data.py index 7d61154d..dc404af1 100644 --- a/data.py +++ b/data.py @@ -15,8 +15,8 @@ def tokenize_inputs(config, tokenizer, examples): out = {"labels": [], "input_ids": []} for prompt, response in zip(examples["prompt"], examples["response"]): if different_eos: - if response.count("") > 0: - response = response.replace("", tokenizer.eos_token) + if response.count(" \n") > 0: + response = response.replace(" \n", f"{tokenizer.eos_token} \n") prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])