diff --git a/data.py b/data.py index 72dc4574..4457d93e 100644 --- a/data.py +++ b/data.py @@ -86,7 +86,7 @@ def load_data(config, tokenizer): **kwargs ) val_dataset = val_dataset.map( - lambda ele: tokenize_inputs(config, tokenizer, ele), + lambda ele: tokenize_inputs(config, tokenizer, ele), batched=True, remove_columns=["source", "prompt"], **kwargs