Ellaft commited on
Commit
4e62071
verified
1 Parent(s): d1b8db1

Overwrite models.py with v2 architecture (auxiliary heads + OGM-GE + anti-collapse)

Browse files
Files changed (1) hide show
  1. src/models.py +80 -24
src/models.py CHANGED
@@ -1,12 +1,14 @@
1
  """
2
  Multimodal PC Fault Detection - Model Architecture
3
  ====================================================
4
- Two-branch architecture:
5
- - Visual: ViT-B/16 pretrained on ImageNet-21k
6
- - Audio: AST pretrained on AudioSet
7
  - Fusion: Late fusion (concat / weighted sum / attention)
8
-
9
- Supports LoRA, full fine-tuning, and linear probe modes.
 
 
10
  """
11
 
12
  import torch
@@ -24,12 +26,13 @@ class VisualBranch(nn.Module):
24
  self.vit = ViTModel.from_pretrained(config.vit_model_name)
25
  if finetune_method == "lora" and lora_config and lora_config.enabled:
26
  peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
27
- target_modules=lora_config.vit_target_modules, lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
 
28
  self.vit = get_peft_model(self.vit, peft_config)
29
  self.vit.print_trainable_parameters()
30
  elif finetune_method == "linear_probe":
31
  for param in self.vit.parameters(): param.requires_grad = False
32
-
33
  def forward(self, pixel_values):
34
  return self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :]
35
 
@@ -40,12 +43,13 @@ class AudioBranch(nn.Module):
40
  self.ast = ASTModel.from_pretrained(config.ast_model_name)
41
  if finetune_method == "lora" and lora_config and lora_config.enabled:
42
  peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
43
- target_modules=lora_config.ast_target_modules, lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
 
44
  self.ast = get_peft_model(self.ast, peft_config)
45
  self.ast.print_trainable_parameters()
46
  elif finetune_method == "linear_probe":
47
  for param in self.ast.parameters(): param.requires_grad = False
48
-
49
  def forward(self, input_values):
50
  return self.ast(input_values=input_values).last_hidden_state[:, 0, :]
51
 
@@ -57,9 +61,10 @@ class LateFusion(nn.Module):
57
  if config.fusion_type == "concat":
58
  self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
59
  self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
60
- self.classifier = nn.Sequential(nn.LayerNorm(config.fusion_dim * 2), nn.Dropout(config.fusion_dropout),
61
- nn.Linear(config.fusion_dim * 2, config.fusion_dim), nn.GELU(), nn.Dropout(config.fusion_dropout),
62
- nn.Linear(config.fusion_dim, config.num_classes))
 
63
  elif config.fusion_type == "weighted_sum":
64
  self.visual_head = nn.Linear(config.vit_embed_dim, config.num_classes)
65
  self.audio_head = nn.Linear(config.ast_embed_dim, config.num_classes)
@@ -67,16 +72,17 @@ class LateFusion(nn.Module):
67
  elif config.fusion_type == "attention":
68
  self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
69
  self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
70
- self.cross_attn = nn.MultiheadAttention(embed_dim=config.fusion_dim, num_heads=8, dropout=config.fusion_dropout, batch_first=True)
71
- self.classifier = nn.Sequential(nn.LayerNorm(config.fusion_dim), nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes))
72
-
 
 
73
  def forward(self, visual_emb, audio_emb, modality_mask=None):
74
  if modality_mask:
75
  visual_emb = visual_emb * modality_mask.get("visual", 1.0)
76
  audio_emb = audio_emb * modality_mask.get("audio", 1.0)
77
  if self.fusion_type == "concat":
78
- fused = torch.cat([self.visual_proj(visual_emb), self.audio_proj(audio_emb)], dim=-1)
79
- return self.classifier(fused)
80
  elif self.fusion_type == "weighted_sum":
81
  w = torch.softmax(self.fusion_weights, dim=0)
82
  return w[0] * self.visual_head(visual_emb) + w[1] * self.audio_head(audio_emb)
