Ellaft commited on
Commit
38fdf87
·
verified ·
1 Parent(s): 3f69a1e

Add model architecture (ViT + AST + Late Fusion + LoRA)

Browse files
Files changed (1) hide show
  1. src/models.py +131 -0
src/models.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal PC Fault Detection - Model Architecture
3
+ ====================================================
4
+ Two-branch architecture:
5
+ - Visual: ViT-B/16 pretrained on ImageNet-21k
6
+ - Audio: AST pretrained on AudioSet
7
+ - Fusion: Late fusion (concat / weighted sum / attention)
8
+
9
+ Supports LoRA, full fine-tuning, and linear probe modes.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from typing import Dict, Optional, Literal
16
+ from transformers import ViTModel, ASTModel, ViTImageProcessor, ASTFeatureExtractor
17
+ from peft import LoraConfig, get_peft_model
18
+ from config import ModelConfig, LoRAConfig, FAULT_CLASSES
19
+
20
+
21
+ class VisualBranch(nn.Module):
22
+ def __init__(self, config, lora_config=None, finetune_method="lora"):
23
+ super().__init__()
24
+ self.vit = ViTModel.from_pretrained(config.vit_model_name)
25
+ if finetune_method == "lora" and lora_config and lora_config.enabled:
26
+ peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
27
+ target_modules=lora_config.vit_target_modules, lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
28
+ self.vit = get_peft_model(self.vit, peft_config)
29
+ self.vit.print_trainable_parameters()
30
+ elif finetune_method == "linear_probe":
31
+ for param in self.vit.parameters(): param.requires_grad = False
32
+
33
+ def forward(self, pixel_values):
34
+ return self.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :]
35
+
36
+
37
+ class AudioBranch(nn.Module):
38
+ def __init__(self, config, lora_config=None, finetune_method="lora"):
39
+ super().__init__()
40
+ self.ast = ASTModel.from_pretrained(config.ast_model_name)
41
+ if finetune_method == "lora" and lora_config and lora_config.enabled:
42
+ peft_config = LoraConfig(r=lora_config.r, lora_alpha=lora_config.lora_alpha,
43
+ target_modules=lora_config.ast_target_modules, lora_dropout=lora_config.lora_dropout, bias=lora_config.bias)
44
+ self.ast = get_peft_model(self.ast, peft_config)
45
+ self.ast.print_trainable_parameters()
46
+ elif finetune_method == "linear_probe":
47
+ for param in self.ast.parameters(): param.requires_grad = False
48
+
49
+ def forward(self, input_values):
50
+ return self.ast(input_values=input_values).last_hidden_state[:, 0, :]
51
+
52
+
53
+ class LateFusion(nn.Module):
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.fusion_type = config.fusion_type
57
+ if config.fusion_type == "concat":
58
+ self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
59
+ self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
60
+ self.classifier = nn.Sequential(nn.LayerNorm(config.fusion_dim * 2), nn.Dropout(config.fusion_dropout),
61
+ nn.Linear(config.fusion_dim * 2, config.fusion_dim), nn.GELU(), nn.Dropout(config.fusion_dropout),
62
+ nn.Linear(config.fusion_dim, config.num_classes))
63
+ elif config.fusion_type == "weighted_sum":
64
+ self.visual_head = nn.Linear(config.vit_embed_dim, config.num_classes)
65
+ self.audio_head = nn.Linear(config.ast_embed_dim, config.num_classes)
66
+ self.fusion_weights = nn.Parameter(torch.tensor([0.5, 0.5]))
67
+ elif config.fusion_type == "attention":
68
+ self.visual_proj = nn.Linear(config.vit_embed_dim, config.fusion_dim)
69
+ self.audio_proj = nn.Linear(config.ast_embed_dim, config.fusion_dim)
70
+ self.cross_attn = nn.MultiheadAttention(embed_dim=config.fusion_dim, num_heads=8, dropout=config.fusion_dropout, batch_first=True)
71
+ self.classifier = nn.Sequential(nn.LayerNorm(config.fusion_dim), nn.Dropout(config.fusion_dropout), nn.Linear(config.fusion_dim, config.num_classes))
72
+
73
+ def forward(self, visual_emb, audio_emb, modality_mask=None):
74
+ if modality_mask:
75
+ visual_emb = visual_emb * modality_mask.get("visual", 1.0)
76
+ audio_emb = audio_emb * modality_mask.get("audio", 1.0)
77
+ if self.fusion_type == "concat":
78
+ fused = torch.cat([self.visual_proj(visual_emb), self.audio_proj(audio_emb)], dim=-1)
79
+ return self.classifier(fused)
80
+ elif self.fusion_type == "weighted_sum":
81
+ w = torch.softmax(self.fusion_weights, dim=0)
82
+ return w[0] * self.visual_head(visual_emb) + w[1] * self.audio_head(audio_emb)
83
+ elif self.fusion_type == "attention":
84
+ tokens = torch.cat([self.visual_proj(visual_emb).unsqueeze(1), self.audio_proj(audio_emb).unsqueeze(1)], dim=1)
85
+ return self.classifier(self.cross_attn(tokens, tokens, tokens)[0].mean(dim=1))
86
+
87
+
88
+ class MultimodalPCFaultDetector(nn.Module):
89
+ def __init__(self, model_config, lora_config=None, finetune_method="lora", mode="multimodal"):
90
+ super().__init__()
91
+ self.mode, self.modality_dropout_p = mode, model_config.modality_dropout_p
92
+ self.visual_branch = VisualBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "visual_only") else None
93
+ self.audio_branch = AudioBranch(model_config, lora_config, finetune_method) if mode in ("multimodal", "audio_only") else None
94
+ if mode == "multimodal":
95
+ self.fusion = LateFusion(model_config)
96
+ else:
97
+ embed_dim = model_config.vit_embed_dim if mode == "visual_only" else model_config.ast_embed_dim
98
+ self.classifier = nn.Sequential(nn.LayerNorm(embed_dim), nn.Dropout(model_config.fusion_dropout),
99
+ nn.Linear(embed_dim, model_config.fusion_dim), nn.GELU(), nn.Dropout(model_config.fusion_dropout),
100
+ nn.Linear(model_config.fusion_dim, model_config.num_classes))
101
+ self.loss_fn = nn.CrossEntropyLoss()
102
+ total = sum(p.numel() for p in self.parameters())
103
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
104
+ print(f"[Model] Mode={mode}, Total={total:,}, Trainable={trainable:,} ({100*trainable/total:.2f}%)")
105
+
106
+ def forward(self, pixel_values=None, audio_values=None, labels=None):
107
+ if self.mode == "multimodal":
108
+ v_emb, a_emb = self.visual_branch(pixel_values), self.audio_branch(audio_values)
109
+ mask = None
110
+ if self.training and self.modality_dropout_p > 0:
111
+ mask = {"visual": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0,
112
+ "audio": 0.0 if torch.rand(1).item() < self.modality_dropout_p else 1.0}
113
+ if mask["visual"] == 0.0 and mask["audio"] == 0.0:
114
+ mask["visual" if torch.rand(1).item() < 0.5 else "audio"] = 1.0
115
+ logits = self.fusion(v_emb, a_emb, mask)
116
+ elif self.mode == "visual_only":
117
+ logits = self.classifier(self.visual_branch(pixel_values))
118
+ else:
119
+ logits = self.classifier(self.audio_branch(audio_values))
120
+ outputs = {"logits": logits}
121
+ if labels is not None:
122
+ outputs["loss"] = self.loss_fn(logits, labels)
123
+ return outputs
124
+
125
+
126
+ def create_model(model_config, lora_config, mode="multimodal", finetune_method="lora"):
127
+ return MultimodalPCFaultDetector(model_config, lora_config, finetune_method, mode)
128
+
129
+ def get_processors(model_config):
130
+ return (ViTImageProcessor.from_pretrained(model_config.vit_model_name),
131
+ ASTFeatureExtractor.from_pretrained(model_config.ast_model_name))