| """ |
| Multimodal PC Fault Detection - Training Script |
| ================================================= |
| Loads dataset from Ellaft/pc-fault-real-dataset (HuggingFace Hub). |
| Uses OGM-GE anti-modality-collapse + auxiliary heads + asymmetric LRs. |
| |
| Usage: |
| python train.py # Full training |
| python train.py --quick_test --no_push # Quick 2-epoch test |
| python train.py --eval_robustness # + missing modality test |
| python train.py --mode visual_only --no_push # Unimodal ablation |
| python train.py --mode audio_only --no_push # Unimodal ablation |
| |
| All 6 ablations: |
| python run_ablations.py |
| python run_ablations.py --quick_test |
| """ |
|
|
| import os, sys, json, argparse, time |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import OneCycleLR |
| from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support |
|
|
| from config import ExperimentConfig, FAULT_CLASSES, NUM_CLASSES |
| from models import create_model, get_processors, OGMGEModulator |
| from dataset_v2 import BuiltDataset as PCFaultDataset, multimodal_collate_fn |
|
|
|
|
| def compute_metrics(preds, labels, class_names=FAULT_CLASSES): |
| accuracy = accuracy_score(labels, preds) |
| precision, recall, f1, support = precision_recall_fscore_support( |
| labels, preds, average=None, labels=range(len(class_names)), zero_division=0) |
| macro_f1 = f1_score(labels, preds, average="macro", zero_division=0) |
| weighted_f1 = f1_score(labels, preds, average="weighted", zero_division=0) |
| conf_matrix = confusion_matrix(labels, preds, labels=range(len(class_names))) |
| metrics = {"accuracy": accuracy, "macro_f1": macro_f1, "weighted_f1": weighted_f1, |
| "confusion_matrix": conf_matrix.tolist(), "per_class": {}} |
| for i, name in enumerate(class_names): |
| metrics["per_class"][name] = {"precision": precision[i], "recall": recall[i], |
| "f1": f1[i], "support": int(support[i])} |
| return metrics |
|
|
|
|
| class MultimodalTrainer: |
| def __init__(self, model, train_dataset, val_dataset, config, device, |
| use_ogm=True, ogm_alpha=0.3, ogm_noise_sigma=0.1, |
| visual_lr_multiplier=3.0, audio_lr_multiplier=0.5): |
| self.model = model.to(device) |
| self.device, self.config = device, config |
| self.use_ogm = use_ogm and (model.mode == "multimodal") |
| self.ogm = OGMGEModulator(alpha=ogm_alpha, noise_sigma=ogm_noise_sigma) if self.use_ogm else None |
| if self.use_ogm: |
| print(f"[Trainer] OGM-GE enabled: alpha={ogm_alpha}, noise_sigma={ogm_noise_sigma}") |
|
|
| self.train_loader = DataLoader(train_dataset, batch_size=config.per_device_train_batch_size, |
| shuffle=True, collate_fn=multimodal_collate_fn, num_workers=2, pin_memory=True, drop_last=True) |
| self.val_loader = DataLoader(val_dataset, batch_size=config.per_device_eval_batch_size, |
| shuffle=False, collate_fn=multimodal_collate_fn, num_workers=2, pin_memory=True) |
|
|
| param_groups = self._get_param_groups(visual_lr_multiplier, audio_lr_multiplier) |
| self.optimizer = AdamW(param_groups, weight_decay=config.weight_decay) |
| total_steps = max(len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps, 1) |
| self.scheduler = OneCycleLR(self.optimizer, max_lr=[pg["lr"] for pg in param_groups], |
| total_steps=total_steps, pct_start=config.warmup_ratio, anneal_strategy="cos") |
| self.scaler = torch.amp.GradScaler("cuda") if config.fp16 and device.type == "cuda" else None |
| self.best_metric, self.best_epoch = 0.0, 0 |
| self.history = {"train_loss": [], "val_loss": [], "val_accuracy": [], "val_macro_f1": [], |
| "train_loss_fusion": [], "train_loss_visual": [], "train_loss_audio": [], |
| "ogm_visual_conf": [], "ogm_audio_conf": [], "ogm_coeff_visual": [], "ogm_coeff_audio": []} |
|
|
| def _get_param_groups(self, visual_lr_mult, audio_lr_mult): |
| visual_p, audio_p, fusion_p = [], [], [] |
| for name, p in self.model.named_parameters(): |
| if not p.requires_grad: continue |
| if "visual_branch" in name: visual_p.append(p) |
| elif "audio_branch" in name: audio_p.append(p) |
| else: fusion_p.append(p) |
| base_lr = self.config.lora_learning_rate |
| groups = [] |
| if visual_p: groups.append({"params": visual_p, "lr": base_lr * visual_lr_mult, "name": "visual"}) |
| if audio_p: groups.append({"params": audio_p, "lr": base_lr * audio_lr_mult, "name": "audio"}) |
| if fusion_p: groups.append({"params": fusion_p, "lr": base_lr, "name": "fusion"}) |
| for g in groups: print(f" [Param Group] {g.get('name','?')}: {len(g['params'])} tensors, lr={g['lr']:.2e}") |
| if not groups: raise ValueError("No trainable parameters!") |
| return groups |
|
|
| def train_epoch(self, epoch): |
| self.model.train() |
| tot_loss, tot_fus, tot_vis, tot_aud, n = 0., 0., 0., 0., 0 |
| ogm_vc, ogm_ac, ogm_cv, ogm_ca = [], [], [], [] |
| self.optimizer.zero_grad() |
| for bi, batch in enumerate(self.train_loader): |
| pv, av, labels = batch["pixel_values"].to(self.device), batch["audio_values"].to(self.device), batch["labels"].to(self.device) |
| if self.scaler: |
| with torch.amp.autocast("cuda"): |
| out = self.model(pixel_values=pv, audio_values=av, labels=labels) |
| loss = out["loss"] / self.config.gradient_accumulation_steps |
| self.scaler.scale(loss).backward() |
| else: |
| out = self.model(pixel_values=pv, audio_values=av, labels=labels) |
| loss = out["loss"] / self.config.gradient_accumulation_steps |
| loss.backward() |
| tot_loss += loss.item() * self.config.gradient_accumulation_steps; n += 1 |
| if "loss_fusion" in out: tot_fus += out["loss_fusion"]; tot_vis += out["loss_visual"]; tot_aud += out["loss_audio"] |
| if self.use_ogm and "visual_logits" in out: |
| cv, ca, st = self.ogm.compute_modulation_coefficients(out["visual_logits"], out["audio_logits"], labels) |
| ogm_vc.append(st["visual_conf"]); ogm_ac.append(st["audio_conf"]); ogm_cv.append(cv); ogm_ca.append(ca) |
| if (bi + 1) % self.config.gradient_accumulation_steps == 0: |
| if self.scaler: self.scaler.unscale_(self.optimizer) |
| if self.use_ogm and ogm_cv: self.ogm.apply_gradient_modulation(self.model, ogm_cv[-1], ogm_ca[-1]) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
| if self.scaler: self.scaler.step(self.optimizer); self.scaler.update() |
| else: self.optimizer.step() |
| self.scheduler.step(); self.optimizer.zero_grad() |
| if (bi + 1) % self.config.logging_steps == 0 or bi == 0: |
| msg = f" [Epoch {epoch+1}] Step {bi+1}/{len(self.train_loader)} | Loss: {tot_loss/n:.4f}" |
| if "loss_fusion" in out: msg += f" | fus:{tot_fus/n:.4f} vis:{tot_vis/n:.4f} aud:{tot_aud/n:.4f}" |
| if ogm_cv: msg += f" | OGM cv:{ogm_cv[-1]:.3f} ca:{ogm_ca[-1]:.3f}" |
| print(msg) |
| stats = {"train_loss": tot_loss/max(n,1), "loss_fusion": tot_fus/max(n,1), "loss_visual": tot_vis/max(n,1), "loss_audio": tot_aud/max(n,1)} |
| if ogm_vc: stats.update({"ogm_visual_conf": np.mean(ogm_vc), "ogm_audio_conf": np.mean(ogm_ac), "ogm_coeff_visual": np.mean(ogm_cv), "ogm_coeff_audio": np.mean(ogm_ca)}) |
| return stats |
|
|
| @torch.no_grad() |
| def evaluate(self, modality_mask=None): |
| self.model.eval() |
| preds, lbls, tot_loss, n = [], [], 0., 0 |
| for batch in self.val_loader: |
| pv, av, labels = batch["pixel_values"].to(self.device), batch["audio_values"].to(self.device), batch["labels"].to(self.device) |
| if modality_mask: |
| if modality_mask.get("visual", 1.0) == 0.0: pv = torch.zeros_like(pv) |
| if modality_mask.get("audio", 1.0) == 0.0: av = torch.zeros_like(av) |
| out = self.model(pixel_values=pv, audio_values=av, labels=labels) |
| tot_loss += out["loss"].item(); n += 1 |
| preds.extend(out["logits"].argmax(dim=-1).cpu().numpy()); lbls.extend(labels.cpu().numpy()) |
| m = compute_metrics(np.array(preds), np.array(lbls)) |
| m["val_loss"] = tot_loss / max(n, 1) |
| return m |
|
|
| def train(self): |
| print(f"\n{'='*60}\nTraining: mode={self.model.mode}, epochs={self.config.num_epochs}, device={self.device}") |
| print(f"OGM-GE: {'ON' if self.use_ogm else 'OFF'}") |
| if self.model.mode == "multimodal": print(f"λ_visual={self.model.lambda_visual}, λ_audio={self.model.lambda_audio}") |
| print(f"{'='*60}\n") |
| for epoch in range(self.config.num_epochs): |
| t0 = time.time() |
| ts = self.train_epoch(epoch) |
| vm = self.evaluate() |
| print(f"\n[Epoch {epoch+1}/{self.config.num_epochs}] ({time.time()-t0:.1f}s)") |
| msg = f" Train Loss: {ts['train_loss']:.4f}" |
| if ts.get("loss_fusion", 0) > 0: msg += f" (fus={ts['loss_fusion']:.4f} vis={ts['loss_visual']:.4f} aud={ts['loss_audio']:.4f})" |
| print(msg) |
| print(f" Val Loss: {vm['val_loss']:.4f} | Acc: {vm['accuracy']:.4f} | F1: {vm['macro_f1']:.4f}") |
| if "ogm_visual_conf" in ts: print(f" OGM-GE: v_conf={ts['ogm_visual_conf']:.4f} a_conf={ts['ogm_audio_conf']:.4f} | cv={ts['ogm_coeff_visual']:.4f} ca={ts['ogm_coeff_audio']:.4f}") |
| for k in ["train_loss","val_loss","val_accuracy","val_macro_f1","train_loss_fusion","train_loss_visual","train_loss_audio"]: self.history[k].append(ts.get(k, vm.get(k, 0))) |
| if "ogm_visual_conf" in ts: |
| for k in ["ogm_visual_conf","ogm_audio_conf","ogm_coeff_visual","ogm_coeff_audio"]: self.history[k].append(ts[k]) |
| if vm[self.config.metric_for_best_model] > self.best_metric: |
| self.best_metric, self.best_epoch = vm[self.config.metric_for_best_model], epoch + 1 |
| os.makedirs(self.config.output_dir, exist_ok=True) |
| torch.save({"model_state_dict": self.model.state_dict(), "epoch": epoch+1, "metrics": vm}, |
| os.path.join(self.config.output_dir, "best_model.pt")) |
| print(f" ✓ Best model saved (F1={self.best_metric:.4f})") |
| print(f"\nDone. Best epoch={self.best_epoch}, F1={self.best_metric:.4f}") |
| return self.history |
|
|
| def run_robustness_evaluation(self): |
| print("\n=== Missing Modality Robustness ===") |
| results = {} |
| for name, mask in [("both", None), ("visual_only", {"visual": 1.0, "audio": 0.0}), ("audio_only", {"visual": 0.0, "audio": 1.0})]: |
| m = self.evaluate(modality_mask=mask) |
| results[name] = {"accuracy": m["accuracy"], "macro_f1": m["macro_f1"]} |
| print(f" {name:15s}: Acc={m['accuracy']:.4f} F1={m['macro_f1']:.4f}") |
| for cls, cm in m["per_class"].items(): print(f" {cls:25s} P:{cm['precision']:.3f} R:{cm['recall']:.3f} F1:{cm['f1']:.3f}") |
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--mode", default="multimodal", choices=["multimodal", "visual_only", "audio_only"]) |
| parser.add_argument("--finetune", default="lora", choices=["lora", "full", "linear_probe"]) |
| parser.add_argument("--epochs", type=int); parser.add_argument("--batch_size", type=int) |
| parser.add_argument("--lr", type=float); parser.add_argument("--fusion", default="concat") |
| parser.add_argument("--modality_dropout", type=float); parser.add_argument("--output_dir", type=str) |
| parser.add_argument("--hub_model_id", type=str); parser.add_argument("--no_push", action="store_true") |
| parser.add_argument("--eval_robustness", action="store_true"); parser.add_argument("--quick_test", action="store_true") |
| parser.add_argument("--no_ogm", action="store_true") |
| parser.add_argument("--ogm_alpha", type=float); parser.add_argument("--ogm_noise_sigma", type=float) |
| parser.add_argument("--lambda_visual", type=float); parser.add_argument("--lambda_audio", type=float) |
| parser.add_argument("--visual_lr_mult", type=float); parser.add_argument("--audio_lr_mult", type=float) |
| args = parser.parse_args() |
|
|
| config = ExperimentConfig() |
| config.train.mode, config.train.finetune_method, config.model.fusion_type = args.mode, args.finetune, args.fusion |
| if args.epochs: config.train.num_epochs = args.epochs |
| if args.batch_size: config.train.per_device_train_batch_size = args.batch_size |
| if args.lr: config.train.learning_rate = config.train.lora_learning_rate = args.lr |
| if args.modality_dropout is not None: config.model.modality_dropout_p = args.modality_dropout |
| if args.output_dir: config.train.output_dir = args.output_dir |
| if args.hub_model_id: config.train.hub_model_id = args.hub_model_id |
| if args.no_push: config.train.push_to_hub = False |
| if args.quick_test: |
| config.train.num_epochs, config.train.per_device_train_batch_size = 2, 4 |
| config.train.per_device_eval_batch_size, config.train.gradient_accumulation_steps = 4, 1 |
| config.train.logging_steps = 2 |
| if args.finetune != "lora": config.lora.enabled = False |
|
|
| ogm_a = args.ogm_alpha if args.ogm_alpha is not None else config.ogm_alpha |
| ogm_s = args.ogm_noise_sigma if args.ogm_noise_sigma is not None else config.ogm_noise_sigma |
| lv = args.lambda_visual if args.lambda_visual is not None else config.lambda_visual |
| la = args.lambda_audio if args.lambda_audio is not None else config.lambda_audio |
| vlr = args.visual_lr_mult if args.visual_lr_mult is not None else config.visual_lr_multiplier |
| alr = args.audio_lr_mult if args.audio_lr_mult is not None else config.audio_lr_multiplier |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch.manual_seed(config.train.seed); np.random.seed(config.train.seed) |
|
|
| print(f"\n{'='*60}\nMultimodal PC Fault Detection\n{'='*60}") |
| print(f"Mode: {args.mode} | Finetune: {args.finetune} | Device: {device}") |
| print(f"Dataset: {config.hub_dataset}") |
| print(f"OGM-GE: {'OFF' if args.no_ogm else 'ON'} (α={ogm_a}, σ={ogm_s})") |
| print(f"λ_visual={lv}, λ_audio={la}, LR mult: visual={vlr}x audio={alr}x") |
|
|
| vit_proc, ast_ext = get_processors(config.model) |
| train_ds = PCFaultDataset(config.data, config.model, "train", vit_proc, ast_ext, True, |
| source="hub", hub_dataset=config.hub_dataset) |
| val_ds = PCFaultDataset(config.data, config.model, "val", vit_proc, ast_ext, False, |
| source="hub", hub_dataset=config.hub_dataset) |
|
|
| model = create_model(config.model, config.lora, mode=args.mode, finetune_method=args.finetune, |
| use_ogm=not args.no_ogm, lambda_visual=lv, lambda_audio=la) |
| trainer = MultimodalTrainer(model, train_ds, val_ds, config.train, device, |
| use_ogm=not args.no_ogm, ogm_alpha=ogm_a, ogm_noise_sigma=ogm_s, |
| visual_lr_multiplier=vlr, audio_lr_multiplier=alr) |
|
|
| history = trainer.train() |
| final = trainer.evaluate() |
| print(f"\nFinal: Acc={final['accuracy']:.4f} F1={final['macro_f1']:.4f}") |
| for cls, m in final["per_class"].items(): print(f" {cls:25s} P:{m['precision']:.3f} R:{m['recall']:.3f} F1:{m['f1']:.3f} N:{m['support']}") |
|
|
| robustness = trainer.run_robustness_evaluation() if args.eval_robustness and args.mode == "multimodal" else None |
|
|
| os.makedirs(config.train.output_dir, exist_ok=True) |
| results = {"experiment": config.experiment_name, "mode": args.mode, "finetune": args.finetune, |
| "dataset": config.hub_dataset, "ogm_ge": not args.no_ogm, |
| "final_metrics": {"accuracy": final["accuracy"], "macro_f1": final["macro_f1"], |
| "weighted_f1": final["weighted_f1"], "per_class": final["per_class"], "confusion_matrix": final["confusion_matrix"]}, |
| "history": history, "best_epoch": trainer.best_epoch, "best_metric": trainer.best_metric} |
| if robustness: results["robustness"] = robustness |
| with open(os.path.join(config.train.output_dir, "results.json"), "w") as f: json.dump(results, f, indent=2) |
|
|
| if config.train.push_to_hub: |
| try: |
| from huggingface_hub import HfApi, login |
| login(token=os.environ.get("HF_TOKEN")) |
| HfApi().upload_folder(folder_path=config.train.output_dir, repo_id=config.train.hub_model_id, |
| repo_type="model", commit_message=f"Training: {config.experiment_name}") |
| print(f"✓ Pushed to https://huggingface.co/{config.train.hub_model_id}") |
| except Exception as e: print(f"✗ Push failed: {e}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|