| """ |
| Multimodal PC Fault Detection - Model Architecture |
| ==================================================== |
| Two-branch architecture with anti-modality-collapse features: |
| - Visual: ViT-B/16 (ImageNet-21k) + LoRA |
| - Audio: AST (AudioSet) + LoRA |
| - Fusion: Late fusion (concat / weighted sum / attention) |
| - Auxiliary unimodal classification heads |
| - OGM-GE gradient modulation support (Peng et al., CVPR 2022) |
| |
| Loss = L_fusion + λ_visual * L_visual + λ_audio * L_audio |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, Optional, Literal |
| from transformers import ViTModel, ASTModel, ViTImageProcessor, ASTFeatureExtractor |
| from peft import LoraConfig, get_peft_model |
| from config import ModelConfig, LoRAConfig, FAULT_CLASSES |
|
|
|
|
| class VisualBranch(nn.Module): |
| def __init__(self, config, lora_config=None, finetune_method="lora"): |
| super().__init__() |
| self.vit = ViTModel.from_pretrained(config.vit_model_name) |
| if finetune_method == "lora" and lora_config and lora_config.enabled: |
| peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha, |
| target_modules=lora_config.vit_target_modules, |
| lora_dropout=lora_config.lora_dropout, bias=lora_config.bias) |
| self.vit = get_peft_model(self.vit, peft_config) |
| self.vit.print_trainable_parameters() |
| elif finetune_method == "linear_probe": |
| for param in self.vit.parameters(): param.requires_grad = False |
|
|
| def forward(self, pixel_values): |
| return self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :] |
|
|
|
|
| class AudioBranch(nn.Module): |
| def __init__(self, config, lora_config=None, finetune_method="lora"): |
| super().__init__() |
| self.ast = ASTModel.from_pretrained(config.ast_model_name) |
| if finetune_method == "lora" and lora_config and lora_config.enabled: |
| peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha, |
| target_modules=lora_config.ast_target_modules, |
| lora_dropout=lora_config.lora_dropout, bias=lora_config.bias) |
| self.ast = get_peft_model(self.ast, peft_config) |
| self.ast.print_trainable_parameters() |
| elif finetune_method == "linear_probe": |
| for param in self.ast.parameters(): param.requires_grad = False |
|
|
| def forward(self, input_values): |
| return self.ast(input_values=input_values).last_hidden_state[:, 0, :] |
|
|
|
|
| class LateFusion(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.fusion_type = config.fusion_type |
| if config.fusion_type == "concat": |
| self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim) |
| self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim) |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(config.fusion_dim * 2), nn.Dropout(config.fusion_dropout), |
| nn.Linear(config.fusion_dim * 2, config.fusion_dim), nn.GELU(), |
| nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes)) |
| elif config.fusion_type == "weighted_sum": |
| self.visual_head = nn.Linear(config.vit_embed_dim, config.num_classes) |
| self.audio_head = nn.Linear(config.ast_embed_dim, config.num_classes) |
| self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5])) |
| elif config.fusion_type == "attention": |
| self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim) |
| self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim) |
| self.cross_attn = nn.MultiheadAttention(embed_dim=config.fusion_dim, num_heads=8, |
| dropout=config.fusion_dropout, batch_first=True) |
| self.classifier = nn.Sequential(nn.LayerNorm(config.fusion_dim), |
| nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes)) |
|
|
| def forward(self, visual_emb, audio_emb, modality_mask=None): |
| if modality_mask: |
| visual_emb = visual_emb * modality_mask.get("visual", 1.0) |
| audio_emb = audio_emb * modality_mask.get("audio", 1.0) |
| if self.fusion_type == "concat": |
| return self.classifier(torch.cat([self.visual_proj(visual_emb), self.audio_proj(audio_emb)], dim=-1)) |
| elif self.fusion_type == "weighted_sum": |
| w = torch.softmax(self.fusion_weights, dim=0) |
| return w[0] * self.visual_head(visual_emb) + w[1] * self.audio_head(audio_emb) |
| elif self.fusion_type == "attention": |
| tokens = torch.cat([self.visual_proj(visual_emb).unsqueeze(1), self.audio_proj(audio_emb).unsqueeze(1)], dim=1) |
| return self.classifier(self.cross_attn(tokens, tokens, tokens)[0].mean(dim=1)) |
|
|
|
|
| class OGMGEModulator: |
| """OGM-GE from Peng et al., CVPR 2022. Suppresses dominant modality gradients.""" |
| def __init__(self, alpha=0.3, noise_sigma=0.1): |
| self.alpha = alpha |
| self.noise_sigma = noise_sigma |
|
|
| @torch.no_grad() |
| def compute_modulation_coefficients(self, visual_logits, audio_logits, labels): |
| v_probs = F.softmax(visual_logits, dim=-1) |
| a_probs = F.softmax(audio_logits, dim=-1) |
| batch_idx = torch.arange(labels.size(0), device=labels.device) |
| v_conf = v_probs[batch_idx, labels].mean().item() |
| a_conf = a_probs[batch_idx, labels].mean().item() |
| eps = 1e-8 |
| ratio = (v_conf + eps) / (a_conf + eps) |
| if ratio > 1.0: |
| coeff_visual = 1.0 - self.alpha * torch.tanh(torch.tensor(ratio - 1.0)).item() |
| coeff_audio = 1.0 |
| else: |
| coeff_visual = 1.0 |
| coeff_audio = 1.0 - self.alpha * torch.tanh(torch.tensor(1.0 / ratio - 1.0)).item() |
| return coeff_visual, coeff_audio, {"visual_conf": v_conf, "audio_conf": a_conf, |
| "ratio": ratio, "coeff_visual": coeff_visual, "coeff_audio": coeff_audio} |
|
|
| def apply_gradient_modulation(self, model, coeff_visual, coeff_audio): |
| for name, param in model.named_parameters(): |
| if param.grad is None: continue |
| if "visual_branch" in name: |
| param.grad.data.mul_(coeff_visual) |
| if coeff_visual < 1.0 and self.noise_sigma > 0: |
| param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean()) |
| elif "audio_branch" in name: |
| param.grad.data.mul_(coeff_audio) |
| if coeff_audio < 1.0 and self.noise_sigma > 0: |
| param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean()) |
|
|
|
|
| class MultimodalPCFaultDetector(nn.Module): |
| def __init__(self, model_config, lora_config=None, finetune_method="lora", |
| mode="multimodal", use_ogm=True, lambda_visual=1.5, lambda_audio=0.5): |
| super().__init__() |
| self.mode, self.modality_dropout_p = mode, model_config.modality_dropout_p |
| self.use_ogm, self.lambda_visual, self.lambda_audio = use_ogm, lambda_visual, lambda_audio |
| self.visual_branch = VisualBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "visual_only") else None |
| self.audio_branch = AudioBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "audio_only") else None |
| if mode == "multimodal": |
| self.fusion = LateFusion(model_config) |
| self.visual_head = nn.Sequential(nn.LayerNorm(model_config.vit_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.vit_embed_dim, model_config.num_classes)) |
| self.audio_head = nn.Sequential(nn.LayerNorm(model_config.ast_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.ast_embed_dim, model_config.num_classes)) |
| else: |
| embed_dim = model_config.vit_embed_dim if mode == "visual_only" else model_config.ast_embed_dim |
| self.classifier = nn.Sequential(nn.LayerNorm(embed_dim), nn.Dropout(model_config.fusion_dropout), |
| nn.Linear(embed_dim, model_config.fusion_dim), nn.GELU(), nn.Dropout(model_config.fusion_dropout), |
| nn.Linear(model_config.fusion_dim, model_config.num_classes)) |
| self.loss_fn = nn.CrossEntropyLoss() |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| print(f"[Model] Mode={mode}, Total={total:,}, Trainable={trainable:,} ({100*trainable/total:.2f}%)") |
|
|
| def forward(self, pixel_values=None, audio_values=None, labels=None): |
| if self.mode == "multimodal": |
| v_emb, a_emb = self.visual_branch(pixel_values), self.audio_branch(audio_values) |
| mask = None |
| if self.training and self.modality_dropout_p > 0: |
| mask = {"visual": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0, |
| "audio": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0} |
| if mask["visual"] == 0.0 and mask["audio"] == 0.0: |
| mask["visual" if torch.rand(1).item() < 0.5 else "audio"] = 1.0 |
| logits = self.fusion(v_emb, a_emb, mask) |
| visual_logits, audio_logits = self.visual_head(v_emb), self.audio_head(a_emb) |
| outputs = {"logits": logits, "visual_logits": visual_logits, "audio_logits": audio_logits, "visual_emb": v_emb, "audio_emb": a_emb} |
| if labels is not None: |
| loss_f, loss_v, loss_a = self.loss_fn(logits, labels), self.loss_fn(visual_logits, labels), self.loss_fn(audio_logits, labels) |
| outputs.update({"loss": loss_f + self.lambda_visual * loss_v + self.lambda_audio * loss_a, |
| "loss_fusion": loss_f.item(), "loss_visual": loss_v.item(), "loss_audio": loss_a.item()}) |
| elif self.mode == "visual_only": |
| logits = self.classifier(self.visual_branch(pixel_values)) |
| outputs = {"logits": logits} |
| if labels is not None: outputs["loss"] = self.loss_fn(logits, labels) |
| else: |
| logits = self.classifier(self.audio_branch(audio_values)) |
| outputs = {"logits": logits} |
| if labels is not None: outputs["loss"] = self.loss_fn(logits, labels) |
| return outputs |
|
|
|
|
| def create_model(model_config, lora_config, mode="multimodal", finetune_method="lora", |
| use_ogm=True, lambda_visual=1.5, lambda_audio=0.5): |
| return MultimodalPCFaultDetector(model_config, lora_config, finetune_method, mode, |
| use_ogm=use_ogm, lambda_visual=lambda_visual, lambda_audio=lambda_audio) |
|
|
| def get_processors(model_config): |
| return (ViTImageProcessor.from_pretrained(model_config.vit_model_name), |
| ASTFeatureExtractor.from_pretrained(model_config.ast_model_name)) |
|
|