Ellaft commited on
Commit
61c951d
·
verified ·
1 Parent(s): c29559f

Update train_v2.py: add --dataset flag to switch between old proxy data (dataset_real) and new built dataset (dataset_v2)

Browse files
Files changed (1) hide show
  1. src/train_v2.py +111 -212
src/train_v2.py CHANGED
@@ -6,13 +6,18 @@ Changes from v1:
6
  - Asymmetric learning rates: higher for visual branch, lower for audio
7
  - Auxiliary loss logging (loss_fusion, loss_visual, loss_audio per epoch)
8
  - OGM-GE stats logging (visual_conf, audio_conf, modulation coefficients)
9
- - Uses models_v2 (auxiliary heads) and dataset_real (real data)
10
 
11
  Usage:
 
12
  python train_v2.py --mode multimodal --finetune lora --eval_robustness
 
 
 
 
 
 
13
  python train_v2.py --mode visual_only --finetune lora --no_push
14
- python train_v2.py --mode audio_only --finetune lora --no_push
15
- python train_v2.py --mode multimodal --finetune full --lr 2e-5
16
  python train_v2.py --quick_test --no_push
17
 
18
  References:
@@ -30,7 +35,6 @@ from torch.optim.lr_scheduler import OneCycleLR
30
  from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support
31
 
32
  from config import ExperimentConfig, FAULT_CLASSES, NUM_CLASSES
33
- from dataset_real import RealPCFaultDataset as PCFaultDataset, multimodal_collate_fn
34
  from models_v2 import create_model, get_processors, OGMGEModulator
35
 
36
 
@@ -73,7 +77,8 @@ class MultimodalTrainerV2:
73
 
74
  def __init__(self, model, train_dataset, val_dataset, config, device,
75
  use_ogm=True, ogm_alpha=0.3, ogm_noise_sigma=0.1,
76
- visual_lr_multiplier=3.0, audio_lr_multiplier=0.5):
 
77
  self.model = model.to(device)
78
  self.device = device
79
  self.config = config
@@ -91,7 +96,7 @@ class MultimodalTrainerV2:
91
  train_dataset,
92
  batch_size=config.per_device_train_batch_size,
93
  shuffle=True,
94
- collate_fn=multimodal_collate_fn,
95
  num_workers=2,
96
  pin_memory=True,
97
  drop_last=True)
@@ -99,7 +104,7 @@ class MultimodalTrainerV2:
99
  val_dataset,
100
  batch_size=config.per_device_eval_batch_size,
101
  shuffle=False,
102
- collate_fn=multimodal_collate_fn,
103
  num_workers=2,
104
  pin_memory=True)
105
 
@@ -126,25 +131,13 @@ class MultimodalTrainerV2:
126
  self.history = {
127
  "train_loss": [], "val_loss": [],
128
  "val_accuracy": [], "val_macro_f1": [],
129
- # v2 additions
130
  "train_loss_fusion": [], "train_loss_visual": [], "train_loss_audio": [],
131
  "ogm_visual_conf": [], "ogm_audio_conf": [],
132
  "ogm_coeff_visual": [], "ogm_coeff_audio": [],
133
  }
134
 
135
  def _get_param_groups(self, visual_lr_multiplier, audio_lr_multiplier):
136
- """
137
- Create 3 parameter groups with asymmetric learning rates.
138
-
139
- For LoRA mode: uses lora_learning_rate as base.
140
- Visual branch gets multiplier > 1 (boost weak modality).
141
- Audio branch gets multiplier < 1 (slow down dominant modality).
142
- Fusion head + auxiliary heads get base LR.
143
- """
144
- visual_params = []
145
- audio_params = []
146
- fusion_params = [] # fusion head + auxiliary heads
147
-
148
  for name, param in self.model.named_parameters():
149
  if not param.requires_grad:
150
  continue
@@ -155,9 +148,7 @@ class MultimodalTrainerV2:
155
  else:
156
  fusion_params.append(param)
157
 
158
- # Determine base LR
159
- base_lr = self.config.lora_learning_rate # default: 5e-3
160
-
161
  groups = []
162
  if visual_params:
163
  vlr = base_lr * visual_lr_multiplier
@@ -170,24 +161,15 @@ class MultimodalTrainerV2:
170
  if fusion_params:
171
  groups.append({"params": fusion_params, "lr": base_lr, "name": "fusion_heads"})
172
  print(f"[Trainer v2] fusion_heads: {len(fusion_params)} tensors, lr={base_lr:.2e}")
173
-
174
  if not groups:
175
  raise ValueError("No trainable parameters!")
176
  return groups
177
 
178
  def train_epoch(self, epoch):
179
- """Train one epoch with OGM-GE gradient modulation."""
180
  self.model.train()
181
- total_loss = 0.0
182
- total_loss_fusion = 0.0
183
- total_loss_visual = 0.0
184
- total_loss_audio = 0.0
185
  num_batches = 0
186
-
187
- # OGM-GE stats accumulators
188
- ogm_v_confs, ogm_a_confs = [], []
189
- ogm_cv, ogm_ca = [], []
190
-
191
  self.optimizer.zero_grad()
192
 
193
  for batch_idx, batch in enumerate(self.train_loader):
@@ -195,7 +177,6 @@ class MultimodalTrainerV2:
195
  av = batch["audio_values"].to(self.device)
196
  labels = batch["labels"].to(self.device)
197
 
198
- # Forward pass
199
  if self.scaler:
200
  with torch.amp.autocast("cuda"):
201
  outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
@@ -206,7 +187,6 @@ class MultimodalTrainerV2:
206
  loss = outputs["loss"] / self.config.gradient_accumulation_steps
