File size: 10,849 Bytes
38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 4e62071 38fdf87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """
Multimodal PC Fault Detection - Model Architecture
====================================================
Two-branch architecture with anti-modality-collapse features:
- Visual: ViT-B/16 (ImageNet-21k) + LoRA
- Audio: AST (AudioSet) + LoRA
- Fusion: Late fusion (concat / weighted sum / attention)
- Auxiliary unimodal classification heads
- OGM-GE gradient modulation support (Peng et al., CVPR 2022)
Loss = L_fusion + λ_visual * L_visual + λ_audio * L_audio
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Literal
from transformers import ViTModel, ASTModel, ViTImageProcessor, ASTFeatureExtractor
from peft import LoraConfig, get_peft_model
from config import ModelConfig, LoRAConfig, FAULT_CLASSES
class VisualBranch(nn.Module):
def __init__(self, config, lora_config=None, finetune_method="lora"):
super().__init__()
self.vit = ViTModel.from_pretrained(config.vit_model_name)
if finetune_method == "lora" and lora_config and lora_config.enabled:
peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
target_modules=lora_config.vit_target_modules,
lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
self.vit = get_peft_model(self.vit, peft_config)
self.vit.print_trainable_parameters()
elif finetune_method == "linear_probe":
for param in self.vit.parameters(): param.requires_grad = False
def forward(self, pixel_values):
return self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :]
class AudioBranch(nn.Module):
def __init__(self, config, lora_config=None, finetune_method="lora"):
super().__init__()
self.ast = ASTModel.from_pretrained(config.ast_model_name)
if finetune_method == "lora" and lora_config and lora_config.enabled:
peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
target_modules=lora_config.ast_target_modules,
lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
self.ast = get_peft_model(self.ast, peft_config)
self.ast.print_trainable_parameters()
elif finetune_method == "linear_probe":
for param in self.ast.parameters(): param.requires_grad = False
def forward(self, input_values):
return self.ast(input_values=input_values).last_hidden_state[:, 0, :]
class LateFusion(nn.Module):
def __init__(self, config):
super().__init__()
self.fusion_type = config.fusion_type
if config.fusion_type == "concat":
self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
self.classifier = nn.Sequential(
nn.LayerNorm(config.fusion_dim * 2), nn.Dropout(config.fusion_dropout),
nn.Linear(config.fusion_dim * 2, config.fusion_dim), nn.GELU(),
nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes))
elif config.fusion_type == "weighted_sum":
self.visual_head = nn.Linear(config.vit_embed_dim, config.num_classes)
self.audio_head = nn.Linear(config.ast_embed_dim, config.num_classes)
self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5]))
elif config.fusion_type == "attention":
self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
self.cross_attn = nn.MultiheadAttention(embed_dim=config.fusion_dim, num_heads=8,
dropout=config.fusion_dropout, batch_first=True)
self.classifier = nn.Sequential(nn.LayerNorm(config.fusion_dim),
nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes))
def forward(self, visual_emb, audio_emb, modality_mask=None):
if modality_mask:
visual_emb = visual_emb * modality_mask.get("visual", 1.0)
audio_emb = audio_emb * modality_mask.get("audio", 1.0)
if self.fusion_type == "concat":
return self.classifier(torch.cat([self.visual_proj(visual_emb), self.audio_proj(audio_emb)], dim=-1))
elif self.fusion_type == "weighted_sum":
w = torch.softmax(self.fusion_weights, dim=0)
return w[0] * self.visual_head(visual_emb) + w[1] * self.audio_head(audio_emb)
elif self.fusion_type == "attention":
tokens = torch.cat([self.visual_proj(visual_emb).unsqueeze(1), self.audio_proj(audio_emb).unsqueeze(1)], dim=1)
return self.classifier(self.cross_attn(tokens, tokens, tokens)[0].mean(dim=1))
class OGMGEModulator:
"""OGM-GE from Peng et al., CVPR 2022. Suppresses dominant modality gradients."""
def __init__(self, alpha=0.3, noise_sigma=0.1):
self.alpha = alpha
self.noise_sigma = noise_sigma
@torch.no_grad()
def compute_modulation_coefficients(self, visual_logits, audio_logits, labels):
v_probs = F.softmax(visual_logits, dim=-1)
a_probs = F.softmax(audio_logits, dim=-1)
batch_idx = torch.arange(labels.size(0), device=labels.device)
v_conf = v_probs[batch_idx, labels].mean().item()
a_conf = a_probs[batch_idx, labels].mean().item()
eps = 1e-8
ratio = (v_conf + eps) / (a_conf + eps)
if ratio > 1.0:
coeff_visual = 1.0 - self.alpha * torch.tanh(torch.tensor(ratio - 1.0)).item()
coeff_audio = 1.0
else:
coeff_visual = 1.0
coeff_audio = 1.0 - self.alpha * torch.tanh(torch.tensor(1.0 / ratio - 1.0)).item()
return coeff_visual, coeff_audio, {"visual_conf": v_conf, "audio_conf": a_conf,
"ratio": ratio, "coeff_visual": coeff_visual, "coeff_audio": coeff_audio}
def apply_gradient_modulation(self, model, coeff_visual, coeff_audio):
for name, param in model.named_parameters():
if param.grad is None: continue
if "visual_branch" in name:
param.grad.data.mul_(coeff_visual)
if coeff_visual < 1.0 and self.noise_sigma > 0:
param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean())
elif "audio_branch" in name:
param.grad.data.mul_(coeff_audio)
if coeff_audio < 1.0 and self.noise_sigma > 0:
param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean())
class MultimodalPCFaultDetector(nn.Module):
def __init__(self, model_config, lora_config=None, finetune_method="lora",
mode="multimodal", use_ogm=True, lambda_visual=1.5, lambda_audio=0.5):
super().__init__()
self.mode, self.modality_dropout_p = mode, model_config.modality_dropout_p
self.use_ogm, self.lambda_visual, self.lambda_audio = use_ogm, lambda_visual, lambda_audio
self.visual_branch = VisualBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "visual_only") else None
self.audio_branch = AudioBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "audio_only") else None
if mode == "multimodal":
self.fusion = LateFusion(model_config)
self.visual_head = nn.Sequential(nn.LayerNorm(model_config.vit_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.vit_embed_dim, model_config.num_classes))
self.audio_head = nn.Sequential(nn.LayerNorm(model_config.ast_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.ast_embed_dim, model_config.num_classes))
else:
embed_dim = model_config.vit_embed_dim if mode == "visual_only" else model_config.ast_embed_dim
self.classifier = nn.Sequential(nn.LayerNorm(embed_dim), nn.Dropout(model_config.fusion_dropout),
nn.Linear(embed_dim, model_config.fusion_dim), nn.GELU(), nn.Dropout(model_config.fusion_dropout),
nn.Linear(model_config.fusion_dim, model_config.num_classes))
self.loss_fn = nn.CrossEntropyLoss()
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[Model] Mode={mode}, Total={total:,}, Trainable={trainable:,} ({100*trainable/total:.2f}%)")
def forward(self, pixel_values=None, audio_values=None, labels=None):
if self.mode == "multimodal":
v_emb, a_emb = self.visual_branch(pixel_values), self.audio_branch(audio_values)
mask = None
if self.training and self.modality_dropout_p > 0:
mask = {"visual": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0,
"audio": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0}
if mask["visual"] == 0.0 and mask["audio"] == 0.0:
mask["visual" if torch.rand(1).item() < 0.5 else "audio"] = 1.0
logits = self.fusion(v_emb, a_emb, mask)
visual_logits, audio_logits = self.visual_head(v_emb), self.audio_head(a_emb)
outputs = {"logits": logits, "visual_logits": visual_logits, "audio_logits": audio_logits, "visual_emb": v_emb, "audio_emb": a_emb}
if labels is not None:
loss_f, loss_v, loss_a = self.loss_fn(logits, labels), self.loss_fn(visual_logits, labels), self.loss_fn(audio_logits, labels)
outputs.update({"loss": loss_f + self.lambda_visual * loss_v + self.lambda_audio * loss_a,
"loss_fusion": loss_f.item(), "loss_visual": loss_v.item(), "loss_audio": loss_a.item()})
elif self.mode == "visual_only":
logits = self.classifier(self.visual_branch(pixel_values))
outputs = {"logits": logits}
if labels is not None: outputs["loss"] = self.loss_fn(logits, labels)
else:
logits = self.classifier(self.audio_branch(audio_values))
outputs = {"logits": logits}
if labels is not None: outputs["loss"] = self.loss_fn(logits, labels)
return outputs
def create_model(model_config, lora_config, mode="multimodal", finetune_method="lora",
use_ogm=True, lambda_visual=1.5, lambda_audio=0.5):
return MultimodalPCFaultDetector(model_config, lora_config, finetune_method, mode,
use_ogm=use_ogm, lambda_visual=lambda_visual, lambda_audio=lambda_audio)
def get_processors(model_config):
return (ViTImageProcessor.from_pretrained(model_config.vit_model_name),
ASTFeatureExtractor.from_pretrained(model_config.ast_model_name))
|