Ellaft's picture
Overwrite models.py with v2 architecture (auxiliary heads + OGM-GE + anti-collapse)
4e62071 verified
"""
Multimodal PC Fault Detection - Model Architecture
====================================================
Two-branch architecture with anti-modality-collapse features:
- Visual: ViT-B/16 (ImageNet-21k) + LoRA
- Audio: AST (AudioSet) + LoRA
- Fusion: Late fusion (concat / weighted sum / attention)
- Auxiliary unimodal classification heads
- OGM-GE gradient modulation support (Peng et al., CVPR 2022)
Loss = L_fusion + λ_visual * L_visual + λ_audio * L_audio
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Literal
from transformers import ViTModel, ASTModel, ViTImageProcessor, ASTFeatureExtractor
from peft import LoraConfig, get_peft_model
from config import ModelConfig, LoRAConfig, FAULT_CLASSES
class VisualBranch(nn.Module):
def __init__(self, config, lora_config=None, finetune_method="lora"):
super().__init__()
self.vit = ViTModel.from_pretrained(config.vit_model_name)
if finetune_method == "lora" and lora_config and lora_config.enabled:
peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
target_modules=lora_config.vit_target_modules,
lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
self.vit = get_peft_model(self.vit, peft_config)
self.vit.print_trainable_parameters()
elif finetune_method == "linear_probe":
for param in self.vit.parameters(): param.requires_grad = False
def forward(self, pixel_values):
return self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :]
class AudioBranch(nn.Module):
def __init__(self, config, lora_config=None, finetune_method="lora"):
super().__init__()
self.ast = ASTModel.from_pretrained(config.ast_model_name)
if finetune_method == "lora" and lora_config and lora_config.enabled:
peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
target_modules=lora_config.ast_target_modules,
lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
self.ast = get_peft_model(self.ast, peft_config)
self.ast.print_trainable_parameters()
elif finetune_method == "linear_probe":
for param in self.ast.parameters(): param.requires_grad = False
def forward(self, input_values):
return self.ast(input_values=input_values).last_hidden_state[:, 0, :]
class LateFusion(nn.Module):
def __init__(self, config):
super().__init__()
self.fusion_type = config.fusion_type
if config.fusion_type == "concat":
self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
self.classifier = nn.Sequential(
nn.LayerNorm(config.fusion_dim * 2), nn.Dropout(config.fusion_dropout),
nn.Linear(config.fusion_dim * 2, config.fusion_dim), nn.GELU(),
nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes))
elif config.fusion_type == "weighted_sum":
self.visual_head = nn.Linear(config.vit_embed_dim, config.num_classes)
self.audio_head = nn.Linear(config.ast_embed_dim, config.num_classes)
self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5]))
elif config.fusion_type == "attention":
self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
self.cross_attn = nn.MultiheadAttention(embed_dim=config.fusion_dim, num_heads=8,
dropout=config.fusion_dropout, batch_first=True)
self.classifier = nn.Sequential(nn.LayerNorm(config.fusion_dim),
nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes))
def forward(self, visual_emb, audio_emb, modality_mask=None):
if modality_mask:
visual_emb = visual_emb * modality_mask.get("visual", 1.0)
audio_emb = audio_emb * modality_mask.get("audio", 1.0)
if self.fusion_type == "concat":
return self.classifier(torch.cat([self.visual_proj(visual_emb), self.audio_proj(audio_emb)], dim=-1))
elif self.fusion_type == "weighted_sum":
w = torch.softmax(self.fusion_weights, dim=0)
return w[0] * self.visual_head(visual_emb) + w[1] * self.audio_head(audio_emb)
elif self.fusion_type == "attention":
tokens = torch.cat([self.visual_proj(visual_emb).unsqueeze(1), self.audio_proj(audio_emb).unsqueeze(1)], dim=1)
return self.classifier(self.cross_attn(tokens, tokens, tokens)[0].mean(dim=1))
class OGMGEModulator:
"""OGM-GE from Peng et al., CVPR 2022. Suppresses dominant modality gradients."""
def __init__(self, alpha=0.3, noise_sigma=0.1):
self.alpha = alpha
self.noise_sigma = noise_sigma
@torch.no_grad()
def compute_modulation_coefficients(self, visual_logits, audio_logits, labels):
v_probs = F.softmax(visual_logits, dim=-1)
a_probs = F.softmax(audio_logits, dim=-1)
batch_idx = torch.arange(labels.size(0), device=labels.device)
v_conf = v_probs[batch_idx, labels].mean().item()
a_conf = a_probs[batch_idx, labels].mean().item()
eps = 1e-8
ratio = (v_conf + eps) / (a_conf + eps)
if ratio > 1.0:
coeff_visual = 1.0 - self.alpha * torch.tanh(torch.tensor(ratio - 1.0)).item()
coeff_audio = 1.0
else:
coeff_visual = 1.0
coeff_audio = 1.0 - self.alpha * torch.tanh(torch.tensor(1.0 / ratio - 1.0)).item()
return coeff_visual, coeff_audio, {"visual_conf": v_conf, "audio_conf": a_conf,
"ratio": ratio, "coeff_visual": coeff_visual, "coeff_audio": coeff_audio}
def apply_gradient_modulation(self, model, coeff_visual, coeff_audio):
for name, param in model.named_parameters():
if param.grad is None: continue
if "visual_branch" in name:
param.grad.data.mul_(coeff_visual)
if coeff_visual < 1.0 and self.noise_sigma > 0:
param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean())
elif "audio_branch" in name:
param.grad.data.mul_(coeff_audio)
if coeff_audio < 1.0 and self.noise_sigma > 0:
param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean())
class MultimodalPCFaultDetector(nn.Module):
def __init__(self, model_config, lora_config=None, finetune_method="lora",
mode="multimodal", use_ogm=True, lambda_visual=1.5, lambda_audio=0.5):
super().__init__()
self.mode, self.modality_dropout_p = mode, model_config.modality_dropout_p
self.use_ogm, self.lambda_visual, self.lambda_audio = use_ogm, lambda_visual, lambda_audio
self.visual_branch = VisualBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "visual_only") else None
self.audio_branch = AudioBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "audio_only") else None
if mode == "multimodal":
self.fusion = LateFusion(model_config)
self.visual_head = nn.Sequential(nn.LayerNorm(model_config.vit_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.vit_embed_dim, model_config.num_classes))
self.audio_head = nn.Sequential(nn.LayerNorm(model_config.ast_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.ast_embed_dim, model_config.num_classes))
else:
embed_dim = model_config.vit_embed_dim if mode == "visual_only" else model_config.ast_embed_dim
self.classifier = nn.Sequential(nn.LayerNorm(embed_dim), nn.Dropout(model_config.fusion_dropout),
nn.Linear(embed_dim, model_config.fusion_dim), nn.GELU(), nn.Dropout(model_config.fusion_dropout),
nn.Linear(model_config.fusion_dim, model_config.num_classes))
self.loss_fn = nn.CrossEntropyLoss()
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[Model] Mode={mode}, Total={total:,}, Trainable={trainable:,} ({100*trainable/total:.2f}%)")
def forward(self, pixel_values=None, audio_values=None, labels=None):
if self.mode == "multimodal":
v_emb, a_emb = self.visual_branch(pixel_values), self.audio_branch(audio_values)
mask = None
if self.training and self.modality_dropout_p > 0:
mask = {"visual": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0,
"audio": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0}
if mask["visual"] == 0.0 and mask["audio"] == 0.0:
mask["visual" if torch.rand(1).item() < 0.5 else "audio"] = 1.0
logits = self.fusion(v_emb, a_emb, mask)
visual_logits, audio_logits = self.visual_head(v_emb), self.audio_head(a_emb)
outputs = {"logits": logits, "visual_logits": visual_logits, "audio_logits": audio_logits, "visual_emb": v_emb, "audio_emb": a_emb}
if labels is not None:
loss_f, loss_v, loss_a = self.loss_fn(logits, labels), self.loss_fn(visual_logits, labels), self.loss_fn(audio_logits, labels)
outputs.update({"loss": loss_f + self.lambda_visual * loss_v + self.lambda_audio * loss_a,
"loss_fusion": loss_f.item(), "loss_visual": loss_v.item(), "loss_audio": loss_a.item()})
elif self.mode == "visual_only":
logits = self.classifier(self.visual_branch(pixel_values))
outputs = {"logits": logits}
if labels is not None: outputs["loss"] = self.loss_fn(logits, labels)
else:
logits = self.classifier(self.audio_branch(audio_values))
outputs = {"logits": logits}
if labels is not None: outputs["loss"] = self.loss_fn(logits, labels)
return outputs
def create_model(model_config, lora_config, mode="multimodal", finetune_method="lora",
use_ogm=True, lambda_visual=1.5, lambda_audio=0.5):
return MultimodalPCFaultDetector(model_config, lora_config, finetune_method, mode,
use_ogm=use_ogm, lambda_visual=lambda_visual, lambda_audio=lambda_audio)
def get_processors(model_config):
return (ViTImageProcessor.from_pretrained(model_config.vit_model_name),
ASTFeatureExtractor.from_pretrained(model_config.ast_model_name))