Ellaft's picture
Overwrite train.py: defaults to Hub dataset (Ellaft/pc-fault-real-dataset), uses models.py + dataset_v2.py, OGM-GE enabled
2df13e2 verified
"""
Multimodal PC Fault Detection - Training Script
=================================================
Loads dataset from Ellaft/pc-fault-real-dataset (HuggingFace Hub).
Uses OGM-GE anti-modality-collapse + auxiliary heads + asymmetric LRs.
Usage:
python train.py # Full training
python train.py --quick_test --no_push # Quick 2-epoch test
python train.py --eval_robustness # + missing modality test
python train.py --mode visual_only --no_push # Unimodal ablation
python train.py --mode audio_only --no_push # Unimodal ablation
All 6 ablations:
python run_ablations.py
python run_ablations.py --quick_test
"""
import os, sys, json, argparse, time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support
from config import ExperimentConfig, FAULT_CLASSES, NUM_CLASSES
from models import create_model, get_processors, OGMGEModulator
from dataset_v2 import BuiltDataset as PCFaultDataset, multimodal_collate_fn
def compute_metrics(preds, labels, class_names=FAULT_CLASSES):
accuracy = accuracy_score(labels, preds)
precision, recall, f1, support = precision_recall_fscore_support(
labels, preds, average=None, labels=range(len(class_names)), zero_division=0)
macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)
weighted_f1 = f1_score(labels, preds, average="weighted", zero_division=0)
conf_matrix = confusion_matrix(labels, preds, labels=range(len(class_names)))
metrics = {"accuracy": accuracy, "macro_f1": macro_f1, "weighted_f1": weighted_f1,
"confusion_matrix": conf_matrix.tolist(), "per_class": {}}
for i, name in enumerate(class_names):
metrics["per_class"][name] = {"precision": precision[i], "recall": recall[i],
"f1": f1[i], "support": int(support[i])}
return metrics
class MultimodalTrainer:
def __init__(self, model, train_dataset, val_dataset, config, device,
use_ogm=True, ogm_alpha=0.3, ogm_noise_sigma=0.1,
visual_lr_multiplier=3.0, audio_lr_multiplier=0.5):
self.model = model.to(device)
self.device, self.config = device, config
self.use_ogm = use_ogm and (model.mode == "multimodal")
self.ogm = OGMGEModulator(alpha=ogm_alpha, noise_sigma=ogm_noise_sigma) if self.use_ogm else None
if self.use_ogm:
print(f"[Trainer] OGM-GE enabled: alpha={ogm_alpha}, noise_sigma={ogm_noise_sigma}")
self.train_loader = DataLoader(train_dataset, batch_size=config.per_device_train_batch_size,
shuffle=True, collate_fn=multimodal_collate_fn, num_workers=2, pin_memory=True, drop_last=True)
self.val_loader = DataLoader(val_dataset, batch_size=config.per_device_eval_batch_size,
shuffle=False, collate_fn=multimodal_collate_fn, num_workers=2, pin_memory=True)
param_groups = self._get_param_groups(visual_lr_multiplier, audio_lr_multiplier)
self.optimizer = AdamW(param_groups, weight_decay=config.weight_decay)
total_steps = max(len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps, 1)
self.scheduler = OneCycleLR(self.optimizer, max_lr=[pg["lr"] for pg in param_groups],
total_steps=total_steps, pct_start=config.warmup_ratio, anneal_strategy="cos")
self.scaler = torch.amp.GradScaler("cuda") if config.fp16 and device.type == "cuda" else None
self.best_metric, self.best_epoch = 0.0, 0
self.history = {"train_loss": [], "val_loss": [], "val_accuracy": [], "val_macro_f1": [],
"train_loss_fusion": [], "train_loss_visual": [], "train_loss_audio": [],
"ogm_visual_conf": [], "ogm_audio_conf": [], "ogm_coeff_visual": [], "ogm_coeff_audio": []}
def _get_param_groups(self, visual_lr_mult, audio_lr_mult):
visual_p, audio_p, fusion_p = [], [], []
for name, p in self.model.named_parameters():
if not p.requires_grad: continue
if "visual_branch" in name: visual_p.append(p)
elif "audio_branch" in name: audio_p.append(p)
else: fusion_p.append(p)
base_lr = self.config.lora_learning_rate
groups = []
if visual_p: groups.append({"params": visual_p, "lr": base_lr * visual_lr_mult, "name": "visual"})
if audio_p: groups.append({"params": audio_p, "lr": base_lr * audio_lr_mult, "name": "audio"})
if fusion_p: groups.append({"params": fusion_p, "lr": base_lr, "name": "fusion"})
for g in groups: print(f" [Param Group] {g.get('name','?')}: {len(g['params'])} tensors, lr={g['lr']:.2e}")
if not groups: raise ValueError("No trainable parameters!")
return groups
def train_epoch(self, epoch):
self.model.train()
tot_loss, tot_fus, tot_vis, tot_aud, n = 0., 0., 0., 0., 0
ogm_vc, ogm_ac, ogm_cv, ogm_ca = [], [], [], []
self.optimizer.zero_grad()
for bi, batch in enumerate(self.train_loader):
pv, av, labels = batch["pixel_values"].to(self.device), batch["audio_values"].to(self.device), batch["labels"].to(self.device)
if self.scaler:
with torch.amp.autocast("cuda"):
out = self.model(pixel_values=pv, audio_values=av, labels=labels)
loss = out["loss"] / self.config.gradient_accumulation_steps
self.scaler.scale(loss).backward()
else:
out = self.model(pixel_values=pv, audio_values=av, labels=labels)
loss = out["loss"] / self.config.gradient_accumulation_steps
loss.backward()
tot_loss += loss.item() * self.config.gradient_accumulation_steps; n += 1
if "loss_fusion" in out: tot_fus += out["loss_fusion"]; tot_vis += out["loss_visual"]; tot_aud += out["loss_audio"]
if self.use_ogm and "visual_logits" in out:
cv, ca, st = self.ogm.compute_modulation_coefficients(out["visual_logits"], out["audio_logits"], labels)
ogm_vc.append(st["visual_conf"]); ogm_ac.append(st["audio_conf"]); ogm_cv.append(cv); ogm_ca.append(ca)
if (bi + 1) % self.config.gradient_accumulation_steps == 0:
if self.scaler: self.scaler.unscale_(self.optimizer)
if self.use_ogm and ogm_cv: self.ogm.apply_gradient_modulation(self.model, ogm_cv[-1], ogm_ca[-1])
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
if self.scaler: self.scaler.step(self.optimizer); self.scaler.update()
else: self.optimizer.step()
self.scheduler.step(); self.optimizer.zero_grad()
if (bi + 1) % self.config.logging_steps == 0 or bi == 0:
msg = f" [Epoch {epoch+1}] Step {bi+1}/{len(self.train_loader)} | Loss: {tot_loss/n:.4f}"
if "loss_fusion" in out: msg += f" | fus:{tot_fus/n:.4f} vis:{tot_vis/n:.4f} aud:{tot_aud/n:.4f}"
if ogm_cv: msg += f" | OGM cv:{ogm_cv[-1]:.3f} ca:{ogm_ca[-1]:.3f}"
print(msg)
stats = {"train_loss": tot_loss/max(n,1), "loss_fusion": tot_fus/max(n,1), "loss_visual": tot_vis/max(n,1), "loss_audio": tot_aud/max(n,1)}
if ogm_vc: stats.update({"ogm_visual_conf": np.mean(ogm_vc), "ogm_audio_conf": np.mean(ogm_ac), "ogm_coeff_visual": np.mean(ogm_cv), "ogm_coeff_audio": np.mean(ogm_ca)})
return stats
@torch.no_grad()
def evaluate(self, modality_mask=None):
self.model.eval()
preds, lbls, tot_loss, n = [], [], 0., 0
for batch in self.val_loader:
pv, av, labels = batch["pixel_values"].to(self.device), batch["audio_values"].to(self.device), batch["labels"].to(self.device)
if modality_mask:
if modality_mask.get("visual", 1.0) == 0.0: pv = torch.zeros_like(pv)
if modality_mask.get("audio", 1.0) == 0.0: av = torch.zeros_like(av)
out = self.model(pixel_values=pv, audio_values=av, labels=labels)
tot_loss += out["loss"].item(); n += 1
preds.extend(out["logits"].argmax(dim=-1).cpu().numpy()); lbls.extend(labels.cpu().numpy())
m = compute_metrics(np.array(preds), np.array(lbls))
m["val_loss"] = tot_loss / max(n, 1)
return m
def train(self):
print(f"\n{'='*60}\nTraining: mode={self.model.mode}, epochs={self.config.num_epochs}, device={self.device}")
print(f"OGM-GE: {'ON' if self.use_ogm else 'OFF'}")
if self.model.mode == "multimodal": print(f"λ_visual={self.model.lambda_visual}, λ_audio={self.model.lambda_audio}")
print(f"{'='*60}\n")
for epoch in range(self.config.num_epochs):
t0 = time.time()
ts = self.train_epoch(epoch)
vm = self.evaluate()
print(f"\n[Epoch {epoch+1}/{self.config.num_epochs}] ({time.time()-t0:.1f}s)")
msg = f" Train Loss: {ts['train_loss']:.4f}"
if ts.get("loss_fusion", 0) > 0: msg += f" (fus={ts['loss_fusion']:.4f} vis={ts['loss_visual']:.4f} aud={ts['loss_audio']:.4f})"
print(msg)
print(f" Val Loss: {vm['val_loss']:.4f} | Acc: {vm['accuracy']:.4f} | F1: {vm['macro_f1']:.4f}")
if "ogm_visual_conf" in ts: print(f" OGM-GE: v_conf={ts['ogm_visual_conf']:.4f} a_conf={ts['ogm_audio_conf']:.4f} | cv={ts['ogm_coeff_visual']:.4f} ca={ts['ogm_coeff_audio']:.4f}")
for k in ["train_loss","val_loss","val_accuracy","val_macro_f1","train_loss_fusion","train_loss_visual","train_loss_audio"]: self.history[k].append(ts.get(k, vm.get(k, 0)))
if "ogm_visual_conf" in ts:
for k in ["ogm_visual_conf","ogm_audio_conf","ogm_coeff_visual","ogm_coeff_audio"]: self.history[k].append(ts[k])
if vm[self.config.metric_for_best_model] > self.best_metric:
self.best_metric, self.best_epoch = vm[self.config.metric_for_best_model], epoch + 1
os.makedirs(self.config.output_dir, exist_ok=True)
torch.save({"model_state_dict": self.model.state_dict(), "epoch": epoch+1, "metrics": vm},
os.path.join(self.config.output_dir, "best_model.pt"))
print(f" ✓ Best model saved (F1={self.best_metric:.4f})")
print(f"\nDone. Best epoch={self.best_epoch}, F1={self.best_metric:.4f}")
return self.history
def run_robustness_evaluation(self):
print("\n=== Missing Modality Robustness ===")
results = {}
for name, mask in [("both", None), ("visual_only", {"visual": 1.0, "audio": 0.0}), ("audio_only", {"visual": 0.0, "audio": 1.0})]:
m = self.evaluate(modality_mask=mask)
results[name] = {"accuracy": m["accuracy"], "macro_f1": m["macro_f1"]}
print(f" {name:15s}: Acc={m['accuracy']:.4f} F1={m['macro_f1']:.4f}")
for cls, cm in m["per_class"].items(): print(f" {cls:25s} P:{cm['precision']:.3f} R:{cm['recall']:.3f} F1:{cm['f1']:.3f}")
return results
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default="multimodal", choices=["multimodal", "visual_only", "audio_only"])
parser.add_argument("--finetune", default="lora", choices=["lora", "full", "linear_probe"])
parser.add_argument("--epochs", type=int); parser.add_argument("--batch_size", type=int)
parser.add_argument("--lr", type=float); parser.add_argument("--fusion", default="concat")
parser.add_argument("--modality_dropout", type=float); parser.add_argument("--output_dir", type=str)
parser.add_argument("--hub_model_id", type=str); parser.add_argument("--no_push", action="store_true")
parser.add_argument("--eval_robustness", action="store_true"); parser.add_argument("--quick_test", action="store_true")
parser.add_argument("--no_ogm", action="store_true")
parser.add_argument("--ogm_alpha", type=float); parser.add_argument("--ogm_noise_sigma", type=float)
parser.add_argument("--lambda_visual", type=float); parser.add_argument("--lambda_audio", type=float)
parser.add_argument("--visual_lr_mult", type=float); parser.add_argument("--audio_lr_mult", type=float)
args = parser.parse_args()
config = ExperimentConfig()
config.train.mode, config.train.finetune_method, config.model.fusion_type = args.mode, args.finetune, args.fusion
if args.epochs: config.train.num_epochs = args.epochs
if args.batch_size: config.train.per_device_train_batch_size = args.batch_size
if args.lr: config.train.learning_rate = config.train.lora_learning_rate = args.lr
if args.modality_dropout is not None: config.model.modality_dropout_p = args.modality_dropout
if args.output_dir: config.train.output_dir = args.output_dir
if args.hub_model_id: config.train.hub_model_id = args.hub_model_id
if args.no_push: config.train.push_to_hub = False
if args.quick_test:
config.train.num_epochs, config.train.per_device_train_batch_size = 2, 4
config.train.per_device_eval_batch_size, config.train.gradient_accumulation_steps = 4, 1
config.train.logging_steps = 2
if args.finetune != "lora": config.lora.enabled = False
ogm_a = args.ogm_alpha if args.ogm_alpha is not None else config.ogm_alpha
ogm_s = args.ogm_noise_sigma if args.ogm_noise_sigma is not None else config.ogm_noise_sigma
lv = args.lambda_visual if args.lambda_visual is not None else config.lambda_visual
la = args.lambda_audio if args.lambda_audio is not None else config.lambda_audio
vlr = args.visual_lr_mult if args.visual_lr_mult is not None else config.visual_lr_multiplier
alr = args.audio_lr_mult if args.audio_lr_mult is not None else config.audio_lr_multiplier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(config.train.seed); np.random.seed(config.train.seed)
print(f"\n{'='*60}\nMultimodal PC Fault Detection\n{'='*60}")
print(f"Mode: {args.mode} | Finetune: {args.finetune} | Device: {device}")
print(f"Dataset: {config.hub_dataset}")
print(f"OGM-GE: {'OFF' if args.no_ogm else 'ON'} (α={ogm_a}, σ={ogm_s})")
print(f"λ_visual={lv}, λ_audio={la}, LR mult: visual={vlr}x audio={alr}x")
vit_proc, ast_ext = get_processors(config.model)
train_ds = PCFaultDataset(config.data, config.model, "train", vit_proc, ast_ext, True,
source="hub", hub_dataset=config.hub_dataset)
val_ds = PCFaultDataset(config.data, config.model, "val", vit_proc, ast_ext, False,
source="hub", hub_dataset=config.hub_dataset)
model = create_model(config.model, config.lora, mode=args.mode, finetune_method=args.finetune,
use_ogm=not args.no_ogm, lambda_visual=lv, lambda_audio=la)
trainer = MultimodalTrainer(model, train_ds, val_ds, config.train, device,
use_ogm=not args.no_ogm, ogm_alpha=ogm_a, ogm_noise_sigma=ogm_s,
visual_lr_multiplier=vlr, audio_lr_multiplier=alr)
history = trainer.train()
final = trainer.evaluate()
print(f"\nFinal: Acc={final['accuracy']:.4f} F1={final['macro_f1']:.4f}")
for cls, m in final["per_class"].items(): print(f" {cls:25s} P:{m['precision']:.3f} R:{m['recall']:.3f} F1:{m['f1']:.3f} N:{m['support']}")
robustness = trainer.run_robustness_evaluation() if args.eval_robustness and args.mode == "multimodal" else None
os.makedirs(config.train.output_dir, exist_ok=True)
results = {"experiment": config.experiment_name, "mode": args.mode, "finetune": args.finetune,
"dataset": config.hub_dataset, "ogm_ge": not args.no_ogm,
"final_metrics": {"accuracy": final["accuracy"], "macro_f1": final["macro_f1"],
"weighted_f1": final["weighted_f1"], "per_class": final["per_class"], "confusion_matrix": final["confusion_matrix"]},
"history": history, "best_epoch": trainer.best_epoch, "best_metric": trainer.best_metric}
if robustness: results["robustness"] = robustness
with open(os.path.join(config.train.output_dir, "results.json"), "w") as f: json.dump(results, f, indent=2)
if config.train.push_to_hub:
try:
from huggingface_hub import HfApi, login
login(token=os.environ.get("HF_TOKEN"))
HfApi().upload_folder(folder_path=config.train.output_dir, repo_id=config.train.hub_model_id,
repo_type="model", commit_message=f"Training: {config.experiment_name}")
print(f"✓ Pushed to https://huggingface.co/{config.train.hub_model_id}")
except Exception as e: print(f"✗ Push failed: {e}")
if __name__ == "__main__":
main()