""" Multimodal PC Fault Detection - Model Architecture ==================================================== Two-branch architecture with anti-modality-collapse features: - Visual: ViT-B/16 (ImageNet-21k) + LoRA - Audio: AST (AudioSet) + LoRA - Fusion: Late fusion (concat / weighted sum / attention) - Auxiliary unimodal classification heads - OGM-GE gradient modulation support (Peng et al., CVPR 2022) Loss = L_fusion + λ_visual * L_visual + λ_audio * L_audio """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Optional, Literal from transformers import ViTModel, ASTModel, ViTImageProcessor, ASTFeatureExtractor from peft import LoraConfig, get_peft_model from config import ModelConfig, LoRAConfig, FAULT_CLASSES class VisualBranch(nn.Module): def __init__(self, config, lora_config=None, finetune_method="lora"): super().__init__() self.vit = ViTModel.from_pretrained(config.vit_model_name) if finetune_method == "lora" and lora_config and lora_config.enabled: peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha, target_modules=lora_config.vit_target_modules, lora_dropout=lora_config.lora_dropout, bias=lora_config.bias) self.vit = get_peft_model(self.vit, peft_config) self.vit.print_trainable_parameters() elif finetune_method == "linear_probe": for param in self.vit.parameters(): param.requires_grad = False def forward(self, pixel_values): return self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :] class AudioBranch(nn.Module): def __init__(self, config, lora_config=None, finetune_method="lora"): super().__init__() self.ast = ASTModel.from_pretrained(config.ast_model_name) if finetune_method == "lora" and lora_config and lora_config.enabled: peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha, target_modules=lora_config.ast_target_modules, lora_dropout=lora_config.lora_dropout, bias=lora_config.bias) self.ast = get_peft_model(self.ast, peft_config) self.ast.print_trainable_parameters() elif finetune_method == "linear_probe": for param in self.ast.parameters(): param.requires_grad = False def forward(self, input_values): return self.ast(input_values=input_values).last_hidden_state[:, 0, :] class LateFusion(nn.Module): def __init__(self, config): super().__init__() self.fusion_type = config.fusion_type if config.fusion_type == "concat": self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim) self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim) self.classifier = nn.Sequential( nn.LayerNorm(config.fusion_dim * 2), nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim * 2, config.fusion_dim), nn.GELU(), nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes)) elif config.fusion_type == "weighted_sum": self.visual_head = nn.Linear(config.vit_embed_dim, config.num_classes) self.audio_head = nn.Linear(config.ast_embed_dim, config.num_classes) self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5])) elif config.fusion_type == "attention": self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim) self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim) self.cross_attn = nn.MultiheadAttention(embed_dim=config.fusion_dim, num_heads=8, dropout=config.fusion_dropout, batch_first=True) self.classifier = nn.Sequential(nn.LayerNorm(config.fusion_dim), nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes)) def forward(self, visual_emb, audio_emb, modality_mask=None): if modality_mask: visual_emb = visual_emb * modality_mask.get("visual", 1.0) audio_emb = audio_emb * modality_mask.get("audio", 1.0) if self.fusion_type == "concat": return self.classifier(torch.cat([self.visual_proj(visual_emb), self.audio_proj(audio_emb)], dim=-1)) elif self.fusion_type == "weighted_sum": w = torch.softmax(self.fusion_weights, dim=0) return w[0] * self.visual_head(visual_emb) + w[1] * self.audio_head(audio_emb) elif self.fusion_type == "attention": tokens = torch.cat([self.visual_proj(visual_emb).unsqueeze(1), self.audio_proj(audio_emb).unsqueeze(1)], dim=1) return self.classifier(self.cross_attn(tokens, tokens, tokens)[0].mean(dim=1)) class OGMGEModulator: """OGM-GE from Peng et al., CVPR 2022. Suppresses dominant modality gradients.""" def __init__(self, alpha=0.3, noise_sigma=0.1): self.alpha = alpha self.noise_sigma = noise_sigma @torch.no_grad() def compute_modulation_coefficients(self, visual_logits, audio_logits, labels): v_probs = F.softmax(visual_logits, dim=-1) a_probs = F.softmax(audio_logits, dim=-1) batch_idx = torch.arange(labels.size(0), device=labels.device) v_conf = v_probs[batch_idx, labels].mean().item() a_conf = a_probs[batch_idx, labels].mean().item() eps = 1e-8 ratio = (v_conf + eps) / (a_conf + eps) if ratio > 1.0: coeff_visual = 1.0 - self.alpha * torch.tanh(torch.tensor(ratio - 1.0)).item() coeff_audio = 1.0 else: coeff_visual = 1.0 coeff_audio = 1.0 - self.alpha * torch.tanh(torch.tensor(1.0 / ratio - 1.0)).item() return coeff_visual, coeff_audio, {"visual_conf": v_conf, "audio_conf": a_conf, "ratio": ratio, "coeff_visual": coeff_visual, "coeff_audio": coeff_audio} def apply_gradient_modulation(self, model, coeff_visual, coeff_audio): for name, param in model.named_parameters(): if param.grad is None: continue if "visual_branch" in name: param.grad.data.mul_(coeff_visual) if coeff_visual < 1.0 and self.noise_sigma > 0: param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean()) elif "audio_branch" in name: param.grad.data.mul_(coeff_audio) if coeff_audio < 1.0 and self.noise_sigma > 0: param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean()) class MultimodalPCFaultDetector(nn.Module): def __init__(self, model_config, lora_config=None, finetune_method="lora", mode="multimodal", use_ogm=True, lambda_visual=1.5, lambda_audio=0.5): super().__init__() self.mode, self.modality_dropout_p = mode, model_config.modality_dropout_p self.use_ogm, self.lambda_visual, self.lambda_audio = use_ogm, lambda_visual, lambda_audio self.visual_branch = VisualBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "visual_only") else None self.audio_branch = AudioBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "audio_only") else None if mode == "multimodal": self.fusion = LateFusion(model_config) self.visual_head = nn.Sequential(nn.LayerNorm(model_config.vit_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.vit_embed_dim, model_config.num_classes)) self.audio_head = nn.Sequential(nn.LayerNorm(model_config.ast_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.ast_embed_dim, model_config.num_classes)) else: embed_dim = model_config.vit_embed_dim if mode == "visual_only" else model_config.ast_embed_dim self.classifier = nn.Sequential(nn.LayerNorm(embed_dim), nn.Dropout(model_config.fusion_dropout), nn.Linear(embed_dim, model_config.fusion_dim), nn.GELU(), nn.Dropout(model_config.fusion_dropout), nn.Linear(model_config.fusion_dim, model_config.num_classes)) self.loss_fn = nn.CrossEntropyLoss() total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) print(f"[Model] Mode={mode}, Total={total:,}, Trainable={trainable:,} ({100*trainable/total:.2f}%)") def forward(self, pixel_values=None, audio_values=None, labels=None): if self.mode == "multimodal": v_emb, a_emb = self.visual_branch(pixel_values), self.audio_branch(audio_values) mask = None if self.training and self.modality_dropout_p > 0: mask = {"visual": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0, "audio": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0} if mask["visual"] == 0.0 and mask["audio"] == 0.0: mask["visual" if torch.rand(1).item() < 0.5 else "audio"] = 1.0 logits = self.fusion(v_emb, a_emb, mask) visual_logits, audio_logits = self.visual_head(v_emb), self.audio_head(a_emb) outputs = {"logits": logits, "visual_logits": visual_logits, "audio_logits": audio_logits, "visual_emb": v_emb, "audio_emb": a_emb} if labels is not None: loss_f, loss_v, loss_a = self.loss_fn(logits, labels), self.loss_fn(visual_logits, labels), self.loss_fn(audio_logits, labels) outputs.update({"loss": loss_f + self.lambda_visual * loss_v + self.lambda_audio * loss_a, "loss_fusion": loss_f.item(), "loss_visual": loss_v.item(), "loss_audio": loss_a.item()}) elif self.mode == "visual_only": logits = self.classifier(self.visual_branch(pixel_values)) outputs = {"logits": logits} if labels is not None: outputs["loss"] = self.loss_fn(logits, labels) else: logits = self.classifier(self.audio_branch(audio_values)) outputs = {"logits": logits} if labels is not None: outputs["loss"] = self.loss_fn(logits, labels) return outputs def create_model(model_config, lora_config, mode="multimodal", finetune_method="lora", use_ogm=True, lambda_visual=1.5, lambda_audio=0.5): return MultimodalPCFaultDetector(model_config, lora_config, finetune_method, mode, use_ogm=use_ogm, lambda_visual=lambda_visual, lambda_audio=lambda_audio) def get_processors(model_config): return (ViTImageProcessor.from_pretrained(model_config.vit_model_name), ASTFeatureExtractor.from_pretrained(model_config.ast_model_name))