nraptisss commited on
Commit
2fdbc71
·
verified ·
1 Parent(s): d389d4c

fix: max_seq_length → max_length, warmup_ratio → warmup_steps (TRL 1.3 compat)

Browse files
Files changed (1) hide show
  1. train.py +34 -18
train.py CHANGED
@@ -11,7 +11,7 @@ Usage:
11
  python train.py --base_model Qwen/Qwen2.5-7B-Instruct --lora_r 64
12
  """
13
 
14
- import argparse, os, json, torch
15
  from datetime import datetime
16
 
17
  from datasets import load_dataset
@@ -42,8 +42,8 @@ def parse_args():
42
  p.add_argument("--lr", type=float, default=1e-4)
43
  p.add_argument("--batch_size", type=int, default=4)
44
  p.add_argument("--grad_accum", type=int, default=8)
45
- p.add_argument("--max_seq_length", type=int, default=4096)
46
- p.add_argument("--warmup_ratio", type=float, default=0.05)
47
  p.add_argument("--weight_decay", type=float, default=0.01)
48
  # Output
49
  p.add_argument("--output_dir", type=str, default="./output")
@@ -52,8 +52,8 @@ def parse_args():
52
  p.add_argument("--push_to_hub", action="store_true", default=False)
53
  # Misc
54
  p.add_argument("--seed", type=int, default=42)
55
- p.add_argument("--bf16", action="store_true", default=True)
56
  p.add_argument("--flash_attn", action="store_true", default=True)
 
57
  return p.parse_args()
58
 
59
 
@@ -63,16 +63,17 @@ def main():
63
  print("=" * 70)
64
  print("TMF921 Intent Translation — QLoRA Training")
65
  print("=" * 70)
66
- print(f"Base model : {args.base_model}")
67
- print(f"Dataset : {args.dataset}")
68
- print(f"LoRA r/alpha : {args.lora_r}/{args.lora_alpha}")
69
- print(f"Epochs : {args.epochs}")
70
- print(f"LR : {args.lr}")
71
- print(f"Batch size : {args.batch_size} × {args.grad_accum} grad_accum = "
72
  f"{args.batch_size * args.grad_accum} effective")
73
- print(f"Max seq len : {args.max_seq_length}")
74
- print(f"Output : {args.output_dir}")
75
- print(f"Push to Hub : {args.push_to_hub} → {args.hub_model_id}")
 
76
  print("=" * 70)
77
 
78
  # ── 1. Load dataset ──────────────────────────────────────────────
@@ -126,6 +127,10 @@ def main():
126
  # ── 4. SFT Training ─────────────────────────────────────────────
127
  print("\n[4/4] Starting SFT training …")
128
 
 
 
 
 
129
  run_name = (
130
  f"tmf921-{args.base_model.split('/')[-1]}"
131
  f"-r{args.lora_r}-lr{args.lr}-ep{args.epochs}"
@@ -142,14 +147,14 @@ def main():
142
  num_train_epochs=args.epochs,
143
  learning_rate=args.lr,
144
  lr_scheduler_type="cosine",
145
- warmup_ratio=args.warmup_ratio,
146
  weight_decay=args.weight_decay,
147
  # Precision & memory
148
- bf16=args.bf16,
149
  gradient_checkpointing=True,
150
  gradient_checkpointing_kwargs={"use_reentrant": False},
151
  # Sequence
152
- max_seq_length=args.max_seq_length,
153
  # Loss: train only on assistant outputs
154
  assistant_only_loss=True,
155
  # Logging
@@ -159,10 +164,10 @@ def main():
159
  disable_tqdm=False,
160
  # Eval
161
  eval_strategy="steps",
162
- eval_steps=200,
163
  # Save
164
  save_strategy="steps",
165
- save_steps=200,
166
  save_total_limit=3,
167
  load_best_model_at_end=True,
168
  metric_for_best_model="eval_loss",
@@ -177,6 +182,10 @@ def main():
177
  dataloader_pin_memory=True,
178
  )
179
 
 
 
 
 
180
  trainer = SFTTrainer(
181
  model=model,
182
  args=sft_config,
@@ -204,6 +213,13 @@ def main():
204
  metrics = train_result.metrics
205
  metrics["train_samples"] = len(train_ds)
206
  metrics["eval_samples"] = len(eval_ds)
 
 
 
 
 
 
 
207
  with open(os.path.join(args.output_dir, "train_metrics.json"), "w") as f:
208
  json.dump(metrics, f, indent=2)
209
  print(f" Metrics saved to {args.output_dir}/train_metrics.json")
 
11
  python train.py --base_model Qwen/Qwen2.5-7B-Instruct --lora_r 64
12
  """
