diff --git a/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb b/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb index 018d7bd..788a43c 100644 --- a/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb +++ b/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb @@ -6,7 +6,7 @@ "provenance": [], "machine_shape": "hm", "gpuType": "A100", - "authorship_tag": "ABX9TyOJJCuqxZQnS1q+Fvz5+URG", + "authorship_tag": "ABX9TyNuIN7/ICiXCX5xELzN1Y3R", "include_colab_link": true }, "kernelspec": { @@ -380,6 +380,8 @@ "source": [ "# Fine-tune a Mistral-7b model with DPO\n", "\n", + "> 🗣️ [Large Language Model Course](https://github.com/mlabonne/llm-course)\n", + "\n", "❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne)." ], "metadata": { @@ -469,10 +471,10 @@ " prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)\n", "\n", " # Format chosen answer\n", - " chosen = example['chatgpt'] + \"<|im_end|>\\n\"\n", + " chosen = example['chosen'] + \"<|im_end|>\\n\"\n", "\n", " # Format rejected answer\n", - " rejected = example['llama2-13b-chat'] + \"<|im_end|>\\n\"\n", + " rejected = example['rejected'] + \"<|im_end|>\\n\"\n", "\n", " return {\n", " \"prompt\": system + prompt,\n", @@ -561,13 +563,6 @@ ")\n", "model.config.use_cache = False\n", "\n", - "# Reference model\n", - "ref_model = AutoModelForCausalLM.from_pretrained(\n", - " model_name,\n", - " torch_dtype=torch.float16,\n", - " load_in_4bit=True\n", - ")\n", - "\n", "# Training arguments\n", "training_args = TrainingArguments(\n", " per_device_train_batch_size=4,\n", @@ -588,7 +583,6 @@ "# Create DPO trainer\n", "dpo_trainer = DPOTrainer(\n", " model,\n", - " ref_model,\n", " args=training_args,\n", " train_dataset=dataset,\n", " tokenizer=tokenizer,\n", @@ -624,7 +618,7 @@ "tokenizer.save_pretrained(\"final_checkpoint\")\n", "\n", "# Flush memory\n", - "del dpo_trainer, model, ref_model\n", + "del dpo_trainer, model\n", "gc.collect()\n", "torch.cuda.empty_cache()\n", "\n",