fix: max_seq_length → max_length, warmup_ratio → warmup_steps (TRL 1.3 compat)
Browse files
train.py
CHANGED
|
@@ -11,7 +11,7 @@ Usage:
|
|
| 11 |
python train.py --base_model Qwen/Qwen2.5-7B-Instruct --lora_r 64
|
| 12 |
"""
|
| 13 |
|
| 14 |
-
import argparse, os, json, torch
|
| 15 |
from datetime import datetime
|
| 16 |
|
| 17 |
from datasets import load_dataset
|
|
@@ -42,8 +42,8 @@ def parse_args():
|
|
| 42 |
p.add_argument("--lr", type=float, default=1e-4)
|
| 43 |
p.add_argument("--batch_size", type=int, default=4)
|
| 44 |
p.add_argument("--grad_accum", type=int, default=8)
|
| 45 |
-
p.add_argument("--
|
| 46 |
-
p.add_argument("--
|
| 47 |
p.add_argument("--weight_decay", type=float, default=0.01)
|
| 48 |
# Output
|
| 49 |
p.add_argument("--output_dir", type=str, default="./output")
|
|
@@ -52,8 +52,8 @@ def parse_args():
|
|
| 52 |
p.add_argument("--push_to_hub", action="store_true", default=False)
|
| 53 |
# Misc
|
| 54 |
p.add_argument("--seed", type=int, default=42)
|
| 55 |
-
p.add_argument("--bf16", action="store_true", default=True)
|
| 56 |
p.add_argument("--flash_attn", action="store_true", default=True)
|
|
|
|
| 57 |
return p.parse_args()
|
| 58 |
|
| 59 |
|
|
@@ -63,16 +63,17 @@ def main():
|
|
| 63 |
print("=" * 70)
|
| 64 |
print("TMF921 Intent Translation — QLoRA Training")
|
| 65 |
print("=" * 70)
|
| 66 |
-
print(f"Base model
|
| 67 |
-
print(f"Dataset
|
| 68 |
-
print(f"LoRA r/alpha
|
| 69 |
-
print(f"Epochs
|
| 70 |
-
print(f"LR
|
| 71 |
-
print(f"Batch size
|
| 72 |
f"{args.batch_size * args.grad_accum} effective")
|
| 73 |
-
print(f"Max
|
| 74 |
-
print(f"
|
| 75 |
-
print(f"
|
|
|
|
| 76 |
print("=" * 70)
|
| 77 |
|
| 78 |
# ── 1. Load dataset ──────────────────────────────────────────────
|
|
@@ -126,6 +127,10 @@ def main():
|
|
| 126 |
# ── 4. SFT Training ─────────────────────────────────────────────
|
| 127 |
print("\n[4/4] Starting SFT training …")
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
run_name = (
|
| 130 |
f"tmf921-{args.base_model.split('/')[-1]}"
|
| 131 |
f"-r{args.lora_r}-lr{args.lr}-ep{args.epochs}"
|
|
@@ -142,14 +147,14 @@ def main():
|
|
| 142 |
num_train_epochs=args.epochs,
|
| 143 |
learning_rate=args.lr,
|
| 144 |
lr_scheduler_type="cosine",
|
| 145 |
-
|
| 146 |
weight_decay=args.weight_decay,
|
| 147 |
# Precision & memory
|
| 148 |
-
bf16=
|
| 149 |
gradient_checkpointing=True,
|
| 150 |
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 151 |
# Sequence
|
| 152 |
-
|
| 153 |
# Loss: train only on assistant outputs
|
| 154 |
assistant_only_loss=True,
|
| 155 |
# Logging
|
|
@@ -159,10 +164,10 @@ def main():
|
|
| 159 |
disable_tqdm=False,
|
| 160 |
# Eval
|
| 161 |
eval_strategy="steps",
|
| 162 |
-
eval_steps=
|
| 163 |
# Save
|
| 164 |
save_strategy="steps",
|
| 165 |
-
save_steps=
|
| 166 |
save_total_limit=3,
|
| 167 |
load_best_model_at_end=True,
|
| 168 |
metric_for_best_model="eval_loss",
|
|
@@ -177,6 +182,10 @@ def main():
|
|
| 177 |
dataloader_pin_memory=True,
|
| 178 |
)
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
trainer = SFTTrainer(
|
| 181 |
model=model,
|
| 182 |
args=sft_config,
|
|
@@ -204,6 +213,13 @@ def main():
|
|
| 204 |
metrics = train_result.metrics
|
| 205 |
metrics["train_samples"] = len(train_ds)
|
| 206 |
metrics["eval_samples"] = len(eval_ds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
with open(os.path.join(args.output_dir, "train_metrics.json"), "w") as f:
|
| 208 |
json.dump(metrics, f, indent=2)
|
| 209 |
print(f" Metrics saved to {args.output_dir}/train_metrics.json")
|
|
|
|
| 11 |
python train.py --base_model Qwen/Qwen2.5-7B-Instruct --lora_r 64
|
| 12 |
"""
|
| 13 |
|
| 14 |
+
import argparse, os, json, torch, math
|
| 15 |
from datetime import datetime
|
| 16 |
|
| 17 |
from datasets import load_dataset
|
|
|
|
| 42 |
p.add_argument("--lr", type=float, default=1e-4)
|
| 43 |
p.add_argument("--batch_size", type=int, default=4)
|
| 44 |
p.add_argument("--grad_accum", type=int, default=8)
|
| 45 |
+
p.add_argument("--max_length", type=int, default=4096)
|
| 46 |
+
p.add_argument("--warmup_steps", type=int, default=100)
|
| 47 |
p.add_argument("--weight_decay", type=float, default=0.01)
|
| 48 |
# Output
|
| 49 |
p.add_argument("--output_dir", type=str, default="./output")
|
|
|
|
| 52 |
p.add_argument("--push_to_hub", action="store_true", default=False)
|
| 53 |
# Misc
|
| 54 |
p.add_argument("--seed", type=int, default=42)
|
|
|
|
| 55 |
p.add_argument("--flash_attn", action="store_true", default=True)
|
| 56 |
+
p.add_argument("--no_flash_attn", dest="flash_attn", action="store_false")
|
| 57 |
return p.parse_args()
|
| 58 |
|
| 59 |
|
|
|
|
| 63 |
print("=" * 70)
|
| 64 |
print("TMF921 Intent Translation — QLoRA Training")
|
| 65 |
print("=" * 70)
|
| 66 |
+
print(f"Base model : {args.base_model}")
|
| 67 |
+
print(f"Dataset : {args.dataset}")
|
| 68 |
+
print(f"LoRA r/alpha : {args.lora_r}/{args.lora_alpha}")
|
| 69 |
+
print(f"Epochs : {args.epochs}")
|
| 70 |
+
print(f"LR : {args.lr}")
|
| 71 |
+
print(f"Batch size : {args.batch_size} × {args.grad_accum} grad_accum = "
|
| 72 |
f"{args.batch_size * args.grad_accum} effective")
|
| 73 |
+
print(f"Max length : {args.max_length}")
|
| 74 |
+
print(f"Flash attn : {args.flash_attn}")
|
| 75 |
+
print(f"Output : {args.output_dir}")
|
| 76 |
+
print(f"Push to Hub : {args.push_to_hub} → {args.hub_model_id}")
|
| 77 |
print("=" * 70)
|
| 78 |
|
| 79 |
# ── 1. Load dataset ──────────────────────────────────────────────
|
|
|
|
| 127 |
# ── 4. SFT Training ─────────────────────────────────────────────
|
| 128 |
print("\n[4/4] Starting SFT training …")
|
| 129 |
|
| 130 |
+
# Compute eval steps: ~4 evals per epoch
|
| 131 |
+
steps_per_epoch = math.ceil(len(train_ds) / (args.batch_size * args.grad_accum))
|
| 132 |
+
eval_steps = max(steps_per_epoch // 4, 50)
|
| 133 |
+
|
| 134 |
run_name = (
|
| 135 |
f"tmf921-{args.base_model.split('/')[-1]}"
|
| 136 |
f"-r{args.lora_r}-lr{args.lr}-ep{args.epochs}"
|
|
|
|
| 147 |
num_train_epochs=args.epochs,
|
| 148 |
learning_rate=args.lr,
|
| 149 |
lr_scheduler_type="cosine",
|
| 150 |
+
warmup_steps=args.warmup_steps,
|
| 151 |
weight_decay=args.weight_decay,
|
| 152 |
# Precision & memory
|
| 153 |
+
bf16=True,
|
| 154 |
gradient_checkpointing=True,
|
| 155 |
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 156 |
# Sequence
|
| 157 |
+
max_length=args.max_length,
|
| 158 |
# Loss: train only on assistant outputs
|
| 159 |
assistant_only_loss=True,
|
| 160 |
# Logging
|
|
|
|
| 164 |
disable_tqdm=False,
|
| 165 |
# Eval
|
| 166 |
eval_strategy="steps",
|
| 167 |
+
eval_steps=eval_steps,
|
| 168 |
# Save
|
| 169 |
save_strategy="steps",
|
| 170 |
+
save_steps=eval_steps,
|
| 171 |
save_total_limit=3,
|
| 172 |
load_best_model_at_end=True,
|
| 173 |
metric_for_best_model="eval_loss",
|
|
|
|
| 182 |
dataloader_pin_memory=True,
|
| 183 |
)
|
| 184 |
|
| 185 |
+
print(f" Steps/epoch: {steps_per_epoch}")
|
| 186 |
+
print(f" Eval every: {eval_steps} steps")
|
| 187 |
+
print(f" Total steps: ~{steps_per_epoch * args.epochs}")
|
| 188 |
+
|
| 189 |
trainer = SFTTrainer(
|
| 190 |
model=model,
|
| 191 |
args=sft_config,
|
|
|
|
| 213 |
metrics = train_result.metrics
|
| 214 |
metrics["train_samples"] = len(train_ds)
|
| 215 |
metrics["eval_samples"] = len(eval_ds)
|
| 216 |
+
metrics["base_model"] = args.base_model
|
| 217 |
+
metrics["lora_r"] = args.lora_r
|
| 218 |
+
metrics["lora_alpha"] = args.lora_alpha
|
| 219 |
+
metrics["learning_rate"] = args.lr
|
| 220 |
+
metrics["epochs"] = args.epochs
|
| 221 |
+
metrics["effective_batch_size"] = args.batch_size * args.grad_accum
|
| 222 |
+
|
| 223 |
with open(os.path.join(args.output_dir, "train_metrics.json"), "w") as f:
|
| 224 |
json.dump(metrics, f, indent=2)
|
| 225 |
print(f" Metrics saved to {args.output_dir}/train_metrics.json")
|