@@ -85,14 +91,55 @@ class LateFusion(nn.Module):
85
  return self.classifier(self.cross_attn(tokens, tokens, tokens)[0].mean(dim=1))
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  class MultimodalPCFaultDetector(nn.Module):
89
- def __init__(self, model_config, lora_config=None, finetune_method="lora", mode="multimodal"):
 
90
  super().__init__()
91
  self.mode, self.modality_dropout_p = mode, model_config.modality_dropout_p
 
92
  self.visual_branch = VisualBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "visual_only") else None
93
  self.audio_branch = AudioBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "audio_only") else None
94
  if mode == "multimodal":
95
  self.fusion = LateFusion(model_config)
 
 
96
  else:
97
  embed_dim = model_config.vit_embed_dim if mode == "visual_only" else model_config.ast_embed_dim
98
  self.classifier = nn.Sequential(nn.LayerNorm(embed_dim), nn.Dropout(model_config.fusion_dropout),
@@ -102,7 +149,7 @@ class MultimodalPCFaultDetector(nn.Module):
102
  total = sum(p.numel() for p in self.parameters())
103
  trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
104
  print(f"[Model] Mode={mode}, Total={total:,}, Trainable={trainable:,} ({100*trainable/total:.2f}%)")
105
-
106
  def forward(self, pixel_values=None, audio_values=None, labels=None):
107
  if self.mode == "multimodal":
108
  v_emb, a_emb = self.visual_branch(pixel_values), self.audio_branch(audio_values)
@@ -113,18 +160,27 @@ class MultimodalPCFaultDetector(nn.Module):
113
  if mask["visual"] == 0.0 and mask["audio"] == 0.0:
114
  mask["visual" if torch.rand(1).item() < 0.5 else "audio"] = 1.0
115
  logits = self.fusion(v_emb, a_emb, mask)
 
 
 
 
 
 
116
  elif self.mode == "visual_only":
117
  logits = self.classifier(self.visual_branch(pixel_values))
 
 
118
  else:
119
  logits = self.classifier(self.audio_branch(audio_values))
120
- outputs = {"logits": logits}
121
- if labels is not None:
122
- outputs["loss"] = self.loss_fn(logits, labels)
123
  return outputs
124
 
125
 
126
- def create_model(model_config, lora_config, mode="multimodal", finetune_method="lora"):
127
- return MultimodalPCFaultDetector(model_config, lora_config, finetune_method, mode)
 
 
128
 
129
  def get_processors(model_config):
130
  return (ViTImageProcessor.from_pretrained(model_config.vit_model_name),
 
1
  """
2
  Multimodal PC Fault Detection - Model Architecture
3
  ====================================================
4
+ Two-branch architecture with anti-modality-collapse features:
5
+ - Visual: ViT-B/16 (ImageNet-21k) + LoRA
6
+ - Audio: AST (AudioSet) + LoRA
7
  - Fusion: Late fusion (concat / weighted sum / attention)
8
+ - Auxiliary unimodal classification heads
9
+ - OGM-GE gradient modulation support (Peng et al., CVPR 2022)
10
+
11
+ Loss = L_fusion + 位_visual * L_visual + 位_audio * L_audio
12
  """
13
 
14
  import torch
 
26
  self.vit = ViTModel.from_pretrained(config.vit_model_name)
27
  if finetune_method == "lora" and lora_config and lora_config.enabled:
28
  peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
29
+ target_modules=lora_config.vit_target_modules,
30
+ lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
31
  self.vit = get_peft_model(self.vit, peft_config)
32
  self.vit.print_trainable_parameters()
33
  elif finetune_method == "linear_probe":
34
  for param in self.vit.parameters(): param.requires_grad = False
35
+
36
  def forward(self, pixel_values):
37
  return self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :]
38
 
 
43
  self.ast = ASTModel.from_pretrained(config.ast_model_name)
44
  if finetune_method == "lora" and lora_config and lora_config.enabled:
45
  peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
46
+ target_modules=lora_config.ast_target_modules,
47
+ lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
48
  self.ast = get_peft_model(self.ast, peft_config)
49
  self.ast.print_trainable_parameters()
50
  elif finetune_method == "linear_probe":
51
  for param in self.ast.parameters(): param.requires_grad = False
52
+
53
  def forward(self, input_values):
54
  return self.ast(input_values=input_values).last_hidden_state[:, 0, :]
55
 
 
61
  if config.fusion_type == "concat":
62
  self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
63
  self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
64
+ self.classifier = nn.Sequential(
65
+ nn.LayerNorm(config.fusion_dim * 2), nn.Dropout(config.fusion_dropout),
66
+ nn.Linear(config.fusion_dim * 2, config.fusion_dim), nn.GELU(),
67
+ nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes))
68
  elif config.fusion_type == "weighted_sum":