207
  loss.backward()
208
 
209
- # Accumulate losses
210
  total_loss += loss.item() * self.config.gradient_accumulation_steps
211
  num_batches += 1
212
  if "loss_fusion" in outputs:
@@ -214,7 +194,6 @@ class MultimodalTrainerV2:
214
  total_loss_visual += outputs["loss_visual"]
215
  total_loss_audio += outputs["loss_audio"]
216
 
217
- # Collect OGM-GE stats every batch (but only apply at accumulation boundary)
218
  if (self.use_ogm and self.ogm is not None
219
  and "visual_logits" in outputs and "audio_logits" in outputs):
220
  _cv, _ca, _stats = self.ogm.compute_modulation_coefficients(
@@ -224,20 +203,12 @@ class MultimodalTrainerV2:
224
  ogm_cv.append(_stats["coeff_visual"])
225
  ogm_ca.append(_stats["coeff_audio"])
226
 
227
- # Optimizer step (at accumulation boundary)
228
  if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
229
  if self.scaler:
230
  self.scaler.unscale_(self.optimizer)
231
-
232
- # ==== OGM-GE: modulate gradients AFTER unscale, BEFORE step ====
233
  if (self.use_ogm and self.ogm is not None and ogm_cv):
234
- # Use the most recent coefficients
235
- self.ogm.apply_gradient_modulation(
236
- self.model, ogm_cv[-1], ogm_ca[-1])
237
-
238
- torch.nn.utils.clip_grad_norm_(
239
- self.model.parameters(), self.config.max_grad_norm)
240
-
241
  if self.scaler:
242
  self.scaler.step(self.optimizer)
243
  self.scaler.update()
@@ -246,102 +217,72 @@ class MultimodalTrainerV2:
246
  self.scheduler.step()
247
  self.optimizer.zero_grad()
248
 
249
- # Logging
250
  if (batch_idx + 1) % self.config.logging_steps == 0 or batch_idx == 0:
251
  avg_loss = total_loss / num_batches
252
  msg = (f" [Epoch {epoch+1}] Step {batch_idx+1}/{len(self.train_loader)} "
253
- f"| Loss: {avg_loss:.4f} "
254
- f"| LR_v: {self.optimizer.param_groups[0]['lr']:.2e}")
255
  if "loss_fusion" in outputs:
256
  msg += (f" | L_fus: {total_loss_fusion/num_batches:.4f}"
257
  f" L_vis: {total_loss_visual/num_batches:.4f}"
258
  f" L_aud: {total_loss_audio/num_batches:.4f}")
259
  if ogm_cv:
260
- msg += (f" | OGM c_v: {ogm_cv[-1]:.3f}"
261
- f" c_a: {ogm_ca[-1]:.3f}")
262
  print(msg)
263
 
264
- # Epoch-level OGM stats
265
  n = max(num_batches, 1)
266
- epoch_stats = {
267
- "train_loss": total_loss / n,
268
- "loss_fusion": total_loss_fusion / n,
269
- "loss_visual": total_loss_visual / n,
270
- "loss_audio": total_loss_audio / n,
271
- }
272
  if ogm_v_confs:
273
- epoch_stats["ogm_visual_conf"] = np.mean(ogm_v_confs)
274
- epoch_stats["ogm_audio_conf"] = np.mean(ogm_a_confs)
275
- epoch_stats["ogm_coeff_visual"] = np.mean(ogm_cv)
276
- epoch_stats["ogm_coeff_audio"] = np.mean(ogm_ca)
277
-
278
  return epoch_stats
279
 
280
  @torch.no_grad()
281
  def evaluate(self, modality_mask=None):
282
- """Evaluate on validation set. Optionally mask a modality for robustness test."""
283
  self.model.eval()
284
- all_preds, all_labels = [], []
285
- total_loss = 0.0
286
- num_batches = 0
287
-
288
  for batch in self.val_loader:
289
  pv = batch["pixel_values"].to(self.device)
290
  av = batch["audio_values"].to(self.device)
291
  labels = batch["labels"].to(self.device)
292
-
293
  if modality_mask:
294
- if modality_mask.get("visual", 1.0) == 0.0:
295
- pv = torch.zeros_like(pv)
296
- if modality_mask.get("audio", 1.0) == 0.0:
297
- av = torch.zeros_like(av)
298
-
299
  outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
300
  total_loss += outputs["loss"].item()
301
  num_batches += 1
302
  all_preds.extend(outputs["logits"].argmax(dim=-1).cpu().numpy())
303
  all_labels.extend(labels.cpu().numpy())
304
-
305
  metrics = compute_metrics(np.array(all_preds), np.array(all_labels))
306
  metrics["val_loss"] = total_loss / max(num_batches, 1)
307
  return metrics
308
 
309
  def train(self):
310
- """Full training loop with OGM-GE and detailed logging."""
311
  print(f"\n{'='*60}")
312
  print(f"Training v2: mode={self.model.mode}, epochs={self.config.num_epochs}, "
313
  f"batch={self.config.per_device_train_batch_size}, device={self.device}")
314
  print(f"OGM-GE: {'ENABLED' if self.use_ogm else 'DISABLED'}")
315
  if self.model.mode == "multimodal":
316
- print(f"Auxiliary loss weights: λ_visual={self.model.lambda_visual}, "
317
- f"λ_audio={self.model.lambda_audio}")
318
  print(f"{'='*60}\n")
319
 
320
  for epoch in range(self.config.num_epochs):
321
  t0 = time.time()
322
  train_stats = self.train_epoch(epoch)
323
  val_metrics = self.evaluate()
324
-
325
- # Print epoch summary
326
  elapsed = time.time() - t0
 
327
  print(f"\n[Epoch {epoch+1}/{self.config.num_epochs}] ({elapsed:.1f}s)")
328
  loss_msg = f" Train Loss: {train_stats['train_loss']:.4f}"
329
  if train_stats.get("loss_fusion", 0) > 0:
330
  loss_msg += (f" (fusion={train_stats['loss_fusion']:.4f} "
331
- f"visual={train_stats['loss_visual']:.4f} "
332
- f"audio={train_stats['loss_audio']:.4f})")
333
  print(loss_msg)
334
- print(f" Val Loss: {val_metrics['val_loss']:.4f} "
335
- f"| Acc: {val_metrics['accuracy']:.4f} "
336
- f"| F1: {val_metrics['macro_f1']:.4f}")
337
-
338
  if "ogm_visual_conf" in train_stats:
339
- print(f" OGM-GE: visual_conf={train_stats['ogm_visual_conf']:.4f} "
340
- f"audio_conf={train_stats['ogm_audio_conf']:.4f} "
341
- f"| coeff_v={train_stats['ogm_coeff_visual']:.4f} "
342
- f"coeff_a={train_stats['ogm_coeff_audio']:.4f}")
343
 
344
- # Update history
345
  self.history["train_loss"].append(train_stats["train_loss"])
346
  self.history["val_loss"].append(val_metrics["val_loss"])
347
  self.history["val_accuracy"].append(val_metrics["accuracy"])
@@ -355,52 +296,35 @@ class MultimodalTrainerV2:
355
  self.history["ogm_coeff_visual"].append(train_stats["ogm_coeff_visual"])
356
  self.history["ogm_coeff_audio"].append(train_stats["ogm_coeff_audio"])
357
 
358
- # Save best model
359
  if val_metrics[self.config.metric_for_best_model] > self.best_metric:
360
  self.best_metric = val_metrics[self.config.metric_for_best_model]
361
  self.best_epoch = epoch + 1
362
  os.makedirs(self.config.output_dir, exist_ok=True)
363
- torch.save({
364
- "model_state_dict": self.model.state_dict(),
365
- "epoch": epoch + 1,
366
- "metrics": val_metrics,
367
- }, os.path.join(self.config.output_dir, "best_model.pt"))
368
  print(f" ✓ Best model saved (F1={self.best_metric:.4f})")
369
 
370
- print(f"\nTraining complete. Best epoch={self.best_epoch}, "
371
- f"Best F1={self.best_metric:.4f}")
372
  return self.history
373
 
374
  def run_robustness_evaluation(self):
375
- """Test with missing modalities to evaluate robustness."""
376
  print("\n=== Missing Modality Robustness Evaluation ===")
377
  results = {}
378
- scenarios = [
379
- ("both_modalities", None),
380
- ("visual_only", {"visual": 1.0, "audio": 0.0}),
381
- ("audio_only", {"visual": 0.0, "audio": 1.0}),
382
- ]
383
- for name, mask in scenarios:
384
  m = self.evaluate(modality_mask=mask)
385
  results[name] = {"accuracy": m["accuracy"], "macro_f1": m["macro_f1"]}
386
  print(f" {name:20s}: Acc={m['accuracy']:.4f} F1={m['macro_f1']:.4f}")
387
-
388
- # Per-class breakdown
389
  for cls, cls_m in m["per_class"].items():
390
- print(f" {cls:25s} P:{cls_m['precision']:.3f} "
391
- f"R:{cls_m['recall']:.3f} F1:{cls_m['f1']:.3f}")
392
-
393
- # Compute improvement vs v1 baseline if available
394
  print("\n [Target] Visual-only should improve from ~0.23 acc / 0.08 F1 (v1)")
395
  return results
396
 
397
 
398
  def main():
399
  parser = argparse.ArgumentParser(description="Multimodal PC Fault Detection Training v2")
400
- parser.add_argument("--mode", default="multimodal",
401
- choices=["multimodal", "visual_only", "audio_only"])
402
- parser.add_argument("--finetune", default="lora",
403
- choices=["lora", "full", "linear_probe"])
404
  parser.add_argument("--epochs", type=int)
405
  parser.add_argument("--batch_size", type=int)
406
  parser.add_argument("--lr", type=float)
@@ -412,21 +336,24 @@ def main():
412
  parser.add_argument("--eval_robustness", action="store_true")
413
  parser.add_argument("--quick_test", action="store_true")
414
 
 
 
 
 
 
 
 
 
 
 
415
  # v2-specific arguments
416
- parser.add_argument("--no_ogm", action="store_true",
417
- help="Disable OGM-GE gradient modulation")
418
- parser.add_argument("--ogm_alpha", type=float, default=None,
419
- help="OGM-GE modulation strength (default from config)")
420
- parser.add_argument("--ogm_noise_sigma", type=float, default=None,
421
- help="OGM-GE noise sigma (default from config)")
422
- parser.add_argument("--lambda_visual", type=float, default=None,
423
- help="Visual auxiliary loss weight (default from config)")
424
- parser.add_argument("--lambda_audio", type=float, default=None,
425
- help="Audio auxiliary loss weight (default from config)")
426
- parser.add_argument("--visual_lr_mult", type=float, default=None,
427
- help="LR multiplier for visual branch (default from config)")
428
- parser.add_argument("--audio_lr_mult", type=float, default=None,
429
- help="LR multiplier for audio branch (default from config)")
430
 
431
  args = parser.parse_args()
432
 
@@ -437,84 +364,72 @@ def main():
437
  config.train.finetune_method = args.finetune
438
  config.model.fusion_type = args.fusion
439
 
440
- if args.epochs:
441
- config.train.num_epochs = args.epochs
442
- if args.batch_size:
443
- config.train.per_device_train_batch_size = args.batch_size
444
- if args.lr:
445
- config.train.learning_rate = args.lr
446
- config.train.lora_learning_rate = args.lr
447
- if args.modality_dropout is not None:
448
- config.model.modality_dropout_p = args.modality_dropout
449
- if args.output_dir:
450
- config.train.output_dir = args.output_dir
451
- if args.hub_model_id:
452
- config.train.hub_model_id = args.hub_model_id
453
- if args.no_push:
454
- config.train.push_to_hub = False
455
  if args.quick_test:
456
- config.train.num_epochs = 2
457
- config.train.per_device_train_batch_size = 4
458
- config.train.per_device_eval_batch_size = 4
459
- config.train.gradient_accumulation_steps = 1
460
  config.train.logging_steps = 2
461
- if args.finetune != "lora":
462
- config.lora.enabled = False
463
 
464
- # v2 hyperparameters from config (with CLI overrides)
465
  ogm_alpha = args.ogm_alpha if args.ogm_alpha is not None else config.ogm_alpha
466
- ogm_noise_sigma = (args.ogm_noise_sigma if args.ogm_noise_sigma is not None
467
- else config.ogm_noise_sigma)
468
- lambda_visual = (args.lambda_visual if args.lambda_visual is not None
469
- else config.lambda_visual)
470
- lambda_audio = (args.lambda_audio if args.lambda_audio is not None
471
- else config.lambda_audio)
472
- visual_lr_mult = (args.visual_lr_mult if args.visual_lr_mult is not None
473
- else config.visual_lr_multiplier)
474
- audio_lr_mult = (args.audio_lr_mult if args.audio_lr_mult is not None
475
- else config.audio_lr_multiplier)
476
  use_ogm = not args.no_ogm
477
 
478
- # Device and seeds
479
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
480
  torch.manual_seed(config.train.seed)
481
  np.random.seed(config.train.seed)
482
- if torch.cuda.is_available():
483
- torch.cuda.manual_seed_all(config.train.seed)
484
 
485
  print(f"\n{'='*60}")
486
  print(f"Multimodal PC Fault Detection v2")
487
  print(f"{'='*60}")
488
  print(f"Mode: {args.mode} | Finetune: {args.finetune} | Device: {device}")
 
 
489
  print(f"OGM-GE: {'ON' if use_ogm else 'OFF'} (alpha={ogm_alpha}, sigma={ogm_noise_sigma})")
490
  print(f"Aux loss weights: λ_visual={lambda_visual}, λ_audio={lambda_audio}")
491
  print(f"LR multipliers: visual={visual_lr_mult}x, audio={audio_lr_mult}x")
492
  print(f"{'='*60}\n")
493
 
494
- # Load processors and dataset
495
  vit_proc, ast_ext = get_processors(config.model)
496
- train_ds = PCFaultDataset(
497
- config.data, config.model, "train", vit_proc, ast_ext, True)
498
- val_ds = PCFaultDataset(
499
- config.data, config.model, "val", vit_proc, ast_ext, False)
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
  # Create model
502
- model = create_model(
503
- config.model, config.lora,
504
- mode=args.mode,
505
- finetune_method=args.finetune,
506
- use_ogm=use_ogm,
507
- lambda_visual=lambda_visual,
508
- lambda_audio=lambda_audio)
509
 
510
  # Create trainer
511
  trainer = MultimodalTrainerV2(
512
  model, train_ds, val_ds, config.train, device,
513
- use_ogm=use_ogm,
514
- ogm_alpha=ogm_alpha,
515
- ogm_noise_sigma=ogm_noise_sigma,
516
- visual_lr_multiplier=visual_lr_mult,
517
- audio_lr_multiplier=audio_lr_mult)
518
 
519
  # Train
520
  history = trainer.train()
@@ -524,10 +439,8 @@ def main():
524
  print(f"\nFinal Evaluation:")
525
  print(f" Acc={final['accuracy']:.4f} F1={final['macro_f1']:.4f}")
526
  for cls, m in final["per_class"].items():
527
- print(f" {cls:25s} P:{m['precision']:.3f} R:{m['recall']:.3f} "
528
- f"F1:{m['f1']:.3f} N:{m['support']}")
529
 
530
- # Robustness evaluation
531
  robustness_results = None
532
  if args.eval_robustness and config.train.mode == "multimodal":
533
  robustness_results = trainer.run_robustness_evaluation()
@@ -535,47 +448,33 @@ def main():
535
  # Save results
536
  os.makedirs(config.train.output_dir, exist_ok=True)
537
  results = {
538
- "experiment": config.experiment_name,
539
- "version": "v2",
540
- "mode": config.train.mode,
541
- "finetune_method": config.train.finetune_method,
542
  "anti_collapse_config": {
543
- "ogm_ge": use_ogm,
544
- "ogm_alpha": ogm_alpha,
545
- "ogm_noise_sigma": ogm_noise_sigma,
546
- "lambda_visual": lambda_visual,
547
- "lambda_audio": lambda_audio,
548
- "visual_lr_multiplier": visual_lr_mult,
549
- "audio_lr_multiplier": audio_lr_mult,
550
  },
551
  "final_metrics": {
552
- "accuracy": final["accuracy"],
553
- "macro_f1": final["macro_f1"],
554
- "weighted_f1": final["weighted_f1"],
555
- "per_class": final["per_class"],
556
  "confusion_matrix": final["confusion_matrix"],
557
  },
558
- "history": history,
559
- "best_epoch": trainer.best_epoch,
560
- "best_metric": trainer.best_metric,
561
  }
562
- if robustness_results:
563
- results["robustness"] = robustness_results
564
 
565
  with open(os.path.join(config.train.output_dir, "results_v2.json"), "w") as f:
566
  json.dump(results, f, indent=2)
567
  print(f"\nResults saved to {config.train.output_dir}/results_v2.json")
568
 
569
- # Push to hub
570
  if config.train.push_to_hub:
571
  try:
572
  from huggingface_hub import HfApi, login
573
  login(token=os.environ.get("HF_TOKEN"))
574
- HfApi().upload_folder(
575
- folder_path=config.train.output_dir,
576
- repo_id=config.train.hub_model_id,
577
- repo_type="model",
578
- commit_message=f"Training v2: {config.experiment_name} (OGM-GE)")
579
  print(f"✓ Pushed to https://huggingface.co/{config.train.hub_model_id}")
580
  except Exception as e:
581
  print(f"✗ Push failed: {e}")
 
6
  - Asymmetric learning rates: higher for visual branch, lower for audio
7
  - Auxiliary loss logging (loss_fusion, loss_visual, loss_audio per epoch)
8
  - OGM-GE stats logging (visual_conf, audio_conf, modulation coefficients)
9
+ - Supports both old proxy data (dataset_real) and new built data (dataset_v2)
10
 
11
  Usage:
12
+ # With old proxy data (ToyADMOS + MVTec, default)
13
  python train_v2.py --mode multimodal --finetune lora --eval_robustness
14
+
15
+ # With new built dataset (from build_dataset.py)
16
+ python train_v2.py --dataset local --dataset_dir ../data/dataset_build --eval_robustness
17
+ python train_v2.py --dataset hub --hub_dataset Ellaft/pc-fault-real-dataset
18
+
19
+ # Other options
20
  python train_v2.py --mode visual_only --finetune lora --no_push
 
 
21
  python train_v2.py --quick_test --no_push
22
 
23
  References:
 
35
  from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support
36
 
37
  from config import ExperimentConfig, FAULT_CLASSES, NUM_CLASSES
 
38
  from models_v2 import create_model, get_processors, OGMGEModulator
39
 
40
 
 
77
 
78
  def __init__(self, model, train_dataset, val_dataset, config, device,
79
  use_ogm=True, ogm_alpha=0.3, ogm_noise_sigma=0.1,
80
+ visual_lr_multiplier=3.0, audio_lr_multiplier=0.5,
81
+ collate_fn=None):
82
  self.model = model.to(device)
83
  self.device = device
84
  self.config = config
 
96
  train_dataset,
97
  batch_size=config.per_device_train_batch_size,
98
  shuffle=True,
99
+ collate_fn=collate_fn,
100
  num_workers=2,
101
  pin_memory=True,
102
  drop_last=True)
 
