""" Multimodal PC Fault Detection - Training Script ================================================= Loads dataset from Ellaft/pc-fault-real-dataset (HuggingFace Hub). Uses OGM-GE anti-modality-collapse + auxiliary heads + asymmetric LRs. Usage: python train.py # Full training python train.py --quick_test --no_push # Quick 2-epoch test python train.py --eval_robustness # + missing modality test python train.py --mode visual_only --no_push # Unimodal ablation python train.py --mode audio_only --no_push # Unimodal ablation All 6 ablations: python run_ablations.py python run_ablations.py --quick_test """ import os, sys, json, argparse, time import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.optim import AdamW from torch.optim.lr_scheduler import OneCycleLR from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support from config import ExperimentConfig, FAULT_CLASSES, NUM_CLASSES from models import create_model, get_processors, OGMGEModulator from dataset_v2 import BuiltDataset as PCFaultDataset, multimodal_collate_fn def compute_metrics(preds, labels, class_names=FAULT_CLASSES): accuracy = accuracy_score(labels, preds) precision, recall, f1, support = precision_recall_fscore_support( labels, preds, average=None, labels=range(len(class_names)), zero_division=0) macro_f1 = f1_score(labels, preds, average="macro", zero_division=0) weighted_f1 = f1_score(labels, preds, average="weighted", zero_division=0) conf_matrix = confusion_matrix(labels, preds, labels=range(len(class_names))) metrics = {"accuracy": accuracy, "macro_f1": macro_f1, "weighted_f1": weighted_f1, "confusion_matrix": conf_matrix.tolist(), "per_class": {}} for i, name in enumerate(class_names): metrics["per_class"][name] = {"precision": precision[i], "recall": recall[i], "f1": f1[i], "support": int(support[i])} return metrics class MultimodalTrainer: def __init__(self, model, train_dataset, val_dataset, config, device, use_ogm=True, ogm_alpha=0.3, ogm_noise_sigma=0.1, visual_lr_multiplier=3.0, audio_lr_multiplier=0.5): self.model = model.to(device) self.device, self.config = device, config self.use_ogm = use_ogm and (model.mode == "multimodal") self.ogm = OGMGEModulator(alpha=ogm_alpha, noise_sigma=ogm_noise_sigma) if self.use_ogm else None if self.use_ogm: print(f"[Trainer] OGM-GE enabled: alpha={ogm_alpha}, noise_sigma={ogm_noise_sigma}") self.train_loader = DataLoader(train_dataset, batch_size=config.per_device_train_batch_size, shuffle=True, collate_fn=multimodal_collate_fn, num_workers=2, pin_memory=True, drop_last=True) self.val_loader = DataLoader(val_dataset, batch_size=config.per_device_eval_batch_size, shuffle=False, collate_fn=multimodal_collate_fn, num_workers=2, pin_memory=True) param_groups = self._get_param_groups(visual_lr_multiplier, audio_lr_multiplier) self.optimizer = AdamW(param_groups, weight_decay=config.weight_decay) total_steps = max(len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps, 1) self.scheduler = OneCycleLR(self.optimizer, max_lr=[pg["lr"] for pg in param_groups], total_steps=total_steps, pct_start=config.warmup_ratio, anneal_strategy="cos") self.scaler = torch.amp.GradScaler("cuda") if config.fp16 and device.type == "cuda" else None self.best_metric, self.best_epoch = 0.0, 0 self.history = {"train_loss": [], "val_loss": [], "val_accuracy": [], "val_macro_f1": [], "train_loss_fusion": [], "train_loss_visual": [], "train_loss_audio": [], "ogm_visual_conf": [], "ogm_audio_conf": [], "ogm_coeff_visual": [], "ogm_coeff_audio": []} def _get_param_groups(self, visual_lr_mult, audio_lr_mult): visual_p, audio_p, fusion_p = [], [], [] for name, p in self.model.named_parameters(): if not p.requires_grad: continue if "visual_branch" in name: visual_p.append(p) elif "audio_branch" in name: audio_p.append(p) else: fusion_p.append(p) base_lr = self.config.lora_learning_rate groups = [] if visual_p: groups.append({"params": visual_p, "lr": base_lr * visual_lr_mult, "name": "visual"}) if audio_p: groups.append({"params": audio_p, "lr": base_lr * audio_lr_mult, "name": "audio"}) if fusion_p: groups.append({"params": fusion_p, "lr": base_lr, "name": "fusion"}) for g in groups: print(f" [Param Group] {g.get('name','?')}: {len(g['params'])} tensors, lr={g['lr']:.2e}") if not groups: raise ValueError("No trainable parameters!") return groups def train_epoch(self, epoch): self.model.train() tot_loss, tot_fus, tot_vis, tot_aud, n = 0., 0., 0., 0., 0 ogm_vc, ogm_ac, ogm_cv, ogm_ca = [], [], [], [] self.optimizer.zero_grad() for bi, batch in enumerate(self.train_loader): pv, av, labels = batch["pixel_values"].to(self.device), batch["audio_values"].to(self.device), batch["labels"].to(self.device) if self.scaler: with torch.amp.autocast("cuda"): out = self.model(pixel_values=pv, audio_values=av, labels=labels) loss = out["loss"] / self.config.gradient_accumulation_steps self.scaler.scale(loss).backward() else: out = self.model(pixel_values=pv, audio_values=av, labels=labels) loss = out["loss"] / self.config.gradient_accumulation_steps loss.backward() tot_loss += loss.item() * self.config.gradient_accumulation_steps; n += 1 if "loss_fusion" in out: tot_fus += out["loss_fusion"]; tot_vis += out["loss_visual"]; tot_aud += out["loss_audio"] if self.use_ogm and "visual_logits" in out: cv, ca, st = self.ogm.compute_modulation_coefficients(out["visual_logits"], out["audio_logits"], labels) ogm_vc.append(st["visual_conf"]); ogm_ac.append(st["audio_conf"]); ogm_cv.append(cv); ogm_ca.append(ca) if (bi + 1) % self.config.gradient_accumulation_steps == 0: if self.scaler: self.scaler.unscale_(self.optimizer) if self.use_ogm and ogm_cv: self.ogm.apply_gradient_modulation(self.model, ogm_cv[-1], ogm_ca[-1]) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) if self.scaler: self.scaler.step(self.optimizer); self.scaler.update() else: self.optimizer.step() self.scheduler.step(); self.optimizer.zero_grad() if (bi + 1) % self.config.logging_steps == 0 or bi == 0: msg = f" [Epoch {epoch+1}] Step {bi+1}/{len(self.train_loader)} | Loss: {tot_loss/n:.4f}" if "loss_fusion" in out: msg += f" | fus:{tot_fus/n:.4f} vis:{tot_vis/n:.4f} aud:{tot_aud/n:.4f}" if ogm_cv: msg += f" | OGM cv:{ogm_cv[-1]:.3f} ca:{ogm_ca[-1]:.3f}" print(msg) stats = {"train_loss": tot_loss/max(n,1), "loss_fusion": tot_fus/max(n,1), "loss_visual": tot_vis/max(n,1), "loss_audio": tot_aud/max(n,1)} if ogm_vc: stats.update({"ogm_visual_conf": np.mean(ogm_vc), "ogm_audio_conf": np.mean(ogm_ac), "ogm_coeff_visual": np.mean(ogm_cv), "ogm_coeff_audio": np.mean(ogm_ca)}) return stats @torch.no_grad() def evaluate(self, modality_mask=None): self.model.eval() preds, lbls, tot_loss, n = [], [], 0., 0 for batch in self.val_loader: pv, av, labels = batch["pixel_values"].to(self.device), batch["audio_values"].to(self.device), batch["labels"].to(self.device) if modality_mask: if modality_mask.get("visual", 1.0) == 0.0: pv = torch.zeros_like(pv) if modality_mask.get("audio", 1.0) == 0.0: av = torch.zeros_like(av) out = self.model(pixel_values=pv, audio_values=av, labels=labels) tot_loss += out["loss"].item(); n += 1 preds.extend(out["logits"].argmax(dim=-1).cpu().numpy()); lbls.extend(labels.cpu().numpy()) m = compute_metrics(np.array(preds), np.array(lbls)) m["val_loss"] = tot_loss / max(n, 1) return m def train(self): print(f"\n{'='*60}\nTraining: mode={self.model.mode}, epochs={self.config.num_epochs}, device={self.device}") print(f"OGM-GE: {'ON' if self.use_ogm else 'OFF'}") if self.model.mode == "multimodal": print(f"λ_visual={self.model.lambda_visual}, λ_audio={self.model.lambda_audio}") print(f"{'='*60}\n") for epoch in range(self.config.num_epochs): t0 = time.time() ts = self.train_epoch(epoch) vm = self.evaluate() print(f"\n[Epoch {epoch+1}/{self.config.num_epochs}] ({time.time()-t0:.1f}s)") msg = f" Train Loss: {ts['train_loss']:.4f}" if ts.get("loss_fusion", 0) > 0: msg += f" (fus={ts['loss_fusion']:.4f} vis={ts['loss_visual']:.4f} aud={ts['loss_audio']:.4f})" print(msg) print(f" Val Loss: {vm['val_loss']:.4f} | Acc: {vm['accuracy']:.4f} | F1: {vm['macro_f1']:.4f}") if "ogm_visual_conf" in ts: print(f" OGM-GE: v_conf={ts['ogm_visual_conf']:.4f} a_conf={ts['ogm_audio_conf']:.4f} | cv={ts['ogm_coeff_visual']:.4f} ca={ts['ogm_coeff_audio']:.4f}") for k in ["train_loss","val_loss","val_accuracy","val_macro_f1","train_loss_fusion","train_loss_visual","train_loss_audio"]: self.history[k].append(ts.get(k, vm.get(k, 0))) if "ogm_visual_conf" in ts: for k in ["ogm_visual_conf","ogm_audio_conf","ogm_coeff_visual","ogm_coeff_audio"]: self.history[k].append(ts[k]) if vm[self.config.metric_for_best_model] > self.best_metric: self.best_metric, self.best_epoch = vm[self.config.metric_for_best_model], epoch + 1 os.makedirs(self.config.output_dir, exist_ok=True) torch.save({"model_state_dict": self.model.state_dict(), "epoch": epoch+1, "metrics": vm}, os.path.join(self.config.output_dir, "best_model.pt")) print(f" ✓ Best model saved (F1={self.best_metric:.4f})") print(f"\nDone. Best epoch={self.best_epoch}, F1={self.best_metric:.4f}") return self.history def run_robustness_evaluation(self): print("\n=== Missing Modality Robustness ===") results = {} for name, mask in [("both", None), ("visual_only", {"visual": 1.0, "audio": 0.0}), ("audio_only", {"visual": 0.0, "audio": 1.0})]: m = self.evaluate(modality_mask=mask) results[name] = {"accuracy": m["accuracy"], "macro_f1": m["macro_f1"]} print(f" {name:15s}: Acc={m['accuracy']:.4f} F1={m['macro_f1']:.4f}") for cls, cm in m["per_class"].items(): print(f" {cls:25s} P:{cm['precision']:.3f} R:{cm['recall']:.3f} F1:{cm['f1']:.3f}") return results def main(): parser = argparse.ArgumentParser() parser.add_argument("--mode", default="multimodal", choices=["multimodal", "visual_only", "audio_only"]) parser.add_argument("--finetune", default="lora", choices=["lora", "full", "linear_probe"]) parser.add_argument("--epochs", type=int); parser.add_argument("--batch_size", type=int) parser.add_argument("--lr", type=float); parser.add_argument("--fusion", default="concat") parser.add_argument("--modality_dropout", type=float); parser.add_argument("--output_dir", type=str) parser.add_argument("--hub_model_id", type=str); parser.add_argument("--no_push", action="store_true") parser.add_argument("--eval_robustness", action="store_true"); parser.add_argument("--quick_test", action="store_true") parser.add_argument("--no_ogm", action="store_true") parser.add_argument("--ogm_alpha", type=float); parser.add_argument("--ogm_noise_sigma", type=float) parser.add_argument("--lambda_visual", type=float); parser.add_argument("--lambda_audio", type=float) parser.add_argument("--visual_lr_mult", type=float); parser.add_argument("--audio_lr_mult", type=float) args = parser.parse_args() config = ExperimentConfig() config.train.mode, config.train.finetune_method, config.model.fusion_type = args.mode, args.finetune, args.fusion if args.epochs: config.train.num_epochs = args.epochs if args.batch_size: config.train.per_device_train_batch_size = args.batch_size if args.lr: config.train.learning_rate = config.train.lora_learning_rate = args.lr if args.modality_dropout is not None: config.model.modality_dropout_p = args.modality_dropout if args.output_dir: config.train.output_dir = args.output_dir if args.hub_model_id: config.train.hub_model_id = args.hub_model_id if args.no_push: config.train.push_to_hub = False if args.quick_test: config.train.num_epochs, config.train.per_device_train_batch_size = 2, 4 config.train.per_device_eval_batch_size, config.train.gradient_accumulation_steps = 4, 1 config.train.logging_steps = 2 if args.finetune != "lora": config.lora.enabled = False ogm_a = args.ogm_alpha if args.ogm_alpha is not None else config.ogm_alpha ogm_s = args.ogm_noise_sigma if args.ogm_noise_sigma is not None else config.ogm_noise_sigma lv = args.lambda_visual if args.lambda_visual is not None else config.lambda_visual la = args.lambda_audio if args.lambda_audio is not None else config.lambda_audio vlr = args.visual_lr_mult if args.visual_lr_mult is not None else config.visual_lr_multiplier alr = args.audio_lr_mult if args.audio_lr_mult is not None else config.audio_lr_multiplier device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(config.train.seed); np.random.seed(config.train.seed) print(f"\n{'='*60}\nMultimodal PC Fault Detection\n{'='*60}") print(f"Mode: {args.mode} | Finetune: {args.finetune} | Device: {device}") print(f"Dataset: {config.hub_dataset}") print(f"OGM-GE: {'OFF' if args.no_ogm else 'ON'} (α={ogm_a}, σ={ogm_s})") print(f"λ_visual={lv}, λ_audio={la}, LR mult: visual={vlr}x audio={alr}x") vit_proc, ast_ext = get_processors(config.model) train_ds = PCFaultDataset(config.data, config.model, "train", vit_proc, ast_ext, True, source="hub", hub_dataset=config.hub_dataset) val_ds = PCFaultDataset(config.data, config.model, "val", vit_proc, ast_ext, False, source="hub", hub_dataset=config.hub_dataset) model = create_model(config.model, config.lora, mode=args.mode, finetune_method=args.finetune, use_ogm=not args.no_ogm, lambda_visual=lv, lambda_audio=la) trainer = MultimodalTrainer(model, train_ds, val_ds, config.train, device, use_ogm=not args.no_ogm, ogm_alpha=ogm_a, ogm_noise_sigma=ogm_s, visual_lr_multiplier=vlr, audio_lr_multiplier=alr) history = trainer.train() final = trainer.evaluate() print(f"\nFinal: Acc={final['accuracy']:.4f} F1={final['macro_f1']:.4f}") for cls, m in final["per_class"].items(): print(f" {cls:25s} P:{m['precision']:.3f} R:{m['recall']:.3f} F1:{m['f1']:.3f} N:{m['support']}") robustness = trainer.run_robustness_evaluation() if args.eval_robustness and args.mode == "multimodal" else None os.makedirs(config.train.output_dir, exist_ok=True) results = {"experiment": config.experiment_name, "mode": args.mode, "finetune": args.finetune, "dataset": config.hub_dataset, "ogm_ge": not args.no_ogm, "final_metrics": {"accuracy": final["accuracy"], "macro_f1": final["macro_f1"], "weighted_f1": final["weighted_f1"], "per_class": final["per_class"], "confusion_matrix": final["confusion_matrix"]}, "history": history, "best_epoch": trainer.best_epoch, "best_metric": trainer.best_metric} if robustness: results["robustness"] = robustness with open(os.path.join(config.train.output_dir, "results.json"), "w") as f: json.dump(results, f, indent=2) if config.train.push_to_hub: try: from huggingface_hub import HfApi, login login(token=os.environ.get("HF_TOKEN")) HfApi().upload_folder(folder_path=config.train.output_dir, repo_id=config.train.hub_model_id, repo_type="model", commit_message=f"Training: {config.experiment_name}") print(f"✓ Pushed to https://huggingface.co/{config.train.hub_model_id}") except Exception as e: print(f"✗ Push failed: {e}") if __name__ == "__main__": main()