fix: multi-turn data breaks

This commit is contained in:
Zach Nussbaum 2023-04-12 03:51:29 +00:00
parent 60155de2a6
commit b1e361882d

View File

@ -15,8 +15,8 @@ def tokenize_inputs(config, tokenizer, examples):
out = {"labels": [], "input_ids": []}
for prompt, response in zip(examples["prompt"], examples["response"]):
if different_eos:
if response.count("</s>") > 0:
response = response.replace("</s>", tokenizer.eos_token)
if response.count("</s> \n") > 0:
response = response.replace("</s> \n", f"{tokenizer.eos_token} \n")
prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])