104
  val_dataset,
105
  batch_size=config.per_device_eval_batch_size,
106
  shuffle=False,
107
+ collate_fn=collate_fn,
108
  num_workers=2,
109
  pin_memory=True)
110
 
 
131
  self.history = {
132
  "train_loss": [], "val_loss": [],
133
  "val_accuracy": [], "val_macro_f1": [],
 
134
  "train_loss_fusion": [], "train_loss_visual": [], "train_loss_audio": [],
135
  "ogm_visual_conf": [], "ogm_audio_conf": [],
136
  "ogm_coeff_visual": [], "ogm_coeff_audio": [],
137
  }
138
 
139
  def _get_param_groups(self, visual_lr_multiplier, audio_lr_multiplier):
140
+ visual_params, audio_params, fusion_params = [], [], []
 
 
 
 
 
 
 
 
 
 
 
141
  for name, param in self.model.named_parameters():
142
  if not param.requires_grad:
143
  continue
 
148
  else:
149
  fusion_params.append(param)
150
 
151
+ base_lr = self.config.lora_learning_rate
 
 
152
  groups = []
153
  if visual_params:
154
  vlr = base_lr * visual_lr_multiplier
 
161
  if fusion_params:
162
  groups.append({"params": fusion_params, "lr": base_lr, "name": "fusion_heads"})
