File size: 17,103 Bytes
58a5fea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 | """
Multimodal PC Fault Detection - Model Architecture v2
======================================================
Changes from v1:
- Auxiliary unimodal classification heads (force each branch to independently classify)
- Asymmetric loss weighting: λ_visual=1.5 (boost weak), λ_audio=0.5 (dampen dominant)
- OGM-GE (On-the-fly Gradient Modulation + Generalization Enhancement) support
- Forward returns per-branch logits + embeddings for OGM-GE gradient modulation
Two-branch architecture:
- Visual: ViT-B/16 pretrained on ImageNet-21k
- Audio: AST pretrained on AudioSet
- Fusion: Late fusion (concat / weighted sum / attention)
Supports LoRA, full fine-tuning, and linear probe modes.
References:
- OGM-GE: Peng et al., "Balanced Multimodal Learning via On-the-fly Gradient
Modulation", CVPR 2022 (arXiv: 2203.15332)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Literal
from transformers import ViTModel, ASTModel, ViTImageProcessor, ASTFeatureExtractor
from peft import LoraConfig, get_peft_model
from config import ModelConfig, LoRAConfig, FAULT_CLASSES
# ===========================================================================
# Branch Modules (unchanged from v1)
# ===========================================================================
class VisualBranch(nn.Module):
def __init__(self, config, lora_config=None, finetune_method="lora"):
super().__init__()
self.vit = ViTModel.from_pretrained(config.vit_model_name)
if finetune_method == "lora" and lora_config and lora_config.enabled:
peft_config = LoraConfig(
r=lora_config.r, lora_alpha=lora_config.lora_alpha,
target_modules=lora_config.vit_target_modules,
lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
self.vit = get_peft_model(self.vit, peft_config)
self.vit.print_trainable_parameters()
elif finetune_method == "linear_probe":
for param in self.vit.parameters():
param.requires_grad = False
def forward(self, pixel_values):
return self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :]
class AudioBranch(nn.Module):
def __init__(self, config, lora_config=None, finetune_method="lora"):
super().__init__()
self.ast = ASTModel.from_pretrained(config.ast_model_name)
if finetune_method == "lora" and lora_config and lora_config.enabled:
peft_config = LoraConfig(
r=lora_config.r, lora_alpha=lora_config.lora_alpha,
target_modules=lora_config.ast_target_modules,
lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
self.ast = get_peft_model(self.ast, peft_config)
self.ast.print_trainable_parameters()
elif finetune_method == "linear_probe":
for param in self.ast.parameters():
param.requires_grad = False
def forward(self, input_values):
return self.ast(input_values=input_values).last_hidden_state[:, 0, :]
# ===========================================================================
# Fusion Module (unchanged from v1)
# ===========================================================================
class LateFusion(nn.Module):
def __init__(self, config):
super().__init__()
self.fusion_type = config.fusion_type
if config.fusion_type == "concat":
self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
self.classifier = nn.Sequential(
nn.LayerNorm(config.fusion_dim * 2),
nn.Dropout(config.fusion_dropout),
nn.Linear(config.fusion_dim * 2, config.fusion_dim),
nn.GELU(),
nn.Dropout(config.fusion_dropout),
nn.Linear(config.fusion_dim, config.num_classes))
elif config.fusion_type == "weighted_sum":
self.visual_head = nn.Linear(config.vit_embed_dim, config.num_classes)
self.audio_head = nn.Linear(config.ast_embed_dim, config.num_classes)
self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5]))
elif config.fusion_type == "attention":
self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
self.cross_attn = nn.MultiheadAttention(
embed_dim=config.fusion_dim, num_heads=8,
dropout=config.fusion_dropout, batch_first=True)
self.classifier = nn.Sequential(
nn.LayerNorm(config.fusion_dim),
nn.Dropout(config.fusion_dropout),
nn.Linear(config.fusion_dim, config.num_classes))
def forward(self, visual_emb, audio_emb, modality_mask=None):
if modality_mask:
visual_emb = visual_emb * modality_mask.get("visual", 1.0)
audio_emb = audio_emb * modality_mask.get("audio", 1.0)
if self.fusion_type == "concat":
fused = torch.cat([self.visual_proj(visual_emb), self.audio_proj(audio_emb)], dim=-1)
return self.classifier(fused)
elif self.fusion_type == "weighted_sum":
w = torch.softmax(self.fusion_weights, dim=0)
return w[0] * self.visual_head(visual_emb) + w[1] * self.audio_head(audio_emb)
elif self.fusion_type == "attention":
tokens = torch.cat([
self.visual_proj(visual_emb).unsqueeze(1),
self.audio_proj(audio_emb).unsqueeze(1)], dim=1)
return self.classifier(self.cross_attn(tokens, tokens, tokens)[0].mean(dim=1))
# ===========================================================================
# OGM-GE: On-the-fly Gradient Modulation with Generalization Enhancement
# ===========================================================================
class OGMGEModulator:
"""
Implements OGM-GE from Peng et al., CVPR 2022.
After loss.backward(), this computes per-modality confidence ratios and
modulates encoder gradients to suppress the dominant modality and boost
the weaker one. Gaussian noise is added to suppressed gradients for
generalization enhancement.
Usage in training loop:
loss.backward()
coeff_v, coeff_a, stats = ogm.compute_modulation_coefficients(
visual_logits, audio_logits, labels)
ogm.apply_gradient_modulation(model, coeff_v, coeff_a)
optimizer.step()
"""
def __init__(self, alpha=0.3, noise_sigma=0.1):
"""
Args:
alpha: Modulation strength. Higher = more aggressive suppression
of dominant modality. Paper uses 0.3-0.5.
noise_sigma: Std of Gaussian noise added to suppressed modality's
gradients (Generalization Enhancement). Paper uses 0.1.
"""
self.alpha = alpha
self.noise_sigma = noise_sigma
@torch.no_grad()
def compute_modulation_coefficients(self, visual_logits, audio_logits, labels):
"""
Compute OGM-GE modulation coefficients based on per-modality confidence.
For each modality, we compute the average softmax probability of the
correct class (confidence). The modality with higher confidence is
considered dominant and gets its gradients scaled down.
Args:
visual_logits: (B, C) logits from the auxiliary visual head
audio_logits: (B, C) logits from the auxiliary audio head
labels: (B,) ground truth class indices
Returns:
coeff_visual: gradient scaling factor for visual encoder
coeff_audio: gradient scaling factor for audio encoder
stats: dict with debugging info
"""
# Softmax probabilities
v_probs = F.softmax(visual_logits, dim=-1)
a_probs = F.softmax(audio_logits, dim=-1)
# Confidence = avg probability assigned to correct class
batch_indices = torch.arange(labels.size(0), device=labels.device)
v_conf = v_probs[batch_indices, labels].mean().item()
a_conf = a_probs[batch_indices, labels].mean().item()
# Confidence ratio: how much better one modality is than the other
# ratio > 1 means visual is dominant, < 1 means audio is dominant
eps = 1e-8
ratio = (v_conf + eps) / (a_conf + eps)
# Modulation: scale down the dominant modality's gradients
# If ratio > 1 (visual dominant): coeff_v < 1, coeff_a = 1
# If ratio < 1 (audio dominant): coeff_v = 1, coeff_a < 1
if ratio > 1.0:
# Visual is dominant — suppress visual, keep audio
coeff_visual = 1.0 - self.alpha * torch.tanh(torch.tensor(ratio - 1.0)).item()
coeff_audio = 1.0
else:
# Audio is dominant — suppress audio, keep visual
coeff_visual = 1.0
coeff_audio = 1.0 - self.alpha * torch.tanh(torch.tensor(1.0 / ratio - 1.0)).item()
stats = {
"visual_conf": v_conf,
"audio_conf": a_conf,
"ratio": ratio,
"coeff_visual": coeff_visual,
"coeff_audio": coeff_audio,
}
return coeff_visual, coeff_audio, stats
def apply_gradient_modulation(self, model, coeff_visual, coeff_audio):
"""
Scale gradients of encoder parameters. Only affects the visual_branch
and audio_branch encoder weights — NOT the fusion head or auxiliary heads.
For the suppressed modality (coeff < 1), also adds Gaussian noise
to gradients (Generalization Enhancement from the paper).
"""
for name, param in model.named_parameters():
if param.grad is None:
continue
if "visual_branch" in name:
param.grad.data.mul_(coeff_visual)
# GE: add noise to suppressed modality
if coeff_visual < 1.0 and self.noise_sigma > 0:
noise = torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean()
param.grad.data.add_(noise)
elif "audio_branch" in name:
param.grad.data.mul_(coeff_audio)
if coeff_audio < 1.0 and self.noise_sigma > 0:
noise = torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean()
param.grad.data.add_(noise)
# ===========================================================================
# Main Model v2 — with auxiliary heads and OGM-GE support
# ===========================================================================
class MultimodalPCFaultDetector(nn.Module):
"""
v2 changes:
- Auxiliary classification heads on each branch (visual_head, audio_head)
- Forward returns per-branch logits for OGM-GE gradient modulation
- Loss = loss_fusion + λ_v * loss_visual + λ_a * loss_audio
- Asymmetric λ weights: λ_visual=1.5 (boost weak), λ_audio=0.5 (dampen dominant)
"""
def __init__(self, model_config, lora_config=None, finetune_method="lora",
mode="multimodal", use_ogm=True, lambda_visual=1.5, lambda_audio=0.5):
super().__init__()
self.mode = mode
self.modality_dropout_p = model_config.modality_dropout_p
self.use_ogm = use_ogm
self.lambda_visual = lambda_visual
self.lambda_audio = lambda_audio
# --- Branches ---
self.visual_branch = (
VisualBranch(model_config, lora_config, finetune_method)
if mode in ("multimodal", "visual_only") else None)
self.audio_branch = (
AudioBranch(model_config, lora_config, finetune_method)
if mode in ("multimodal", "audio_only") else None)
# --- Fusion / classifier ---
if mode == "multimodal":
self.fusion = LateFusion(model_config)
# NEW: Auxiliary unimodal classification heads
# These force each branch to independently learn discriminative features
self.visual_head = nn.Sequential(
nn.LayerNorm(model_config.vit_embed_dim),
nn.Dropout(0.2),
nn.Linear(model_config.vit_embed_dim, model_config.num_classes))
self.audio_head = nn.Sequential(
nn.LayerNorm(model_config.ast_embed_dim),
nn.Dropout(0.2),
nn.Linear(model_config.ast_embed_dim, model_config.num_classes))
else:
embed_dim = (model_config.vit_embed_dim if mode == "visual_only"
else model_config.ast_embed_dim)
self.classifier = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Dropout(model_config.fusion_dropout),
nn.Linear(embed_dim, model_config.fusion_dim),
nn.GELU(),
nn.Dropout(model_config.fusion_dropout),
nn.Linear(model_config.fusion_dim, model_config.num_classes))
self.loss_fn = nn.CrossEntropyLoss()
# --- Print parameter counts ---
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[Model v2] Mode={mode}, Total={total:,}, Trainable={trainable:,} "
f"({100*trainable/total:.2f}%)")
if mode == "multimodal":
print(f"[Model v2] OGM-GE={'ON' if use_ogm else 'OFF'}, "
f"λ_visual={lambda_visual}, λ_audio={lambda_audio}")
def forward(self, pixel_values=None, audio_values=None, labels=None):
if self.mode == "multimodal":
v_emb = self.visual_branch(pixel_values)
a_emb = self.audio_branch(audio_values)
# Modality dropout (training only)
mask = None
if self.training and self.modality_dropout_p > 0:
mask = {
"visual": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0,
"audio": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0,
}
# Ensure at least one modality is active
if mask["visual"] == 0.0 and mask["audio"] == 0.0:
mask["visual" if torch.rand(1).item() < 0.5 else "audio"] = 1.0
# Fusion logits
logits = self.fusion(v_emb, a_emb, mask)
# Auxiliary unimodal logits (always computed, needed for OGM-GE)
visual_logits = self.visual_head(v_emb)
audio_logits = self.audio_head(a_emb)
outputs = {
"logits": logits,
"visual_logits": visual_logits,
"audio_logits": audio_logits,
"visual_emb": v_emb,
"audio_emb": a_emb,
}
if labels is not None:
loss_fusion = self.loss_fn(logits, labels)
loss_visual = self.loss_fn(visual_logits, labels)
loss_audio = self.loss_fn(audio_logits, labels)
# Total loss with asymmetric weighting
loss = (loss_fusion
+ self.lambda_visual * loss_visual
+ self.lambda_audio * loss_audio)
outputs["loss"] = loss
outputs["loss_fusion"] = loss_fusion.item()
outputs["loss_visual"] = loss_visual.item()
outputs["loss_audio"] = loss_audio.item()
elif self.mode == "visual_only":
logits = self.classifier(self.visual_branch(pixel_values))
outputs = {"logits": logits}
if labels is not None:
outputs["loss"] = self.loss_fn(logits, labels)
else: # audio_only
logits = self.classifier(self.audio_branch(audio_values))
outputs = {"logits": logits}
if labels is not None:
outputs["loss"] = self.loss_fn(logits, labels)
return outputs
# ===========================================================================
# Factory functions
# ===========================================================================
def create_model(model_config, lora_config, mode="multimodal",
finetune_method="lora", use_ogm=True,
lambda_visual=1.5, lambda_audio=0.5):
"""Create model with v2 anti-collapse features."""
return MultimodalPCFaultDetector(
model_config, lora_config, finetune_method, mode,
use_ogm=use_ogm,
lambda_visual=lambda_visual,
lambda_audio=lambda_audio)
def get_processors(model_config):
"""Load ViT image processor and AST feature extractor."""
return (
ViTImageProcessor.from_pretrained(model_config.vit_model_name),
ASTFeatureExtractor.from_pretrained(model_config.ast_model_name))
|