13
 
14
+ import argparse, os, json, torch, math
15
  from datetime import datetime
16
 
17
  from datasets import load_dataset
 
42
  p.add_argument("--lr", type=float, default=1e-4)
43
  p.add_argument("--batch_size", type=int, default=4)
44
  p.add_argument("--grad_accum", type=int, default=8)
45
+ p.add_argument("--max_length", type=int, default=4096)
46
+ p.add_argument("--warmup_steps", type=int, default=100)
47
  p.add_argument("--weight_decay", type=float, default=0.01)
48
  # Output
49
  p.add_argument("--output_dir", type=str, default="./output")
 
52
  p.add_argument("--push_to_hub", action="store_true", default=False)
53
  # Misc
54
  p.add_argument("--seed", type=int, default=42)
 
55
  p.add_argument("--flash_attn", action="store_true", default=True)
56
+ p.add_argument("--no_flash_attn", dest="flash_attn", action="store_false")
57
  return p.parse_args()
58
 
59
 
 
63
  print("=" * 70)
64
  print("TMF921 Intent Translation — QLoRA Training")
65
  print("=" * 70)
66
+ print(f"Base model : {args.base_model}")
67
+ print(f"Dataset : {args.dataset}")
68
+ print(f"LoRA r/alpha : {args.lora_r}/{args.lora_alpha}")
69
+ print(f"Epochs : {args.epochs}")
70
+ print(f"LR : {args.lr}")
71
+ print(f"Batch size : {args.batch_size} × {args.grad_accum} grad_accum = "
72
  f"{args.batch_size * args.grad_accum} effective")
73
+ print(f"Max length : {args.max_length}")
74
+ print(f"Flash attn : {args.flash_attn}")
75
+ print(f"Output : {args.output_dir}")
76
+ print(f"Push to Hub : {args.push_to_hub} → {args.hub_model_id}")
77
  print("=" * 70)
78
 
79
  # ── 1. Load dataset ──────────────────────────────────────────────
 
127
  # ── 4. SFT Training ─────────────────────────────────────────────
128
  print("\n[4/4] Starting SFT training …")
129
 
130
+ # Compute eval steps: ~4 evals per epoch
131
+ steps_per_epoch = math.ceil(len(train_ds) / (args.batch_size * args.grad_accum))
132
+ eval_steps = max(steps_per_epoch // 4, 50)
133
+
134
  run_name = (
135
  f"tmf921-{args.base_model.split('/')[-1]}"
136
  f"-r{args.lora_r}-lr{args.lr}-ep{args.epochs}"
 
147
  num_train_epochs=args.epochs,
148
  learning_rate=args.lr,
149
  lr_scheduler_type="cosine",
150
+ warmup_steps=args.warmup_steps,
151
  weight_decay=args.weight_decay,
152
  # Precision & memory
153
+ bf16=True,
154
  gradient_checkpointing=True,
155
  gradient_checkpointing_kwargs={"use_reentrant": False},
156
  # Sequence
157
+ max_length=args.max_length,
158
  # Loss: train only on assistant outputs
159
  assistant_only_loss=True,
160
  # Logging
 
164
  disable_tqdm=False,
165
  # Eval
166
  eval_strategy="steps",
167
+ eval_steps=eval_steps,
168
  # Save
169
  save_strategy="steps",
170
+ save_steps=eval_steps,
171
  save_total_limit=3,
172
  load_best_model_at_end=True,
173
  metric_for_best_model="eval_loss",
 
182
  dataloader_pin_memory=True,
183
  )
184
 
185
+ print(f" Steps/epoch: {steps_per_epoch}")
186
+ print(f" Eval every: {eval_steps} steps")
187
+ print(f" Total steps: ~{steps_per_epoch * args.epochs}")
188
+
189
  trainer = SFTTrainer(
190
  model=model,
191
  args=sft_config,
 
213
  metrics = train_result.metrics
214
  metrics["train_samples"] = len(train_ds)
215
  metrics["eval_samples"] = len(eval_ds)
216
+ metrics["base_model"] = args.base_model
217
+ metrics["lora_r"] = args.lora_r
218
+ metrics["lora_alpha"] = args.lora_alpha
219
+ metrics["learning_rate"] = args.lr
220
+ metrics["epochs"] = args.epochs
221
+ metrics["effective_batch_size"] = args.batch_size * args.grad_accum
222
+
223
  with open(os.path.join(args.output_dir, "train_metrics.json"), "w") as f:
224
  json.dump(metrics, f, indent=2)
225
  print(f" Metrics saved to {args.output_dir}/train_metrics.json")