69
  self.visual_head = nn.Linear(config.vit_embed_dim, config.num_classes)
70
  self.audio_head = nn.Linear(config.ast_embed_dim, config.num_classes)
 
72
  elif config.fusion_type == "attention":
73
  self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
74
  self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
75
+ self.cross_attn = nn.MultiheadAttention(embed_dim=config.fusion_dim, num_heads=8,
76
+ dropout=config.fusion_dropout, batch_first=True)
77
+ self.classifier = nn.Sequential(nn.LayerNorm(config.fusion_dim),
78
+ nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes))
79
+
80
  def forward(self, visual_emb, audio_emb, modality_mask=None):
81
  if modality_mask:
82
  visual_emb = visual_emb * modality_mask.get("visual", 1.0)
83
  audio_emb = audio_emb * modality_mask.get("audio", 1.0)
84
  if self.fusion_type == "concat":
85
+ return self.classifier(torch.cat([self.visual_proj(visual_emb), self.audio_proj(audio_emb)], dim=-1))
 
86
  elif self.fusion_type == "weighted_sum":
87
  w = torch.softmax(self.fusion_weights, dim=0)
88
  return w[0] * self.visual_head(visual_emb) + w[1] * self.audio_head(audio_emb)
 
91
  return self.classifier(self.cross_attn(tokens, tokens, tokens)[0].mean(dim=1))
92
 
93
 
94
+ class OGMGEModulator:
95
+ """OGM-GE from Peng et al., CVPR 2022. Suppresses dominant modality gradients."""
96
+ def __init__(self, alpha=0.3, noise_sigma=0.1):
97
+ self.alpha = alpha
98
+ self.noise_sigma = noise_sigma
99
+
100
+ @torch.no_grad()
101
+ def compute_modulation_coefficients(self, visual_logits, audio_logits, labels):
102
+ v_probs = F.softmax(visual_logits, dim=-1)
103
+ a_probs = F.softmax(audio_logits, dim=-1)
104
+ batch_idx = torch.arange(labels.size(0), device=labels.device)
105
+ v_conf = v_probs[batch_idx, labels].mean().item()
106
+ a_conf = a_probs[batch_idx, labels].mean().item()
107
+ eps = 1e-8
108
+ ratio = (v_conf + eps) / (a_conf + eps)
109
+ if ratio > 1.0:
110
+ coeff_visual = 1.0 - self.alpha * torch.tanh(torch.tensor(ratio - 1.0)).item()
111
+ coeff_audio = 1.0
112
+ else:
113
+ coeff_visual = 1.0
114
+ coeff_audio = 1.0 - self.alpha * torch.tanh(torch.tensor(1.0 / ratio - 1.0)).item()
115
+ return coeff_visual, coeff_audio, {"visual_conf": v_conf, "audio_conf": a_conf,
116
+ "ratio": ratio, "coeff_visual": coeff_visual, "coeff_audio": coeff_audio}
117
+
118
+ def apply_gradient_modulation(self, model, coeff_visual, coeff_audio):
119
+ for name, param in model.named_parameters():
120
+ if param.grad is None: continue
121
+ if "visual_branch" in name:
122
+ param.grad.data.mul_(coeff_visual)
123
+ if coeff_visual < 1.0 and self.noise_sigma > 0:
124
+ param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean())
125
+ elif "audio_branch" in name:
126
+ param.grad.data.mul_(coeff_audio)
127
+ if coeff_audio < 1.0 and self.noise_sigma > 0:
128
+ param.grad.data.add_(torch.randn_like(param.grad.data) * self.noise_sigma * param.grad.data.abs().mean())
129
+
130
+
131
  class MultimodalPCFaultDetector(nn.Module):
