Ellaft commited on
Commit
0c3f8a7
·
verified ·
1 Parent(s): 58a5fea

Add train_v2.py: OGM-GE training loop + asymmetric LRs + aux loss logging

Browse files
Files changed (1) hide show
  1. src/train_v2.py +585 -0
src/train_v2.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal PC Fault Detection - Training Script v2
3
+ ====================================================
4
+ Changes from v1:
5
+ - OGM-GE gradient modulation after loss.backward(), before optimizer.step()
6
+ - Asymmetric learning rates: higher for visual branch, lower for audio
7
+ - Auxiliary loss logging (loss_fusion, loss_visual, loss_audio per epoch)
8
+ - OGM-GE stats logging (visual_conf, audio_conf, modulation coefficients)
9
+ - Uses models_v2 (auxiliary heads) and dataset_real (real data)
10
+
11
+ Usage:
12
+ python train_v2.py --mode multimodal --finetune lora --eval_robustness
13
+ python train_v2.py --mode visual_only --finetune lora --no_push
14
+ python train_v2.py --mode audio_only --finetune lora --no_push
15
+ python train_v2.py --mode multimodal --finetune full --lr 2e-5
16
+ python train_v2.py --quick_test --no_push
17
+
18
+ References:
19
+ OGM-GE: Peng et al., "Balanced Multimodal Learning via On-the-fly Gradient
20
+ Modulation", CVPR 2022
21
+ """
22
+
23
+ import os, sys, json, argparse, time
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn as nn
27
+ from torch.utils.data import DataLoader
28
+ from torch.optim import AdamW
29
+ from torch.optim.lr_scheduler import OneCycleLR
30
+ from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support
31
+
32
+ from config import ExperimentConfig, FAULT_CLASSES, NUM_CLASSES
33
+ from dataset_real import RealPCFaultDataset as PCFaultDataset, multimodal_collate_fn
34
+ from models_v2 import create_model, get_processors, OGMGEModulator
35
+
36
+
37
+ def compute_metrics(preds, labels, class_names=FAULT_CLASSES):
38
+ """Compute accuracy, F1, precision, recall, and confusion matrix."""
39
+ accuracy = accuracy_score(labels, preds)
40
+ precision, recall, f1, support = precision_recall_fscore_support(
41
+ labels, preds, average=None, labels=range(len(class_names)), zero_division=0)
42
+ macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)
43
+ weighted_f1 = f1_score(labels, preds, average="weighted", zero_division=0)
44
+ conf_matrix = confusion_matrix(labels, preds, labels=range(len(class_names)))
45
+
46
+ metrics = {
47
+ "accuracy": accuracy,
48
+ "macro_f1": macro_f1,
49
+ "weighted_f1": weighted_f1,
50
+ "confusion_matrix": conf_matrix.tolist(),
51
+ "per_class": {},
52
+ }
53
+ for i, name in enumerate(class_names):
54
+ metrics["per_class"][name] = {
55
+ "precision": precision[i], "recall": recall[i],
56
+ "f1": f1[i], "support": int(support[i]),
57
+ }
58
+ return metrics
59
+
60
+
61
+ class MultimodalTrainerV2:
62
+ """
63
+ Training loop v2 with OGM-GE gradient modulation.
64
+
65
+ Key differences from v1:
66
+ 1. Three separate parameter groups with asymmetric LRs:
67
+ - visual_branch: higher LR (visual_lr_multiplier × base_lr)
68
+ - audio_branch: lower LR (audio_lr_multiplier × base_lr)
69
+ - fusion + auxiliary heads: base LR
70
+ 2. OGM-GE applied after backward(), before optimizer.step()
71
+ 3. Logs auxiliary losses and OGM-GE stats per epoch
72
+ """
73
+
74
+ def __init__(self, model, train_dataset, val_dataset, config, device,
75
+ use_ogm=True, ogm_alpha=0.3, ogm_noise_sigma=0.1,
76
+ visual_lr_multiplier=3.0, audio_lr_multiplier=0.5):
77
+ self.model = model.to(device)
78
+ self.device = device
79
+ self.config = config
80
+ self.use_ogm = use_ogm and (model.mode == "multimodal")
81
+
82
+ # OGM-GE modulator
83
+ if self.use_ogm:
84
+ self.ogm = OGMGEModulator(alpha=ogm_alpha, noise_sigma=ogm_noise_sigma)
85
+ print(f"[Trainer v2] OGM-GE enabled: alpha={ogm_alpha}, noise_sigma={ogm_noise_sigma}")
86
+ else:
87
+ self.ogm = None
88
+
89
+ # Data loaders
90
+ self.train_loader = DataLoader(
91
+ train_dataset,
92
+ batch_size=config.per_device_train_batch_size,
93
+ shuffle=True,
94
+ collate_fn=multimodal_collate_fn,
95
+ num_workers=2,
96
+ pin_memory=True,
97
+ drop_last=True)
98
+ self.val_loader = DataLoader(
99
+ val_dataset,
100
+ batch_size=config.per_device_eval_batch_size,
101
+ shuffle=False,
102
+ collate_fn=multimodal_collate_fn,
103
+ num_workers=2,
104
+ pin_memory=True)
105
+
106
+ # Asymmetric parameter groups
107
+ param_groups = self._get_param_groups(visual_lr_multiplier, audio_lr_multiplier)
108
+ self.optimizer = AdamW(param_groups, weight_decay=config.weight_decay)
109
+
110
+ total_steps = (len(self.train_loader) * config.num_epochs
111
+ // config.gradient_accumulation_steps)
112
+ self.scheduler = OneCycleLR(
113
+ self.optimizer,
114
+ max_lr=[pg["lr"] for pg in param_groups],
115
+ total_steps=max(total_steps, 1),
116
+ pct_start=config.warmup_ratio,
117
+ anneal_strategy="cos")
118
+
119
+ # Mixed precision
120
+ self.scaler = (torch.amp.GradScaler("cuda")
121
+ if config.fp16 and device.type == "cuda" else None)
122
+
123
+ # Tracking
124
+ self.best_metric = 0.0
125
+ self.best_epoch = 0
126
+ self.history = {
127
+ "train_loss": [], "val_loss": [],
128
+ "val_accuracy": [], "val_macro_f1": [],
129
+ # v2 additions
130
+ "train_loss_fusion": [], "train_loss_visual": [], "train_loss_audio": [],
131
+ "ogm_visual_conf": [], "ogm_audio_conf": [],
132
+ "ogm_coeff_visual": [], "ogm_coeff_audio": [],
133
+ }
134
+
135
+ def _get_param_groups(self, visual_lr_multiplier, audio_lr_multiplier):
136
+ """
137
+ Create 3 parameter groups with asymmetric learning rates.
138
+
139
+ For LoRA mode: uses lora_learning_rate as base.
140
+ Visual branch gets multiplier > 1 (boost weak modality).
141
+ Audio branch gets multiplier < 1 (slow down dominant modality).
142
+ Fusion head + auxiliary heads get base LR.
143
+ """
144
+ visual_params = []
145
+ audio_params = []
146
+ fusion_params = [] # fusion head + auxiliary heads
147
+
148
+ for name, param in self.model.named_parameters():
149
+ if not param.requires_grad:
150
+ continue
151
+ if "visual_branch" in name:
152
+ visual_params.append(param)
153
+ elif "audio_branch" in name:
154
+ audio_params.append(param)
155
+ else:
156
+ fusion_params.append(param)
157
+
158
+ # Determine base LR
159
+ base_lr = self.config.lora_learning_rate # default: 5e-3
160
+
161
+ groups = []
162
+ if visual_params:
163
+ vlr = base_lr * visual_lr_multiplier
164
+ groups.append({"params": visual_params, "lr": vlr, "name": "visual_branch"})
165
+ print(f"[Trainer v2] visual_branch: {len(visual_params)} tensors, lr={vlr:.2e}")
166
+ if audio_params:
167
+ alr = base_lr * audio_lr_multiplier
168
+ groups.append({"params": audio_params, "lr": alr, "name": "audio_branch"})
169
+ print(f"[Trainer v2] audio_branch: {len(audio_params)} tensors, lr={alr:.2e}")
170
+ if fusion_params:
171
+ groups.append({"params": fusion_params, "lr": base_lr, "name": "fusion_heads"})
172
+ print(f"[Trainer v2] fusion_heads: {len(fusion_params)} tensors, lr={base_lr:.2e}")
173
+
174
+ if not groups:
175
+ raise ValueError("No trainable parameters!")
176
+ return groups
177
+
178
+ def train_epoch(self, epoch):
179
+ """Train one epoch with OGM-GE gradient modulation."""
180
+ self.model.train()
181
+ total_loss = 0.0
182
+ total_loss_fusion = 0.0
183
+ total_loss_visual = 0.0
184
+ total_loss_audio = 0.0
185
+ num_batches = 0
186
+
187
+ # OGM-GE stats accumulators
188
+ ogm_v_confs, ogm_a_confs = [], []
189
+ ogm_cv, ogm_ca = [], []
190
+
191
+ self.optimizer.zero_grad()
192
+
193
+ for batch_idx, batch in enumerate(self.train_loader):
194
+ pv = batch["pixel_values"].to(self.device)
195
+ av = batch["audio_values"].to(self.device)
196
+ labels = batch["labels"].to(self.device)
197
+
198
+ # Forward pass
199
+ if self.scaler:
200
+ with torch.amp.autocast("cuda"):
201
+ outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
202
+ loss = outputs["loss"] / self.config.gradient_accumulation_steps
203
+ self.scaler.scale(loss).backward()
204
+ else:
205
+ outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
206
+ loss = outputs["loss"] / self.config.gradient_accumulation_steps
207
+ loss.backward()
208
+
209
+ # Accumulate losses
210
+ total_loss += loss.item() * self.config.gradient_accumulation_steps
211
+ num_batches += 1
212
+ if "loss_fusion" in outputs:
213
+ total_loss_fusion += outputs["loss_fusion"]
214
+ total_loss_visual += outputs["loss_visual"]
215
+ total_loss_audio += outputs["loss_audio"]
216
+
217
+ # Collect OGM-GE stats every batch (but only apply at accumulation boundary)
218
+ if (self.use_ogm and self.ogm is not None
219
+ and "visual_logits" in outputs and "audio_logits" in outputs):
220
+ _cv, _ca, _stats = self.ogm.compute_modulation_coefficients(
221
+ outputs["visual_logits"], outputs["audio_logits"], labels)
222
+ ogm_v_confs.append(_stats["visual_conf"])
223
+ ogm_a_confs.append(_stats["audio_conf"])
224
+ ogm_cv.append(_stats["coeff_visual"])
225
+ ogm_ca.append(_stats["coeff_audio"])
226
+
227
+ # Optimizer step (at accumulation boundary)
228
+ if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
229
+ if self.scaler:
230
+ self.scaler.unscale_(self.optimizer)
231
+
232
+ # ==== OGM-GE: modulate gradients AFTER unscale, BEFORE step ====
233
+ if (self.use_ogm and self.ogm is not None and ogm_cv):
234
+ # Use the most recent coefficients
235
+ self.ogm.apply_gradient_modulation(
236
+ self.model, ogm_cv[-1], ogm_ca[-1])
237
+
238
+ torch.nn.utils.clip_grad_norm_(
239
+ self.model.parameters(), self.config.max_grad_norm)
240
+
241
+ if self.scaler:
242
+ self.scaler.step(self.optimizer)
243
+ self.scaler.update()
244
+ else:
245
+ self.optimizer.step()
246
+ self.scheduler.step()
247
+ self.optimizer.zero_grad()
248
+
249
+ # Logging
250
+ if (batch_idx + 1) % self.config.logging_steps == 0 or batch_idx == 0:
251
+ avg_loss = total_loss / num_batches
252
+ msg = (f" [Epoch {epoch+1}] Step {batch_idx+1}/{len(self.train_loader)} "
253
+ f"| Loss: {avg_loss:.4f} "
254
+ f"| LR_v: {self.optimizer.param_groups[0]['lr']:.2e}")
255
+ if "loss_fusion" in outputs:
256
+ msg += (f" | L_fus: {total_loss_fusion/num_batches:.4f}"
257
+ f" L_vis: {total_loss_visual/num_batches:.4f}"
258
+ f" L_aud: {total_loss_audio/num_batches:.4f}")
259
+ if ogm_cv:
260
+ msg += (f" | OGM c_v: {ogm_cv[-1]:.3f}"
261
+ f" c_a: {ogm_ca[-1]:.3f}")
262
+ print(msg)
263
+
264
+ # Epoch-level OGM stats
265
+ n = max(num_batches, 1)
266
+ epoch_stats = {
267
+ "train_loss": total_loss / n,
268
+ "loss_fusion": total_loss_fusion / n,
269
+ "loss_visual": total_loss_visual / n,
270
+ "loss_audio": total_loss_audio / n,
271
+ }
272
+ if ogm_v_confs:
273
+ epoch_stats["ogm_visual_conf"] = np.mean(ogm_v_confs)
274
+ epoch_stats["ogm_audio_conf"] = np.mean(ogm_a_confs)
275
+ epoch_stats["ogm_coeff_visual"] = np.mean(ogm_cv)
276
+ epoch_stats["ogm_coeff_audio"] = np.mean(ogm_ca)
277
+
278
+ return epoch_stats
279
+
280
+ @torch.no_grad()
281
+ def evaluate(self, modality_mask=None):
282
+ """Evaluate on validation set. Optionally mask a modality for robustness test."""
283
+ self.model.eval()
284
+ all_preds, all_labels = [], []
285
+ total_loss = 0.0
286
+ num_batches = 0
287
+
288
+ for batch in self.val_loader:
289
+ pv = batch["pixel_values"].to(self.device)
290
+ av = batch["audio_values"].to(self.device)
291
+ labels = batch["labels"].to(self.device)
292
+
293
+ if modality_mask:
294
+ if modality_mask.get("visual", 1.0) == 0.0:
295
+ pv = torch.zeros_like(pv)
296
+ if modality_mask.get("audio", 1.0) == 0.0:
297
+ av = torch.zeros_like(av)
298
+
299
+ outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
300
+ total_loss += outputs["loss"].item()
301
+ num_batches += 1
302
+ all_preds.extend(outputs["logits"].argmax(dim=-1).cpu().numpy())
303
+ all_labels.extend(labels.cpu().numpy())
304
+
305
+ metrics = compute_metrics(np.array(all_preds), np.array(all_labels))
306
+ metrics["val_loss"] = total_loss / max(num_batches, 1)
307
+ return metrics
308
+
309
+ def train(self):
310
+ """Full training loop with OGM-GE and detailed logging."""
311
+ print(f"\n{'='*60}")
312
+ print(f"Training v2: mode={self.model.mode}, epochs={self.config.num_epochs}, "
313
+ f"batch={self.config.per_device_train_batch_size}, device={self.device}")
314
+ print(f"OGM-GE: {'ENABLED' if self.use_ogm else 'DISABLED'}")
315
+ if self.model.mode == "multimodal":
316
+ print(f"Auxiliary loss weights: λ_visual={self.model.lambda_visual}, "
317
+ f"λ_audio={self.model.lambda_audio}")
318
+ print(f"{'='*60}\n")
319
+
320
+ for epoch in range(self.config.num_epochs):
321
+ t0 = time.time()
322
+ train_stats = self.train_epoch(epoch)
323
+ val_metrics = self.evaluate()
324
+
325
+ # Print epoch summary
326
+ elapsed = time.time() - t0
327
+ print(f"\n[Epoch {epoch+1}/{self.config.num_epochs}] ({elapsed:.1f}s)")
328
+ loss_msg = f" Train Loss: {train_stats['train_loss']:.4f}"
329
+ if train_stats.get("loss_fusion", 0) > 0:
330
+ loss_msg += (f" (fusion={train_stats['loss_fusion']:.4f} "
331
+ f"visual={train_stats['loss_visual']:.4f} "
332
+ f"audio={train_stats['loss_audio']:.4f})")
333
+ print(loss_msg)
334
+ print(f" Val Loss: {val_metrics['val_loss']:.4f} "
335
+ f"| Acc: {val_metrics['accuracy']:.4f} "
336
+ f"| F1: {val_metrics['macro_f1']:.4f}")
337
+
338
+ if "ogm_visual_conf" in train_stats:
339
+ print(f" OGM-GE: visual_conf={train_stats['ogm_visual_conf']:.4f} "
340
+ f"audio_conf={train_stats['ogm_audio_conf']:.4f} "
341
+ f"| coeff_v={train_stats['ogm_coeff_visual']:.4f} "
342
+ f"coeff_a={train_stats['ogm_coeff_audio']:.4f}")
343
+
344
+ # Update history
345
+ self.history["train_loss"].append(train_stats["train_loss"])
346
+ self.history["val_loss"].append(val_metrics["val_loss"])
347
+ self.history["val_accuracy"].append(val_metrics["accuracy"])
348
+ self.history["val_macro_f1"].append(val_metrics["macro_f1"])
349
+ self.history["train_loss_fusion"].append(train_stats["loss_fusion"])
350
+ self.history["train_loss_visual"].append(train_stats["loss_visual"])
351
+ self.history["train_loss_audio"].append(train_stats["loss_audio"])
352
+ if "ogm_visual_conf" in train_stats:
353
+ self.history["ogm_visual_conf"].append(train_stats["ogm_visual_conf"])
354
+ self.history["ogm_audio_conf"].append(train_stats["ogm_audio_conf"])
355
+ self.history["ogm_coeff_visual"].append(train_stats["ogm_coeff_visual"])
356
+ self.history["ogm_coeff_audio"].append(train_stats["ogm_coeff_audio"])
357
+
358
+ # Save best model
359
+ if val_metrics[self.config.metric_for_best_model] > self.best_metric:
360
+ self.best_metric = val_metrics[self.config.metric_for_best_model]
361
+ self.best_epoch = epoch + 1
362
+ os.makedirs(self.config.output_dir, exist_ok=True)
363
+ torch.save({
364
+ "model_state_dict": self.model.state_dict(),
365
+ "epoch": epoch + 1,
366
+ "metrics": val_metrics,
367
+ }, os.path.join(self.config.output_dir, "best_model.pt"))
368
+ print(f" ✓ Best model saved (F1={self.best_metric:.4f})")
369
+
370
+ print(f"\nTraining complete. Best epoch={self.best_epoch}, "
371
+ f"Best F1={self.best_metric:.4f}")
372
+ return self.history
373
+
374
+ def run_robustness_evaluation(self):
375
+ """Test with missing modalities to evaluate robustness."""
376
+ print("\n=== Missing Modality Robustness Evaluation ===")
377
+ results = {}
378
+ scenarios = [
379
+ ("both_modalities", None),
380
+ ("visual_only", {"visual": 1.0, "audio": 0.0}),
381
+ ("audio_only", {"visual": 0.0, "audio": 1.0}),
382
+ ]
383
+ for name, mask in scenarios:
384
+ m = self.evaluate(modality_mask=mask)
385
+ results[name] = {"accuracy": m["accuracy"], "macro_f1": m["macro_f1"]}
386
+ print(f" {name:20s}: Acc={m['accuracy']:.4f} F1={m['macro_f1']:.4f}")
387
+
388
+ # Per-class breakdown
389
+ for cls, cls_m in m["per_class"].items():
390
+ print(f" {cls:25s} P:{cls_m['precision']:.3f} "
391
+ f"R:{cls_m['recall']:.3f} F1:{cls_m['f1']:.3f}")
392
+
393
+ # Compute improvement vs v1 baseline if available
394
+ print("\n [Target] Visual-only should improve from ~0.23 acc / 0.08 F1 (v1)")
395
+ return results
396
+
397
+
398
+ def main():
399
+ parser = argparse.ArgumentParser(description="Multimodal PC Fault Detection Training v2")
400
+ parser.add_argument("--mode", default="multimodal",
401
+ choices=["multimodal", "visual_only", "audio_only"])
402
+ parser.add_argument("--finetune", default="lora",
403
+ choices=["lora", "full", "linear_probe"])
404
+ parser.add_argument("--epochs", type=int)
405
+ parser.add_argument("--batch_size", type=int)
406
+ parser.add_argument("--lr", type=float)
407
+ parser.add_argument("--fusion", default="concat")
408
+ parser.add_argument("--modality_dropout", type=float)
409
+ parser.add_argument("--output_dir", type=str)
410
+ parser.add_argument("--hub_model_id", type=str)
411
+ parser.add_argument("--no_push", action="store_true")
412
+ parser.add_argument("--eval_robustness", action="store_true")
413
+ parser.add_argument("--quick_test", action="store_true")
414
+
415
+ # v2-specific arguments
416
+ parser.add_argument("--no_ogm", action="store_true",
417
+ help="Disable OGM-GE gradient modulation")
418
+ parser.add_argument("--ogm_alpha", type=float, default=None,
419
+ help="OGM-GE modulation strength (default from config)")
420
+ parser.add_argument("--ogm_noise_sigma", type=float, default=None,
421
+ help="OGM-GE noise sigma (default from config)")
422
+ parser.add_argument("--lambda_visual", type=float, default=None,
423
+ help="Visual auxiliary loss weight (default from config)")
424
+ parser.add_argument("--lambda_audio", type=float, default=None,
425
+ help="Audio auxiliary loss weight (default from config)")
426
+ parser.add_argument("--visual_lr_mult", type=float, default=None,
427
+ help="LR multiplier for visual branch (default from config)")
428
+ parser.add_argument("--audio_lr_mult", type=float, default=None,
429
+ help="LR multiplier for audio branch (default from config)")
430
+
431
+ args = parser.parse_args()
432
+
433
+ # Load config
434
+ config = ExperimentConfig()
435
+ config.experiment_name = "multimodal_pc_fault_v2"
436
+ config.train.mode = args.mode
437
+ config.train.finetune_method = args.finetune
438
+ config.model.fusion_type = args.fusion
439
+
440
+ if args.epochs:
441
+ config.train.num_epochs = args.epochs
442
+ if args.batch_size:
443
+ config.train.per_device_train_batch_size = args.batch_size
444
+ if args.lr:
445
+ config.train.learning_rate = args.lr
446
+ config.train.lora_learning_rate = args.lr
447
+ if args.modality_dropout is not None:
448
+ config.model.modality_dropout_p = args.modality_dropout
449
+ if args.output_dir:
450
+ config.train.output_dir = args.output_dir
451
+ if args.hub_model_id:
452
+ config.train.hub_model_id = args.hub_model_id
453
+ if args.no_push:
454
+ config.train.push_to_hub = False
455
+ if args.quick_test:
456
+ config.train.num_epochs = 2
457
+ config.train.per_device_train_batch_size = 4
458
+ config.train.per_device_eval_batch_size = 4
459
+ config.train.gradient_accumulation_steps = 1
460
+ config.train.logging_steps = 2
461
+ if args.finetune != "lora":
462
+ config.lora.enabled = False
463
+
464
+ # v2 hyperparameters from config (with CLI overrides)
465
+ ogm_alpha = args.ogm_alpha if args.ogm_alpha is not None else config.ogm_alpha
466
+ ogm_noise_sigma = (args.ogm_noise_sigma if args.ogm_noise_sigma is not None
467
+ else config.ogm_noise_sigma)
468
+ lambda_visual = (args.lambda_visual if args.lambda_visual is not None
469
+ else config.lambda_visual)
470
+ lambda_audio = (args.lambda_audio if args.lambda_audio is not None
471
+ else config.lambda_audio)
472
+ visual_lr_mult = (args.visual_lr_mult if args.visual_lr_mult is not None
473
+ else config.visual_lr_multiplier)
474
+ audio_lr_mult = (args.audio_lr_mult if args.audio_lr_mult is not None
475
+ else config.audio_lr_multiplier)
476
+ use_ogm = not args.no_ogm
477
+
478
+ # Device and seeds
479
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
480
+ torch.manual_seed(config.train.seed)
481
+ np.random.seed(config.train.seed)
482
+ if torch.cuda.is_available():
483
+ torch.cuda.manual_seed_all(config.train.seed)
484
+
485
+ print(f"\n{'='*60}")
486
+ print(f"Multimodal PC Fault Detection v2")
487
+ print(f"{'='*60}")
488
+ print(f"Mode: {args.mode} | Finetune: {args.finetune} | Device: {device}")
489
+ print(f"OGM-GE: {'ON' if use_ogm else 'OFF'} (alpha={ogm_alpha}, sigma={ogm_noise_sigma})")
490
+ print(f"Aux loss weights: λ_visual={lambda_visual}, λ_audio={lambda_audio}")
491
+ print(f"LR multipliers: visual={visual_lr_mult}x, audio={audio_lr_mult}x")
492
+ print(f"{'='*60}\n")
493
+
494
+ # Load processors and dataset
495
+ vit_proc, ast_ext = get_processors(config.model)
496
+ train_ds = PCFaultDataset(
497
+ config.data, config.model, "train", vit_proc, ast_ext, True)
498
+ val_ds = PCFaultDataset(
499
+ config.data, config.model, "val", vit_proc, ast_ext, False)
500
+
501
+ # Create model
502
+ model = create_model(
503
+ config.model, config.lora,
504
+ mode=args.mode,
505
+ finetune_method=args.finetune,
506
+ use_ogm=use_ogm,
507
+ lambda_visual=lambda_visual,
508
+ lambda_audio=lambda_audio)
509
+
510
+ # Create trainer
511
+ trainer = MultimodalTrainerV2(
512
+ model, train_ds, val_ds, config.train, device,
513
+ use_ogm=use_ogm,
514
+ ogm_alpha=ogm_alpha,
515
+ ogm_noise_sigma=ogm_noise_sigma,
516
+ visual_lr_multiplier=visual_lr_mult,
517
+ audio_lr_multiplier=audio_lr_mult)
518
+
519
+ # Train
520
+ history = trainer.train()
521
+
522
+ # Final evaluation
523
+ final = trainer.evaluate()
524
+ print(f"\nFinal Evaluation:")
525
+ print(f" Acc={final['accuracy']:.4f} F1={final['macro_f1']:.4f}")
526
+ for cls, m in final["per_class"].items():
527
+ print(f" {cls:25s} P:{m['precision']:.3f} R:{m['recall']:.3f} "
528
+ f"F1:{m['f1']:.3f} N:{m['support']}")
529
+
530
+ # Robustness evaluation
531
+ robustness_results = None
532
+ if args.eval_robustness and config.train.mode == "multimodal":
533
+ robustness_results = trainer.run_robustness_evaluation()
534
+
535
+ # Save results
536
+ os.makedirs(config.train.output_dir, exist_ok=True)
537
+ results = {
538
+ "experiment": config.experiment_name,
539
+ "version": "v2",
540
+ "mode": config.train.mode,
541
+ "finetune_method": config.train.finetune_method,
542
+ "anti_collapse_config": {
543
+ "ogm_ge": use_ogm,
544
+ "ogm_alpha": ogm_alpha,
545
+ "ogm_noise_sigma": ogm_noise_sigma,
546
+ "lambda_visual": lambda_visual,
547
+ "lambda_audio": lambda_audio,
548
+ "visual_lr_multiplier": visual_lr_mult,
549
+ "audio_lr_multiplier": audio_lr_mult,
550
+ },
551
+ "final_metrics": {
552
+ "accuracy": final["accuracy"],
553
+ "macro_f1": final["macro_f1"],
554
+ "weighted_f1": final["weighted_f1"],
555
+ "per_class": final["per_class"],
556
+ "confusion_matrix": final["confusion_matrix"],
557
+ },
558
+ "history": history,
559
+ "best_epoch": trainer.best_epoch,
560
+ "best_metric": trainer.best_metric,
561
+ }
562
+ if robustness_results:
563
+ results["robustness"] = robustness_results
564
+
565
+ with open(os.path.join(config.train.output_dir, "results_v2.json"), "w") as f:
566
+ json.dump(results, f, indent=2)
567
+ print(f"\nResults saved to {config.train.output_dir}/results_v2.json")
568
+
569
+ # Push to hub
570
+ if config.train.push_to_hub:
571
+ try:
572
+ from huggingface_hub import HfApi, login
573
+ login(token=os.environ.get("HF_TOKEN"))
574
+ HfApi().upload_folder(
575
+ folder_path=config.train.output_dir,
576
+ repo_id=config.train.hub_model_id,
577
+ repo_type="model",
578
+ commit_message=f"Training v2: {config.experiment_name} (OGM-GE)")
579
+ print(f"✓ Pushed to https://huggingface.co/{config.train.hub_model_id}")
580
+ except Exception as e:
581
+ print(f"✗ Push failed: {e}")
582
+
583
+
584
+ if __name__ == "__main__":
585
+ main()