File size: 24,018 Bytes
0c3f8a7
 
 
 
 
 
 
 
61c951d
0c3f8a7
 
61c951d
0c3f8a7
61c951d
 
 
 
 
 
0c3f8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61c951d
 
0c3f8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61c951d
0c3f8a7
 
 
 
 
 
 
61c951d
0c3f8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61c951d
0c3f8a7
 
 
 
 
 
 
 
 
 
61c951d
0c3f8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61c951d
0c3f8a7
61c951d
0c3f8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61c951d
 
0c3f8a7
 
 
 
 
 
 
 
 
 
 
61c951d
0c3f8a7
 
 
 
 
61c951d
0c3f8a7
 
 
61c951d
 
0c3f8a7
61c951d
 
0c3f8a7
 
 
 
 
61c951d
0c3f8a7
 
 
 
 
61c951d
 
0c3f8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61c951d
0c3f8a7
 
 
 
 
 
 
61c951d
0c3f8a7
 
 
 
61c951d
0c3f8a7
61c951d
0c3f8a7
61c951d
 
0c3f8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61c951d
 
0c3f8a7
 
61c951d
0c3f8a7
 
 
 
 
61c951d
 
0c3f8a7
 
 
 
61c951d
0c3f8a7
 
 
 
 
 
61c951d
 
0c3f8a7
 
 
 
 
 
 
 
 
 
 
61c951d
 
 
 
 
 
 
 
 
 
0c3f8a7
61c951d
 
 
 
 
 
 
0c3f8a7
 
 
 
 
 
 
 
 
 
61c951d
 
 
 
 
 
 
0c3f8a7
61c951d
 
0c3f8a7
61c951d
0c3f8a7
 
61c951d
 
 
 
 
0c3f8a7
 
 
 
 
61c951d
0c3f8a7
 
 
 
 
61c951d
 
0c3f8a7
 
 
 
 
61c951d
0c3f8a7
61c951d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c3f8a7
 
61c951d
 
0c3f8a7
 
 
 
61c951d
 
 
0c3f8a7
 
 
 
 
 
 
 
 
61c951d
0c3f8a7
 
 
 
 
 
 
 
61c951d
 
 
0c3f8a7
61c951d
 
 
0c3f8a7
 
61c951d
 
0c3f8a7
 
61c951d
0c3f8a7
61c951d
0c3f8a7
 
 
 
 
 
 
 
 
61c951d
 
0c3f8a7
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
"""
Multimodal PC Fault Detection - Training Script v2
====================================================
Changes from v1:
  - OGM-GE gradient modulation after loss.backward(), before optimizer.step()
  - Asymmetric learning rates: higher for visual branch, lower for audio
  - Auxiliary loss logging (loss_fusion, loss_visual, loss_audio per epoch)
  - OGM-GE stats logging (visual_conf, audio_conf, modulation coefficients)
  - Supports both old proxy data (dataset_real) and new built data (dataset_v2)

Usage:
  # With old proxy data (ToyADMOS + MVTec, default)
  python train_v2.py --mode multimodal --finetune lora --eval_robustness

  # With new built dataset (from build_dataset.py)
  python train_v2.py --dataset local --dataset_dir ../data/dataset_build --eval_robustness
  python train_v2.py --dataset hub --hub_dataset Ellaft/pc-fault-real-dataset

  # Other options
  python train_v2.py --mode visual_only --finetune lora --no_push
  python train_v2.py --quick_test --no_push

References:
  OGM-GE: Peng et al., "Balanced Multimodal Learning via On-the-fly Gradient 
  Modulation", CVPR 2022
"""

import os, sys, json, argparse, time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support

from config import ExperimentConfig, FAULT_CLASSES, NUM_CLASSES
from models_v2 import create_model, get_processors, OGMGEModulator


def compute_metrics(preds, labels, class_names=FAULT_CLASSES):
    """Compute accuracy, F1, precision, recall, and confusion matrix."""
    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, preds, average=None, labels=range(len(class_names)), zero_division=0)
    macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)
    weighted_f1 = f1_score(labels, preds, average="weighted", zero_division=0)
    conf_matrix = confusion_matrix(labels, preds, labels=range(len(class_names)))

    metrics = {
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "weighted_f1": weighted_f1,
        "confusion_matrix": conf_matrix.tolist(),
        "per_class": {},
    }
    for i, name in enumerate(class_names):
        metrics["per_class"][name] = {
            "precision": precision[i], "recall": recall[i],
            "f1": f1[i], "support": int(support[i]),
        }
    return metrics


