| """ |
| Multimodal PC Fault Detection - Model Architecture v2 |
| ====================================================== |
| Changes from v1: |
| - Auxiliary unimodal classification heads (force each branch to independently classify) |
| - Asymmetric loss weighting: λ_visual=1.5 (boost weak), λ_audio=0.5 (dampen dominant) |
| - OGM-GE (On-the-fly Gradient Modulation + Generalization Enhancement) support |
| - Forward returns per-branch logits + embeddings for OGM-GE gradient modulation |
| |
| Two-branch architecture: |
| - Visual: ViT-B/16 pretrained on ImageNet-21k |
| - Audio: AST pretrained on AudioSet |
| - Fusion: Late fusion (concat / weighted sum / attention) |
| |
| Supports LoRA, full fine-tuning, and linear probe modes. |
| |
| References: |
| - OGM-GE: Peng et al., "Balanced Multimodal Learning via On-the-fly Gradient |
| Modulation", CVPR 2022 (arXiv: 2203.15332) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, Optional, Literal |
| from transformers import ViTModel, ASTModel, ViTImageProcessor, ASTFeatureExtractor |
| from peft import LoraConfig, get_peft_model |
| from config import ModelConfig, LoRAConfig, FAULT_CLASSES |
|
|
|
|
| |
| |
| |
|
|
| class VisualBranch(nn.Module): |
| def __init__(self, config, lora_config=None, finetune_method="lora"): |
| super().__init__() |
| self.vit = ViTModel.from_pretrained(config.vit_model_name) |
| if finetune_method == "lora" and lora_config and lora_config.enabled: |
| peft_config = LoraConfig( |
| r=lora_config.r, lora_alpha=lora_config.lora_alpha, |
| target_modules=lora_config.vit_target_modules, |
| lora_dropout=lora_config.lora_dropout, bias=lora_config.bias) |
| self.vit = get_peft_model(self.vit, peft_config) |
| self.vit.print_trainable_parameters() |
| elif finetune_method == "linear_probe": |
| for param in self.vit.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, pixel_values): |
| return self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :] |
|
|
|
|
| class AudioBranch(nn.Module): |
| def __init__(self, config, lora_config=None, finetune_method="lora"): |
| super().__init__() |
| self.ast = ASTModel.from_pretrained(config.ast_model_name) |
| if finetune_method == "lora" and lora_config and lora_config.enabled: |
| peft_config = LoraConfig( |
| r=lora_config.r, lora_alpha=lora_config.lora_alpha, |
| target_modules=lora_config.ast_target_modules, |
| lora_dropout=lora_config.lora_dropout, bias=lora_config.bias) |
| self.ast = get_peft_model(self.ast, peft_config) |
| self.ast.print_trainable_parameters() |
| elif finetune_method == "linear_probe": |
| for param in self.ast.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, input_values): |
| return self.ast(input_values=input_values).last_hidden_state[:, 0, :] |
|
|
|
|
| |
| |
| |
|
|
| class LateFusion(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.fusion_type = config.fusion_type |
| if config.fusion_type == "concat": |
| self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim) |
| self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim) |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(config.fusion_dim * 2), |
| nn.Dropout(config.fusion_dropout), |
| nn.Linear(config.fusion_dim * 2, config.fusion_dim), |
| nn.GELU(), |
| nn.Dropout(config.fusion_dropout), |
| nn.Linear(config.fusion_dim, config.num_classes)) |
| elif config.fusion_type == "weighted_sum": |
| self.visual_head = nn.Linear(config.vit_embed_dim, config.num_classes) |
| self.audio_head = nn.Linear(config.ast_embed_dim, config.num_classes) |
| self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5])) |
| elif config.fusion_type == "attention": |
| self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim) |
| self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim) |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=config.fusion_dim, num_heads=8, |
| dropout=config.fusion_dropout, batch_first=True) |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(config.fusion_dim), |
| nn.Dropout(config.fusion_dropout), |
| nn.Linear(config.fusion_dim, config.num_classes)) |
|
|
| def forward(self, visual_emb, audio_emb, modality_mask=None): |
| if modality_mask: |
| visual_emb = visual_emb * modality_mask.get("visual", 1.0) |
| audio_emb = audio_emb * modality_mask.get("audio", 1.0) |
| if self.fusion_type == "concat": |
| fused = torch.cat([self.visual_proj(visual_emb), self.audio_proj(audio_emb)], dim=-1) |
| return self.classifier(fused) |
| elif self.fusion_type == "weighted_sum": |
| w = torch.softmax(self.fusion_weights, dim=0) |
| return w[0] * self.visual_head(visual_emb) + w[1] * self.audio_head(audio_emb) |
| elif self.fusion_type == "attention": |
| tokens = torch.cat([ |
| self.visual_proj(visual_emb).unsqueeze(1), |
| self.audio_proj(audio_emb).unsqueeze(1)], dim=1) |
| return self.classifier(self.cross_attn(tokens, tokens, tokens)[0].mean(dim=1)) |
|
|
|
|
| |
| |
| |
|
|
| class OGMGEModulator: |
| """ |
| Implements OGM-GE from Peng et al., CVPR 2022. |
| |
| After loss.backward(), this computes per-modality confidence ratios and |
| modulates encoder gradients to suppress the dominant modality and boost |
| the weaker one. Gaussian noise is added to suppressed gradients for |
| generalization enhancement. |
| |
| Usage in training loop: |
| loss.backward() |
| coeff_v, coeff_a, stats = ogm.compute_modulation_coefficients( |
| visual_logits, audio_logits, labels) |
| ogm.apply_gradient_modulation(model, coeff_v, coeff_a) |
| optimizer.step() |
| """ |
|
|
| def __init__(self, alpha=0.3, noise_sigma=0.1): |
| """ |
| Args: |
| alpha: Modulation strength. Higher = more aggressive suppression |
| of dominant modality. Paper uses 0.3-0.5. |
| noise_sigma: Std of Gaussian noise added to suppressed modality's |
| gradients (Generalization Enhancement). Paper uses 0.1. |
| """ |
| self.alpha = alpha |
| self.noise_sigma = noise_sigma |
|
|
| @torch.no_grad() |
| def compute_modulation_coefficients(self, visual_logits, audio_logits, labels): |
| """ |
| Compute OGM-GE modulation coefficients based on per-modality confidence. |
| |
| For each modality, we compute the average softmax probability of the |
| correct class (confidence). The modality with higher confidence is |
| considered dominant and gets its gradients scaled down. |
| |
| Args: |
| visual_logits: (B, C) logits from the auxiliary visual head |
| audio_logits: (B, C) logits from the auxiliary audio head |
| labels: (B,) ground truth class indices |
| |
| Returns: |
| coeff_visual: gradient scaling factor for visual encoder |
| coeff_audio: gradient scaling factor for audio encoder |
| stats: dict with debugging info |
| """ |
| |
| v_probs = F.softmax(visual_logits, dim=-1) |
| a_probs = F.softmax(audio_logits, dim=-1) |
|
|
| |
| batch_indices = torch.arange(labels.size(0), device=labels.device) |
| v_conf = v_probs[batch_indices, labels].mean().item() |
| a_conf = a_probs[batch_indices, labels].mean().item() |
|
|
| |
| |
| eps = 1e-8 |
| ratio = (v_conf + eps) / (a_conf + eps) |
|
|
| |
| |
| |
| if ratio > 1.0: |
| |
| coeff_visual = 1.0 - self.alpha * torch.tanh(torch.tensor(ratio - 1.0)).item() |
| coeff_audio = 1.0 |
| else: |
| |
| coeff_visual = 1.0 |
| coeff_audio = 1.0 - self.alpha * torch.tanh(torch.tensor(1.0 / ratio - 1.0)).item() |
|
|
| stats = { |
| "visual_conf": v_conf, |
| "audio_conf": a_conf, |
| "ratio": ratio, |
| "coeff_visual": coeff_visual, |
| "coeff_audio": coeff_audio, |
| } |
| return coeff_visual, coeff_audio, stats |
|
|
| def apply_gradient_modulation(self, model, coeff_visual, coeff_audio): |
| """ |
| Scale gradients of encoder parameters. Only affects the visual_branch |
| and audio_branch encoder weights — NOT the fusion head or auxiliary heads. |
| |
| For the suppressed modality (coeff < 1), also adds Gaussian noise |
| to gradients (Generalization Enhancement from the paper). |
| """ |
| for name, param in model.named_parameters(): |
| if param.grad is None: |
| continue |
|
|
| if "visual_branch" in name: |
| param.grad.data.mul_(coeff_visual) |
| |
| if coeff_visual < 1.0 and self.noise_sigma > 0: |
| noise = torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean() |
| param.grad.data.add_(noise) |
|
|
| elif "audio_branch" in name: |
| param.grad.data.mul_(coeff_audio) |
| if coeff_audio < 1.0 and self.noise_sigma > 0: |
| noise = torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean() |
| param.grad.data.add_(noise) |
|
|
|
|
| |
| |
| |
|
|
| class MultimodalPCFaultDetector(nn.Module): |
| """ |
| v2 changes: |
| - Auxiliary classification heads on each branch (visual_head, audio_head) |
| - Forward returns per-branch logits for OGM-GE gradient modulation |
| - Loss = loss_fusion + λ_v * loss_visual + λ_a * loss_audio |
| - Asymmetric λ weights: λ_visual=1.5 (boost weak), λ_audio=0.5 (dampen dominant) |
| """ |
|
|
| def __init__(self, model_config, lora_config=None, finetune_method="lora", |
| mode="multimodal", use_ogm=True, lambda_visual=1.5, lambda_audio=0.5): |
| super().__init__() |
| self.mode = mode |
| self.modality_dropout_p = model_config.modality_dropout_p |
| self.use_ogm = use_ogm |
| self.lambda_visual = lambda_visual |
| self.lambda_audio = lambda_audio |
|
|
| |
| self.visual_branch = ( |
| VisualBranch(model_config, lora_config, finetune_method) |
| if mode in ("multimodal", "visual_only") else None) |
| self.audio_branch = ( |
| AudioBranch(model_config, lora_config, finetune_method) |
| if mode in ("multimodal", "audio_only") else None) |
|
|
| |
| if mode == "multimodal": |
| self.fusion = LateFusion(model_config) |
|
|
| |
| |
| self.visual_head = nn.Sequential( |
| nn.LayerNorm(model_config.vit_embed_dim), |
| nn.Dropout(0.2), |
| nn.Linear(model_config.vit_embed_dim, model_config.num_classes)) |
| self.audio_head = nn.Sequential( |
| nn.LayerNorm(model_config.ast_embed_dim), |
| nn.Dropout(0.2), |
| nn.Linear(model_config.ast_embed_dim, model_config.num_classes)) |
| else: |
| embed_dim = (model_config.vit_embed_dim if mode == "visual_only" |
| else model_config.ast_embed_dim) |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(embed_dim), |
| nn.Dropout(model_config.fusion_dropout), |
| nn.Linear(embed_dim, model_config.fusion_dim), |
| nn.GELU(), |
| nn.Dropout(model_config.fusion_dropout), |
| nn.Linear(model_config.fusion_dim, model_config.num_classes)) |
|
|
| self.loss_fn = nn.CrossEntropyLoss() |
|
|
| |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| print(f"[Model v2] Mode={mode}, Total={total:,}, Trainable={trainable:,} " |
| f"({100*trainable/total:.2f}%)") |
| if mode == "multimodal": |
| print(f"[Model v2] OGM-GE={'ON' if use_ogm else 'OFF'}, " |
| f"λ_visual={lambda_visual}, λ_audio={lambda_audio}") |
|
|
| def forward(self, pixel_values=None, audio_values=None, labels=None): |
| if self.mode == "multimodal": |
| v_emb = self.visual_branch(pixel_values) |
| a_emb = self.audio_branch(audio_values) |
|
|
| |
| mask = None |
| if self.training and self.modality_dropout_p > 0: |
| mask = { |
| "visual": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0, |
| "audio": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0, |
| } |
| |
| if mask["visual"] == 0.0 and mask["audio"] == 0.0: |
| mask["visual" if torch.rand(1).item() < 0.5 else "audio"] = 1.0 |
|
|
| |
| logits = self.fusion(v_emb, a_emb, mask) |
|
|
| |
| visual_logits = self.visual_head(v_emb) |
| audio_logits = self.audio_head(a_emb) |
|
|
| outputs = { |
| "logits": logits, |
| "visual_logits": visual_logits, |
| "audio_logits": audio_logits, |
| "visual_emb": v_emb, |
| "audio_emb": a_emb, |
| } |
|
|
| if labels is not None: |
| loss_fusion = self.loss_fn(logits, labels) |
| loss_visual = self.loss_fn(visual_logits, labels) |
| loss_audio = self.loss_fn(audio_logits, labels) |
|
|
| |
| loss = (loss_fusion |
| + self.lambda_visual * loss_visual |
| + self.lambda_audio * loss_audio) |
|
|
| outputs["loss"] = loss |
| outputs["loss_fusion"] = loss_fusion.item() |
| outputs["loss_visual"] = loss_visual.item() |
| outputs["loss_audio"] = loss_audio.item() |
|
|
| elif self.mode == "visual_only": |
| logits = self.classifier(self.visual_branch(pixel_values)) |
| outputs = {"logits": logits} |
| if labels is not None: |
| outputs["loss"] = self.loss_fn(logits, labels) |
|
|
| else: |
| logits = self.classifier(self.audio_branch(audio_values)) |
| outputs = {"logits": logits} |
| if labels is not None: |
| outputs["loss"] = self.loss_fn(logits, labels) |
|
|
| return outputs |
|
|
|
|
| |
| |
| |
|
|
| def create_model(model_config, lora_config, mode="multimodal", |
| finetune_method="lora", use_ogm=True, |
| lambda_visual=1.5, lambda_audio=0.5): |
| """Create model with v2 anti-collapse features.""" |
| return MultimodalPCFaultDetector( |
| model_config, lora_config, finetune_method, mode, |
| use_ogm=use_ogm, |
| lambda_visual=lambda_visual, |
| lambda_audio=lambda_audio) |
|
|
|
|
| def get_processors(model_config): |
| """Load ViT image processor and AST feature extractor.""" |
| return ( |
| ViTImageProcessor.from_pretrained(model_config.vit_model_name), |
| ASTFeatureExtractor.from_pretrained(model_config.ast_model_name)) |
|
|