163
  print(f"[Trainer v2] fusion_heads: {len(fusion_params)} tensors, lr={base_lr:.2e}")
 
164
  if not groups:
165
  raise ValueError("No trainable parameters!")
166
  return groups
167
 
168
  def train_epoch(self, epoch):
 
169
  self.model.train()
170
+ total_loss, total_loss_fusion, total_loss_visual, total_loss_audio = 0.0, 0.0, 0.0, 0.0
 
 
 
171
  num_batches = 0
172
+ ogm_v_confs, ogm_a_confs, ogm_cv, ogm_ca = [], [], [], []
 
 
 
 
173
  self.optimizer.zero_grad()
174
 
175
  for batch_idx, batch in enumerate(self.train_loader):
 
177
  av = batch["audio_values"].to(self.device)
178
  labels = batch["labels"].to(self.device)
179
 
 
180
  if self.scaler:
181
  with torch.amp.autocast("cuda"):
182
  outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
 
187
  loss = outputs["loss"] / self.config.gradient_accumulation_steps
188
  loss.backward()
189
 
 
190
  total_loss += loss.item() * self.config.gradient_accumulation_steps
191
  num_batches += 1
192
  if "loss_fusion" in outputs:
 
194
  total_loss_visual += outputs["loss_visual"]
195
  total_loss_audio += outputs["loss_audio"]