132
+ def __init__(self, model_config, lora_config=None, finetune_method="lora",
133
+ mode="multimodal", use_ogm=True, lambda_visual=1.5, lambda_audio=0.5):
134
  super().__init__()
135
  self.mode, self.modality_dropout_p = mode, model_config.modality_dropout_p
136
+ self.use_ogm, self.lambda_visual, self.lambda_audio = use_ogm, lambda_visual, lambda_audio
137
  self.visual_branch = VisualBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "visual_only") else None
138
  self.audio_branch = AudioBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "audio_only") else None
139
  if mode == "multimodal":
140
  self.fusion = LateFusion(model_config)
141
+ self.visual_head = nn.Sequential(nn.LayerNorm(model_config.vit_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.vit_embed_dim, model_config.num_classes))
142
+ self.audio_head = nn.Sequential(nn.LayerNorm(model_config.ast_embed_dim), nn.Dropout(0.2), nn.Linear(model_config.ast_embed_dim, model_config.num_classes))
143
  else:
144
  embed_dim = model_config.vit_embed_dim if mode == "visual_only" else model_config.ast_embed_dim
145
  self.classifier = nn.Sequential(nn.LayerNorm(embed_dim), nn.Dropout(model_config.fusion_dropout),
 
149
  total = sum(p.numel() for p in self.parameters())
150
  trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
151
  print(f"[Model] Mode={mode}, Total={total:,}, Trainable={trainable:,} ({100*trainable/total:.2f}%)")
152
+
153
  def forward(self, pixel_values=None, audio_values=None, labels=None):
154
  if self.mode == "multimodal":
155
  v_emb, a_emb = self.visual_branch(pixel_values), self.audio_branch(audio_values)
 
160
  if mask["visual"] == 0.0 and mask["audio"] == 0.0:
161
  mask["visual" if torch.rand(1).item() < 0.5 else "audio"] = 1.0
162
  logits = self.fusion(v_emb, a_emb, mask)
163
+ visual_logits, audio_logits = self.visual_head(v_emb), self.audio_head(a_emb)
164
+ outputs = {"logits": logits, "visual_logits": visual_logits, "audio_logits": audio_logits, "visual_emb": v_emb, "audio_emb": a_emb}
165
+ if labels is not None:
166
+ loss_f, loss_v, loss_a = self.loss_fn(logits, labels), self.loss_fn(visual_logits, labels), self.loss_fn(audio_logits, labels)
167
+ outputs.update({"loss": loss_f + self.lambda_visual * loss_v + self.lambda_audio * loss_a,
168
+ "loss_fusion": loss_f.item(), "loss_visual": loss_v.item(), "loss_audio": loss_a.item()})
169
  elif self.mode == "visual_only":
170
  logits = self.classifier(self.visual_branch(pixel_values))
171
+ outputs = {"logits": logits}
172
+ if labels is not None: outputs["loss"] = self.loss_fn(logits, labels)
173
  else:
174
  logits = self.classifier(self.audio_branch(audio_values))
175
+ outputs = {"logits": logits}
176
+ if labels is not None: outputs["loss"] = self.loss_fn(logits, labels)
 
177
  return outputs
178
 
179
 
180
+ def create_model(model_config, lora_config, mode="multimodal", finetune_method="lora",
181
+ use_ogm=True, lambda_visual=1.5, lambda_audio=0.5):
182
+ return MultimodalPCFaultDetector(model_config, lora_config, finetune_method, mode,
183
+ use_ogm=use_ogm, lambda_visual=lambda_visual, lambda_audio=lambda_audio)
184
 
185
  def get_processors(model_config):
186
  return (ViTImageProcessor.from_pretrained(model_config.vit_model_name),