class MultimodalTrainerV2:
    """
    Training loop v2 with OGM-GE gradient modulation.
    
    Key differences from v1:
    1. Three separate parameter groups with asymmetric LRs:
       - visual_branch: higher LR (visual_lr_multiplier × base_lr)
       - audio_branch: lower LR (audio_lr_multiplier × base_lr)
       - fusion + auxiliary heads: base LR
    2. OGM-GE applied after backward(), before optimizer.step()
    3. Logs auxiliary losses and OGM-GE stats per epoch
    """

    def __init__(self, model, train_dataset, val_dataset, config, device,
                 use_ogm=True, ogm_alpha=0.3, ogm_noise_sigma=0.1,
                 visual_lr_multiplier=3.0, audio_lr_multiplier=0.5,
                 collate_fn=None):
        self.model = model.to(device)
        self.device = device
        self.config = config
        self.use_ogm = use_ogm and (model.mode == "multimodal")

        # OGM-GE modulator
        if self.use_ogm:
            self.ogm = OGMGEModulator(alpha=ogm_alpha, noise_sigma=ogm_noise_sigma)
            print(f"[Trainer v2] OGM-GE enabled: alpha={ogm_alpha}, noise_sigma={ogm_noise_sigma}")
        else:
            self.ogm = None

        # Data loaders
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=config.per_device_train_batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=2,
            pin_memory=True,
            drop_last=True)
        self.val_loader = DataLoader(
            val_dataset,
            batch_size=config.per_device_eval_batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=2,
            pin_memory=True)

        # Asymmetric parameter groups
        param_groups = self._get_param_groups(visual_lr_multiplier, audio_lr_multiplier)
        self.optimizer = AdamW(param_groups, weight_decay=config.weight_decay)

        total_steps = (len(self.train_loader) * config.num_epochs
                       // config.gradient_accumulation_steps)
        self.scheduler = OneCycleLR(
            self.optimizer,
            max_lr=[pg["lr"] for pg in param_groups],
            total_steps=max(total_steps, 1),
            pct_start=config.warmup_ratio,
            anneal_strategy="cos")

        # Mixed precision
        self.scaler = (torch.amp.GradScaler("cuda")
                       if config.fp16 and device.type == "cuda" else None)

        # Tracking
        self.best_metric = 0.0
        self.best_epoch = 0
        self.history = {
            "train_loss": [], "val_loss": [],
            "val_accuracy": [], "val_macro_f1": [],
            "train_loss_fusion": [], "train_loss_visual": [], "train_loss_audio": [],
            "ogm_visual_conf": [], "ogm_audio_conf": [],
            "ogm_coeff_visual": [], "ogm_coeff_audio": [],
        }

    def _get_param_groups(self, visual_lr_multiplier, audio_lr_multiplier):
        visual_params, audio_params, fusion_params = [], [], []
        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue
            if "visual_branch" in name:
                visual_params.append(param)
            elif "audio_branch" in name:
                audio_params.append(param)
            else:
                fusion_params.append(param)

        base_lr = self.config.lora_learning_rate
        groups = []
        if visual_params:
            vlr = base_lr * visual_lr_multiplier
            groups.append({"params": visual_params, "lr": vlr, "name": "visual_branch"})
            print(f"[Trainer v2] visual_branch: {len(visual_params)} tensors, lr={vlr:.2e}")
        if audio_params:
            alr = base_lr * audio_lr_multiplier
            groups.append({"params": audio_params, "lr": alr, "name": "audio_branch"})
            print(f"[Trainer v2] audio_branch: {len(audio_params)} tensors, lr={alr:.2e}")
        if fusion_params:
            groups.append({"params": fusion_params, "lr": base_lr, "name": "fusion_heads"})
            print(f"[Trainer v2] fusion_heads: {len(fusion_params)} tensors, lr={base_lr:.2e}")
        if not groups:
            raise ValueError("No trainable parameters!")
        return groups

    def train_epoch(self, epoch):
        self.model.train()
        total_loss, total_loss_fusion, total_loss_visual, total_loss_audio = 0.0, 0.0, 0.0, 0.0
        num_batches = 0
        ogm_v_confs, ogm_a_confs, ogm_cv, ogm_ca = [], [], [], []
        self.optimizer.zero_grad()

        for batch_idx, batch in enumerate(self.train_loader):
            pv = batch["pixel_values"].to(self.device)
            av = batch["audio_values"].to(self.device)
            labels = batch["labels"].to(self.device)

            if self.scaler:
                with torch.amp.autocast("cuda"):
                    outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
                    loss = outputs["loss"] / self.config.gradient_accumulation_steps
                self.scaler.scale(loss).backward()
            else:
                outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
                loss = outputs["loss"] / self.config.gradient_accumulation_steps
                loss.backward()

            total_loss += loss.item() * self.config.gradient_accumulation_steps
            num_batches += 1
            if "loss_fusion" in outputs:
                total_loss_fusion += outputs["loss_fusion"]
                total_loss_visual += outputs["loss_visual"]
                total_loss_audio += outputs["loss_audio"]

            if (self.use_ogm and self.ogm is not None
                    and "visual_logits" in outputs and "audio_logits" in outputs):
                _cv, _ca, _stats = self.ogm.compute_modulation_coefficients(
                    outputs["visual_logits"], outputs["audio_logits"], labels)
                ogm_v_confs.append(_stats["visual_conf"])
                ogm_a_confs.append(_stats["audio_conf"])
                ogm_cv.append(_stats["coeff_visual"])
                ogm_ca.append(_stats["coeff_audio"])

            if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
                if self.scaler:
                    self.scaler.unscale_(self.optimizer)
                if (self.use_ogm and self.ogm is not None and ogm_cv):
                    self.ogm.apply_gradient_modulation(self.model, ogm_cv[-1], ogm_ca[-1])
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                if self.scaler:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            if (batch_idx + 1) % self.config.logging_steps == 0 or batch_idx == 0:
                avg_loss = total_loss / num_batches
                msg = (f"  [Epoch {epoch+1}] Step {batch_idx+1}/{len(self.train_loader)} "
                       f"| Loss: {avg_loss:.4f} | LR_v: {self.optimizer.param_groups[0]['lr']:.2e}")
                if "loss_fusion" in outputs:
                    msg += (f" | L_fus: {total_loss_fusion/num_batches:.4f}"
                            f" L_vis: {total_loss_visual/num_batches:.4f}"
                            f" L_aud: {total_loss_audio/num_batches:.4f}")
                if ogm_cv:
                    msg += f" | OGM c_v: {ogm_cv[-1]:.3f} c_a: {ogm_ca[-1]:.3f}"
                print(msg)

        n = max(num_batches, 1)
        epoch_stats = {"train_loss": total_loss / n, "loss_fusion": total_loss_fusion / n,
                       "loss_visual": total_loss_visual / n, "loss_audio": total_loss_audio / n}
        if ogm_v_confs:
            epoch_stats.update({"ogm_visual_conf": np.mean(ogm_v_confs), "ogm_audio_conf": np.mean(ogm_a_confs),
                                "ogm_coeff_visual": np.mean(ogm_cv), "ogm_coeff_audio": np.mean(ogm_ca)})
        return epoch_stats

    @torch.no_grad()
    def evaluate(self, modality_mask=None):
        self.model.eval()
        all_preds, all_labels, total_loss, num_batches = [], [], 0.0, 0
        for batch in self.val_loader:
            pv = batch["pixel_values"].to(self.device)
            av = batch["audio_values"].to(self.device)
            labels = batch["labels"].to(self.device)
            if modality_mask:
                if modality_mask.get("visual", 1.0) == 0.0: pv = torch.zeros_like(pv)
                if modality_mask.get("audio", 1.0) == 0.0: av = torch.zeros_like(av)
            outputs = self.model(pixel_values=pv, audio_values=av, labels=labels)
            total_loss += outputs["loss"].item()
            num_batches += 1
            all_preds.extend(outputs["logits"].argmax(dim=-1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        metrics = compute_metrics(np.array(all_preds), np.array(all_labels))
        metrics["val_loss"] = total_loss / max(num_batches, 1)
        return metrics

    def train(self):
        print(f"\n{'='*60}")
        print(f"Training v2: mode={self.model.mode}, epochs={self.config.num_epochs}, "
              f"batch={self.config.per_device_train_batch_size}, device={self.device}")
        print(f"OGM-GE: {'ENABLED' if self.use_ogm else 'DISABLED'}")
        if self.model.mode == "multimodal":
            print(f"Auxiliary loss weights: λ_visual={self.model.lambda_visual}, λ_audio={self.model.lambda_audio}")
        print(f"{'='*60}\n")

        for epoch in range(self.config.num_epochs):
            t0 = time.time()
            train_stats = self.train_epoch(epoch)
            val_metrics = self.evaluate()
            elapsed = time.time() - t0

            print(f"\n[Epoch {epoch+1}/{self.config.num_epochs}] ({elapsed:.1f}s)")
            loss_msg = f"  Train Loss: {train_stats['train_loss']:.4f}"
            if train_stats.get("loss_fusion", 0) > 0:
                loss_msg += (f" (fusion={train_stats['loss_fusion']:.4f} "
                             f"visual={train_stats['loss_visual']:.4f} audio={train_stats['loss_audio']:.4f})")
            print(loss_msg)
            print(f"  Val Loss: {val_metrics['val_loss']:.4f} | Acc: {val_metrics['accuracy']:.4f} | F1: {val_metrics['macro_f1']:.4f}")
            if "ogm_visual_conf" in train_stats:
                print(f"  OGM-GE: visual_conf={train_stats['ogm_visual_conf']:.4f} audio_conf={train_stats['ogm_audio_conf']:.4f} "
                      f"| coeff_v={train_stats['ogm_coeff_visual']:.4f} coeff_a={train_stats['ogm_coeff_audio']:.4f}")

            self.history["train_loss"].append(train_stats["train_loss"])
            self.history["val_loss"].append(val_metrics["val_loss"])
            self.history["val_accuracy"].append(val_metrics["accuracy"])
            self.history["val_macro_f1"].append(val_metrics["macro_f1"])
            self.history["train_loss_fusion"].append(train_stats["loss_fusion"])
            self.history["train_loss_visual"].append(train_stats["loss_visual"])
            self.history["train_loss_audio"].append(train_stats["loss_audio"])
            if "ogm_visual_conf" in train_stats:
                self.history["ogm_visual_conf"].append(train_stats["ogm_visual_conf"])
                self.history["ogm_audio_conf"].append(train_stats["ogm_audio_conf"])
                self.history["ogm_coeff_visual"].append(train_stats["ogm_coeff_visual"])
                self.history["ogm_coeff_audio"].append(train_stats["ogm_coeff_audio"])

            if val_metrics[self.config.metric_for_best_model] > self.best_metric:
                self.best_metric = val_metrics[self.config.metric_for_best_model]
                self.best_epoch = epoch + 1
                os.makedirs(self.config.output_dir, exist_ok=True)
                torch.save({"model_state_dict": self.model.state_dict(), "epoch": epoch + 1,
                             "metrics": val_metrics}, os.path.join(self.config.output_dir, "best_model.pt"))
                print(f"  ✓ Best model saved (F1={self.best_metric:.4f})")

        print(f"\nTraining complete. Best epoch={self.best_epoch}, Best F1={self.best_metric:.4f}")
        return self.history

    def run_robustness_evaluation(self):
        print("\n=== Missing Modality Robustness Evaluation ===")
        results = {}
        for name, mask in [("both_modalities", None), ("visual_only", {"visual": 1.0, "audio": 0.0}),
                           ("audio_only", {"visual": 0.0, "audio": 1.0})]:
            m = self.evaluate(modality_mask=mask)
            results[name] = {"accuracy": m["accuracy"], "macro_f1": m["macro_f1"]}
            print(f"  {name:20s}: Acc={m['accuracy']:.4f} F1={m['macro_f1']:.4f}")
            for cls, cls_m in m["per_class"].items():
                print(f"    {cls:25s} P:{cls_m['precision']:.3f} R:{cls_m['recall']:.3f} F1:{cls_m['f1']:.3f}")
        print("\n  [Target] Visual-only should improve from ~0.23 acc / 0.08 F1 (v1)")
        return results


def main():
    parser = argparse.ArgumentParser(description="Multimodal PC Fault Detection Training v2")
    parser.add_argument("--mode", default="multimodal", choices=["multimodal", "visual_only", "audio_only"])
    parser.add_argument("--finetune", default="lora", choices=["lora", "full", "linear_probe"])
    parser.add_argument("--epochs", type=int)
    parser.add_argument("--batch_size", type=int)
    parser.add_argument("--lr", type=float)
    parser.add_argument("--fusion", default="concat")
    parser.add_argument("--modality_dropout", type=float)
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--hub_model_id", type=str)
    parser.add_argument("--no_push", action="store_true")
    parser.add_argument("--eval_robustness", action="store_true")
    parser.add_argument("--quick_test", action="store_true")

    # Dataset selection
    parser.add_argument("--dataset", default="proxy",
                        choices=["proxy", "local", "hub"],
                        help="Dataset source: 'proxy' (ToyADMOS+MVTec, default), "
                             "'local' (build_dataset.py output), 'hub' (HF Hub dataset)")
    parser.add_argument("--dataset_dir", default="./dataset_build",
                        help="Path to build_dataset.py output (for --dataset local)")
    parser.add_argument("--hub_dataset", default="Ellaft/pc-fault-real-dataset",
                        help="HuggingFace dataset ID (for --dataset hub)")

    # v2-specific arguments
    parser.add_argument("--no_ogm", action="store_true")
    parser.add_argument("--ogm_alpha", type=float, default=None)
    parser.add_argument("--ogm_noise_sigma", type=float, default=None)
    parser.add_argument("--lambda_visual", type=float, default=None)
    parser.add_argument("--lambda_audio", type=float, default=None)
    parser.add_argument("--visual_lr_mult", type=float, default=None)
    parser.add_argument("--audio_lr_mult", type=float, default=None)

    args = parser.parse_args()

    # Load config
    config = ExperimentConfig()
    config.experiment_name = "multimodal_pc_fault_v2"
    config.train.mode = args.mode
    config.train.finetune_method = args.finetune
    config.model.fusion_type = args.fusion

    if args.epochs: config.train.num_epochs = args.epochs
    if args.batch_size: config.train.per_device_train_batch_size = args.batch_size
    if args.lr: config.train.learning_rate = config.train.lora_learning_rate = args.lr
    if args.modality_dropout is not None: config.model.modality_dropout_p = args.modality_dropout
    if args.output_dir: config.train.output_dir = args.output_dir
    if args.hub_model_id: config.train.hub_model_id = args.hub_model_id
    if args.no_push: config.train.push_to_hub = False
    if args.quick_test:
        config.train.num_epochs, config.train.per_device_train_batch_size = 2, 4
        config.train.per_device_eval_batch_size, config.train.gradient_accumulation_steps = 4, 1
        config.train.logging_steps = 2
    if args.finetune != "lora": config.lora.enabled = False

    ogm_alpha = args.ogm_alpha if args.ogm_alpha is not None else config.ogm_alpha
    ogm_noise_sigma = args.ogm_noise_sigma if args.ogm_noise_sigma is not None else config.ogm_noise_sigma
    lambda_visual = args.lambda_visual if args.lambda_visual is not None else config.lambda_visual
    lambda_audio = args.lambda_audio if args.lambda_audio is not None else config.lambda_audio
    visual_lr_mult = args.visual_lr_mult if args.visual_lr_mult is not None else config.visual_lr_multiplier
    audio_lr_mult = args.audio_lr_mult if args.audio_lr_mult is not None else config.audio_lr_multiplier
    use_ogm = not args.no_ogm

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(config.train.seed)
    np.random.seed(config.train.seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(config.train.seed)

    print(f"\n{'='*60}")
    print(f"Multimodal PC Fault Detection v2")
    print(f"{'='*60}")
    print(f"Mode: {args.mode} | Finetune: {args.finetune} | Device: {device}")
    print(f"Dataset: {args.dataset}" + (f" ({args.dataset_dir})" if args.dataset == "local" else
          f" ({args.hub_dataset})" if args.dataset == "hub" else " (ToyADMOS + MVTec proxy)"))
    print(f"OGM-GE: {'ON' if use_ogm else 'OFF'} (alpha={ogm_alpha}, sigma={ogm_noise_sigma})")
    print(f"Aux loss weights: λ_visual={lambda_visual}, λ_audio={lambda_audio}")
    print(f"LR multipliers: visual={visual_lr_mult}x, audio={audio_lr_mult}x")
    print(f"{'='*60}\n")

    # Load processors
    vit_proc, ast_ext = get_processors(config.model)

    # ---- Load dataset based on --dataset flag ----
    if args.dataset in ("local", "hub"):
        from dataset_v2 import BuiltDataset as PCFaultDataset, multimodal_collate_fn
        source = args.dataset  # "local" or "hub"
        train_ds = PCFaultDataset(
            config.data, config.model, "train", vit_proc, ast_ext, True,
            source=source, dataset_dir=args.dataset_dir, hub_dataset=args.hub_dataset)
        val_ds = PCFaultDataset(
            config.data, config.model, "val", vit_proc, ast_ext, False,
            source=source, dataset_dir=args.dataset_dir, hub_dataset=args.hub_dataset)
    else:
        # Default: old proxy data (ToyADMOS + MVTec)
        from dataset_real import RealPCFaultDataset as PCFaultDataset, multimodal_collate_fn
        train_ds = PCFaultDataset(config.data, config.model, "train", vit_proc, ast_ext, True)
        val_ds = PCFaultDataset(config.data, config.model, "val", vit_proc, ast_ext, False)

    # Create model
    model = create_model(config.model, config.lora, mode=args.mode, finetune_method=args.finetune,
                         use_ogm=use_ogm, lambda_visual=lambda_visual, lambda_audio=lambda_audio)

    # Create trainer
    trainer = MultimodalTrainerV2(
        model, train_ds, val_ds, config.train, device,
        use_ogm=use_ogm, ogm_alpha=ogm_alpha, ogm_noise_sigma=ogm_noise_sigma,
        visual_lr_multiplier=visual_lr_mult, audio_lr_multiplier=audio_lr_mult,
        collate_fn=multimodal_collate_fn)

    # Train
    history = trainer.train()

    # Final evaluation
    final = trainer.evaluate()
    print(f"\nFinal Evaluation:")
    print(f"  Acc={final['accuracy']:.4f} F1={final['macro_f1']:.4f}")
    for cls, m in final["per_class"].items():
        print(f"  {cls:25s} P:{m['precision']:.3f} R:{m['recall']:.3f} F1:{m['f1']:.3f} N:{m['support']}")

    robustness_results = None
    if args.eval_robustness and config.train.mode == "multimodal":
        robustness_results = trainer.run_robustness_evaluation()

    # Save results
    os.makedirs(config.train.output_dir, exist_ok=True)
    results = {
        "experiment": config.experiment_name, "version": "v2",
        "mode": config.train.mode, "finetune_method": config.train.finetune_method,
        "dataset_source": args.dataset,
        "anti_collapse_config": {
            "ogm_ge": use_ogm, "ogm_alpha": ogm_alpha, "ogm_noise_sigma": ogm_noise_sigma,
            "lambda_visual": lambda_visual, "lambda_audio": lambda_audio,
            "visual_lr_multiplier": visual_lr_mult, "audio_lr_multiplier": audio_lr_mult,
        },
        "final_metrics": {
            "accuracy": final["accuracy"], "macro_f1": final["macro_f1"],
            "weighted_f1": final["weighted_f1"], "per_class": final["per_class"],
            "confusion_matrix": final["confusion_matrix"],
        },
        "history": history, "best_epoch": trainer.best_epoch, "best_metric": trainer.best_metric,
    }
    if robustness_results: results["robustness"] = robustness_results

    with open(os.path.join(config.train.output_dir, "results_v2.json"), "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to {config.train.output_dir}/results_v2.json")

    if config.train.push_to_hub:
        try:
            from huggingface_hub import HfApi, login
            login(token=os.environ.get("HF_TOKEN"))
            HfApi().upload_folder(folder_path=config.train.output_dir, repo_id=config.train.hub_model_id,
                                  repo_type="model", commit_message=f"Training v2: {config.experiment_name} (OGM-GE)")
            print(f"✓ Pushed to https://huggingface.co/{config.train.hub_model_id}")
        except Exception as e:
            print(f"✗ Push failed: {e}")


if __name__ == "__main__":
    main()