196
 
 
197
  if (self.use_ogm and self.ogm is not None
198
  and "visual_logits" in outputs and "audio_logits" in outputs):
199
  _cv, _ca, _stats = self.ogm.compute_modulation_coefficients(
 
203
  ogm_cv.append(_stats["coeff_visual"])
204
  ogm_ca.append(_stats["coeff_audio"])
205
 
 
206
  if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
207
  if self.scaler:
208
  self.scaler.unscale_(self.optimizer)
 
 
209
  if (self.use_ogm and self.ogm is not None and ogm_cv):
210
+ self.ogm.apply_gradient_modulation(self.model, ogm_cv[-1], ogm_ca[-1])
211
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
 
 
 
 
 
212
  if self.scaler:
213
  self.scaler.step(self.optimizer)
214
  self.scaler.update()
 
217
  self.scheduler.step()
218
  self.optimizer.zero_grad()
219
 
 
220
  if (batch_idx + 1) % self.config.logging_steps == 0 or batch_idx == 0:
221
  avg_loss = total_loss / num_batches
222
  msg = (f" [Epoch {epoch+1}] Step {batch_idx+1}/{len(self.train_loader)} "
223
+ f"| Loss: {avg_loss:.4f} | LR_v: {self.optimizer.param_groups[0]['lr']:.2e}")
 
224
  if "loss_fusion" in outputs:
225
  msg += (f" | L_fus: {total_loss_fusion/num_batches:.4f}"
226
  f" L_vis: {total_loss_visual/num_batches:.4f}"
227
  f" L_aud: {total_loss_audio/num_batches:.4f}")
228
  if ogm_cv:
229
+ msg += f" | OGM c_v: {ogm_cv[-1]:.3f} c_a: {ogm_ca[-1]:.3f}"
 
230
  print(msg)
231
 
 
232
  n = max(num_batches, 1)
233
+ epoch_stats = {"train_loss": total_loss / n, "loss_fusion": total_loss_fusion / n,
234
+ "loss_visual": total_loss_visual / n, "loss_audio": total_loss_audio / n}
 
 
 
 
235
  if ogm_v_confs:
236
+ epoch_stats.update({"ogm_visual_conf": np.mean(ogm_v_confs), "ogm_audio_conf": np.mean(ogm_a_confs),
237
+ "ogm_coeff_visual": np.mean(ogm_cv), "ogm_coeff_audio": np.mean(ogm_ca)})
 
 
 
238
  return epoch_stats
239
 
240
  @torch.no_grad()
241
  def evaluate(self, modality_mask=None):
 
242
  self.model.eval()
243
+ all_preds, all_labels, total_loss, num_batches = [], [], 0.0, 0
 
 
 
244
  for batch in self.val_loader:
245
  pv = batch["pixel_values"].to(self.device)
246
  av = batch["audio_values"].to(self.device)
247
  labels = batch["labels"].to(self.device)
 
248
  if modality_mask:
249
+ if modality_mask.get("visual", 1.0) == 0.0: pv = torch.zeros_like(pv)
250
+ if modality_mask.get("audio", 1.0) == 0.0: av = torch.zeros_like(av)
 
 
 
251
  outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
252
  total_loss += outputs["loss"].item()
253
  num_batches += 1
254
  all_preds.extend(outputs["logits"].argmax(dim=-1).cpu().numpy())
255
  all_labels.extend(labels.cpu().numpy())
 
256
  metrics = compute_metrics(np.array(all_preds), np.array(all_labels))
257
  metrics["val_loss"] = total_loss / max(num_batches, 1)
258
  return metrics
259
 
260
  def train(self):
 
261
  print(f"\n{'='*60}")
262
  print(f"Training v2: mode={self.model.mode}, epochs={self.config.num_epochs}, "
263
  f"batch={self.config.per_device_train_batch_size}, device={self.device}")
264
  print(f"OGM-GE: {'ENABLED' if self.use_ogm else 'DISABLED'}")
265
  if self.model.mode == "multimodal":
266
+ print(f"Auxiliary loss weights: λ_visual={self.model.lambda_visual}, λ_audio={self.model.lambda_audio}")
 
267
  print(f"{'='*60}\n")
268
 
269
  for epoch in range(self.config.num_epochs):
270
  t0 = time.time()
271
  train_stats = self.train_epoch(epoch)
272
  val_metrics = self.evaluate()
 
 
273
  elapsed = time.time() - t0
274
+
275
  print(f"\n[Epoch {epoch+1}/{self.config.num_epochs}] ({elapsed:.1f}s)")
276
  loss_msg = f" Train Loss: {train_stats['train_loss']:.4f}"
277
  if train_stats.get("loss_fusion", 0) > 0:
278
  loss_msg += (f" (fusion={train_stats['loss_fusion']:.4f} "
279
+ f"visual={train_stats['loss_visual']:.4f} audio={train_stats['loss_audio']:.4f})")
 
280
  print(loss_msg)
281
+ print(f" Val Loss: {val_metrics['val_loss']:.4f} | Acc: {val_metrics['accuracy']:.4f} | F1: {val_metrics['macro_f1']:.4f}")
 
 
 
282
  if "ogm_visual_conf" in train_stats:
283
+ print(f" OGM-GE: visual_conf={train_stats['ogm_visual_conf']:.4f} audio_conf={train_stats['ogm_audio_conf']:.4f} "
284
+ f"| coeff_v={train_stats['ogm_coeff_visual']:.4f} coeff_a={train_stats['ogm_coeff_audio']:.4f}")
 
 
285
 
 
286
  self.history["train_loss"].append(train_stats["train_loss"])
287
  self.history["val_loss"].append(val_metrics["val_loss"])
288
  self.history["val_accuracy"].append(val_metrics["accuracy"])
 
296
  self.history["ogm_coeff_visual"].append(train_stats["ogm_coeff_visual"])
297
  self.history["ogm_coeff_audio"].append(train_stats["ogm_coeff_audio"])
298
 
 
299
  if val_metrics[self.config.metric_for_best_model] > self.best_metric:
300
  self.best_metric = val_metrics[self.config.metric_for_best_model]
301
  self.best_epoch = epoch + 1
302
  os.makedirs(self.config.output_dir, exist_ok=True)
303
+ torch.save({"model_state_dict": self.model.state_dict(), "epoch": epoch + 1,
304
+ "metrics": val_metrics}, os.path.join(self.config.output_dir, "best_model.pt"))
 
 
 
305
  print(f" ✓ Best model saved (F1={self.best_metric:.4f})")
306
 
307
+ print(f"\nTraining complete. Best epoch={self.best_epoch}, Best F1={self.best_metric:.4f}")
 
308
  return self.history
309
 
310
  def run_robustness_evaluation(self):
 
311
  print("\n=== Missing Modality Robustness Evaluation ===")
312
  results = {}
313
+ for name, mask in [("both_modalities", None), ("visual_only", {"visual": 1.0, "audio": 0.0}),
314
+ ("audio_only", {"visual": 0.0, "audio": 1.0})]:
 
 
 
 
315
  m = self.evaluate(modality_mask=mask)
316
  results[name] = {"accuracy": m["accuracy"], "macro_f1": m["macro_f1"]}
317
  print(f" {name:20s}: Acc={m['accuracy']:.4f} F1={m['macro_f1']:.4f}")
 
 
318
  for cls, cls_m in m["per_class"].items():
319
+ print(f" {cls:25s} P:{cls_m['precision']:.3f} R:{cls_m['recall']:.3f} F1:{cls_m['f1']:.3f}")
 
 
 
320
  print("\n [Target] Visual-only should improve from ~0.23 acc / 0.08 F1 (v1)")
321
  return results
322
 
323
 
324
  def main():
325
  parser = argparse.ArgumentParser(description="Multimodal PC Fault Detection Training v2")
326
+ parser.add_argument("--mode", default="multimodal", choices=["multimodal", "visual_only", "audio_only"])
327
+ parser.add_argument("--finetune", default="lora", choices=["lora", "full", "linear_probe"])
 
 
328
  parser.add_argument("--epochs", type=int)
329
  parser.add_argument("--batch_size", type=int)
330
  parser.add_argument("--lr", type=float)
 
336
  parser.add_argument("--eval_robustness", action="store_true")
337
  parser.add_argument("--quick_test", action="store_true")
338
 
339
+ # Dataset selection
340
+ parser.add_argument("--dataset", default="proxy",
341
+ choices=["proxy", "local", "hub"],
342
+ help="Dataset source: 'proxy' (ToyADMOS+MVTec, default), "
343
+ "'local' (build_dataset.py output), 'hub' (HF Hub dataset)")
344
+ parser.add_argument("--dataset_dir", default="./dataset_build",
345
+ help="Path to build_dataset.py output (for --dataset local)")
346
+ parser.add_argument("--hub_dataset", default="Ellaft/pc-fault-real-dataset",
347
+ help="HuggingFace dataset ID (for --dataset hub)")
348
+
349
  # v2-specific arguments
350
+ parser.add_argument("--no_ogm", action="store_true")
351
+ parser.add_argument("--ogm_alpha", type=float, default=None)
352
+ parser.add_argument("--ogm_noise_sigma", type=float, default=None)
353
+ parser.add_argument("--lambda_visual", type=float, default=None)
354
+ parser.add_argument("--lambda_audio", type=float, default=None)
355
+ parser.add_argument("--visual_lr_mult", type=float, default=None)
356
+ parser.add_argument("--audio_lr_mult", type=float, default=None)
 
 
 
 
 
 
 
357
 
358
  args = parser.parse_args()
359
 
 
364
  config.train.finetune_method = args.finetune
365
  config.model.fusion_type = args.fusion
366
 
367
+ if args.epochs: config.train.num_epochs = args.epochs
368
+ if args.batch_size: config.train.per_device_train_batch_size = args.batch_size
369
+ if args.lr: config.train.learning_rate = config.train.lora_learning_rate = args.lr
370
+ if args.modality_dropout is not None: config.model.modality_dropout_p = args.modality_dropout
371
+ if args.output_dir: config.train.output_dir = args.output_dir
372
+ if args.hub_model_id: config.train.hub_model_id = args.hub_model_id
373
+ if args.no_push: config.train.push_to_hub = False
 
 
 
 
 
 
 
 
374
  if args.quick_test:
375
+ config.train.num_epochs, config.train.per_device_train_batch_size = 2, 4
376
+ config.train.per_device_eval_batch_size, config.train.gradient_accumulation_steps = 4, 1
 
 
377
  config.train.logging_steps = 2
378
+ if args.finetune != "lora": config.lora.enabled = False
 
379
 
 
380
  ogm_alpha = args.ogm_alpha if args.ogm_alpha is not None else config.ogm_alpha
381
+ ogm_noise_sigma = args.ogm_noise_sigma if args.ogm_noise_sigma is not None else config.ogm_noise_sigma
382
+ lambda_visual = args.lambda_visual if args.lambda_visual is not None else config.lambda_visual
383
+ lambda_audio = args.lambda_audio if args.lambda_audio is not None else config.lambda_audio
384
+ visual_lr_mult = args.visual_lr_mult if args.visual_lr_mult is not None else config.visual_lr_multiplier
385
+ audio_lr_mult = args.audio_lr_mult if args.audio_lr_mult is not None else config.audio_lr_multiplier
 
 
 
 
 
386
  use_ogm = not args.no_ogm
387
 
 
388
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
389
  torch.manual_seed(config.train.seed)
390
  np.random.seed(config.train.seed)
391
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(config.train.seed)
 
392
 
393
  print(f"\n{'='*60}")
394
  print(f"Multimodal PC Fault Detection v2")
395
  print(f"{'='*60}")
396
  print(f"Mode: {args.mode} | Finetune: {args.finetune} | Device: {device}")
397
+ print(f"Dataset: {args.dataset}" + (f" ({args.dataset_dir})" if args.dataset == "local" else
398
+ f" ({args.hub_dataset})" if args.dataset == "hub" else " (ToyADMOS + MVTec proxy)"))
399
  print(f"OGM-GE: {'ON' if use_ogm else 'OFF'} (alpha={ogm_alpha}, sigma={ogm_noise_sigma})")
400
  print(f"Aux loss weights: λ_visual={lambda_visual}, λ_audio={lambda_audio}")
401
  print(f"LR multipliers: visual={visual_lr_mult}x, audio={audio_lr_mult}x")
402
  print(f"{'='*60}\n")
403
 
404
+ # Load processors
405
  vit_proc, ast_ext = get_processors(config.model)
406
+
407
+ # ---- Load dataset based on --dataset flag ----
408
+ if args.dataset in ("local", "hub"):
409
+ from dataset_v2 import BuiltDataset as PCFaultDataset, multimodal_collate_fn
410
+ source = args.dataset # "local" or "hub"
411
+ train_ds = PCFaultDataset(
412
+ config.data, config.model, "train", vit_proc, ast_ext, True,
413
+ source=source, dataset_dir=args.dataset_dir, hub_dataset=args.hub_dataset)
414
+ val_ds = PCFaultDataset(
415
+ config.data, config.model, "val", vit_proc, ast_ext, False,
416
+ source=source, dataset_dir=args.dataset_dir, hub_dataset=args.hub_dataset)
417
+ else:
418
+ # Default: old proxy data (ToyADMOS + MVTec)
419
+ from dataset_real import RealPCFaultDataset as PCFaultDataset, multimodal_collate_fn
420
+ train_ds = PCFaultDataset(config.data, config.model, "train", vit_proc, ast_ext, True)
421
+ val_ds = PCFaultDataset(config.data, config.model, "val", vit_proc, ast_ext, False)
422
 
423
  # Create model
424
+ model = create_model(config.model, config.lora, mode=args.mode, finetune_method=args.finetune,
425
+ use_ogm=use_ogm, lambda_visual=lambda_visual, lambda_audio=lambda_audio)
 
 
 
 
 
426
 
427
  # Create trainer
428
  trainer = MultimodalTrainerV2(
429
  model, train_ds, val_ds, config.train, device,
430
+ use_ogm=use_ogm, ogm_alpha=ogm_alpha, ogm_noise_sigma=ogm_noise_sigma,
431
+ visual_lr_multiplier=visual_lr_mult, audio_lr_multiplier=audio_lr_mult,
432
+ collate_fn=multimodal_collate_fn)
 
 
433
 
434
  # Train
435
  history = trainer.train()
 
439
  print(f"\nFinal Evaluation:")
440
  print(f" Acc={final['accuracy']:.4f} F1={final['macro_f1']:.4f}")
441
  for cls, m in final["per_class"].items():
442
+ print(f" {cls:25s} P:{m['precision']:.3f} R:{m['recall']:.3f} F1:{m['f1']:.3f} N:{m['support']}")
 
443
 
 
444
  robustness_results = None
445
  if args.eval_robustness and config.train.mode == "multimodal":
446
  robustness_results = trainer.run_robustness_evaluation()
 
448
  # Save results
449
  os.makedirs(config.train.output_dir, exist_ok=True)
450
  results = {
451
+ "experiment": config.experiment_name, "version": "v2",
452
+ "mode": config.train.mode, "finetune_method": config.train.finetune_method,
453
+ "dataset_source": args.dataset,
 
454
  "anti_collapse_config": {
455
+ "ogm_ge": use_ogm, "ogm_alpha": ogm_alpha, "ogm_noise_sigma": ogm_noise_sigma,
456
+ "lambda_visual": lambda_visual, "lambda_audio": lambda_audio,
457
+ "visual_lr_multiplier": visual_lr_mult, "audio_lr_multiplier": audio_lr_mult,
 
 
 
 
458
  },
459
  "final_metrics": {
460
+ "accuracy": final["accuracy"], "macro_f1": final["macro_f1"],
461
+ "weighted_f1": final["weighted_f1"], "per_class": final["per_class"],
 
 
462
  "confusion_matrix": final["confusion_matrix"],
463
  },
464
+ "history": history, "best_epoch": trainer.best_epoch, "best_metric": trainer.best_metric,
 
 
465
  }
466
+ if robustness_results: results["robustness"] = robustness_results
 
467
 
468
  with open(os.path.join(config.train.output_dir, "results_v2.json"), "w") as f:
469
  json.dump(results, f, indent=2)
470
  print(f"\nResults saved to {config.train.output_dir}/results_v2.json")
471
 
 
472
  if config.train.push_to_hub:
473
  try:
474
  from huggingface_hub import HfApi, login
475
  login(token=os.environ.get("HF_TOKEN"))
476
+ HfApi().upload_folder(folder_path=config.train.output_dir, repo_id=config.train.hub_model_id,
477
+ repo_type="model", commit_message=f"Training v2: {config.experiment_name} (OGM-GE)")
 
 
 
478
  print(f"✓ Pushed to https://huggingface.co/{config.train.hub_model_id}")
479
  except Exception as e:
480
  print(f"✗ Push failed: {e}")