fix: multi-turn data breaks

pull/335/head
Zach Nussbaum 1 year ago
parent 15f7c5b68f
commit 8a94a8c068

@ -15,8 +15,8 @@ def tokenize_inputs(config, tokenizer, examples):
out = {"labels": [], "input_ids": []}
for prompt, response in zip(examples["prompt"], examples["response"]):
if different_eos:
if response.count("</s>") > 0:
response = response.replace("</s>", tokenizer.eos_token)
if response.count("</s> \n") > 0:
response = response.replace("</s> \n", f"{tokenizer.eos_token} \n")
prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])

Loading…
Cancel
Save