pangweijlu commited on
Commit
e384945
·
verified ·
1 Parent(s): 0125717

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +214 -0
model.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal Fraudulent Paper Detection - Core Model Architecture
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import AutoModel, AutoTokenizer, ViTModel, AutoConfig
8
+ from typing import Dict, Optional
9
+
10
+
11
+ class TextEncoder(nn.Module):
12
+ """SciBERT-based text encoder with layer freezing."""
13
+
14
+ def __init__(self, model_name="allenai/scibert_scivocab_uncased", freeze_layers=6):
15
+ super().__init__()
16
+ self.config = AutoConfig.from_pretrained(model_name)
17
+ self.encoder = AutoModel.from_pretrained(model_name)
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ if freeze_layers > 0:
20
+ for layer in self.encoder.encoder.layer[:freeze_layers]:
21
+ for param in layer.parameters():
22
+ param.requires_grad = False
23
+
24
+ def forward(self, input_ids, attention_mask):
25
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
26
+ return outputs.last_hidden_state[:, 0, :]
27
+
28
+ def get_embedding_dim(self):
29
+ return self.config.hidden_size
30
+
31
+
32
+ class ImageEncoder(nn.Module):
33
+ """ViT + forensic CNN for scientific figure analysis."""
34
+
35
+ def __init__(self, model_name="google/vit-base-patch16-224", forensic_dim=64):
36
+ super().__init__()
37
+ self.vit = ViTModel.from_pretrained(model_name)
38
+ self.forensic_cnn = nn.Sequential(
39
+ nn.Conv2d(3, 16, kernel_size=5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
40
+ nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
41
+ nn.AdaptiveAvgPool2d((8, 8)), nn.Flatten(),
42
+ nn.Linear(32 * 8 * 8, forensic_dim)
43
+ )
44
+ vit_dim = self.vit.config.hidden_size
45
+ self.fusion_proj = nn.Linear(vit_dim + forensic_dim, vit_dim)
46
+
47
+ def forward(self, pixel_values, forensic_features=None):
48
+ vit_out = self.vit(pixel_values).last_hidden_state[:, 0, :]
49
+ if forensic_features is not None:
50
+ forensic_out = self.forensic_cnn(forensic_features)
51
+ combined = torch.cat([vit_out, forensic_out], dim=-1)
52
+ return self.fusion_proj(combined)
53
+ return vit_out
54
+
55
+ def get_embedding_dim(self):
56
+ return self.vit.config.hidden_size
57
+
58
+
59
+ class TabularEncoder(nn.Module):
60
+ """FT-Transformer style encoder for tabular data."""
61
+
62
+ def __init__(self, num_features, hidden_dim=256, num_layers=4, num_heads=8):
63
+ super().__init__()
64
+ self.input_proj = nn.Linear(num_features, hidden_dim)
65
+ encoder_layer = nn.TransformerEncoderLayer(
66
+ d_model=hidden_dim, nhead=num_heads,
67
+ dim_feedforward=hidden_dim*4, dropout=0.1, batch_first=True
68
+ )
69
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
70
+ self.hidden_dim = hidden_dim
71
+
72
+ def forward(self, tabular_features):
73
+ x = self.input_proj(tabular_features).unsqueeze(1)
74
+ x = self.transformer(x)
75
+ return x.squeeze(1)
76
+
77
+ def get_embedding_dim(self):
78
+ return self.hidden_dim
79
+
80
+
81
+ class MetadataEncoder(nn.Module):
82
+ """MLP encoder for metadata (author, journal, citation patterns)."""
83
+
84
+ def __init__(self, metadata_dim, hidden_dim=128):
85
+ super().__init__()
86
+ self.mlp = nn.Sequential(
87
+ nn.Linear(metadata_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
88
+ nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
89
+ nn.Linear(hidden_dim, hidden_dim)
90
+ )
91
+ self.hidden_dim = hidden_dim
92
+
93
+ def forward(self, metadata):
94
+ return self.mlp(metadata)
95
+
96
+ def get_embedding_dim(self):
97
+ return self.hidden_dim
98
+
99
+
100
+ class CrossModalFusion(nn.Module):
101
+ """Cross-modal attention fusion layer."""
102
+
103
+ def __init__(self, embed_dims, fused_dim=512, num_heads=8, num_layers=2):
104
+ super().__init__()
105
+ self.modalities = list(embed_dims.keys())
106
+ self.projections = nn.ModuleDict({
107
+ mod: nn.Linear(dim, fused_dim) for mod, dim in embed_dims.items()
108
+ })
109
+ encoder_layer = nn.TransformerEncoderLayer(
110
+ d_model=fused_dim, nhead=num_heads,
111
+ dim_feedforward=fused_dim*4, dropout=0.1, batch_first=True
112
+ )
113
+ self.fusion_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
114
+ self.modality_embeddings = nn.ParameterDict({
115
+ mod: nn.Parameter(torch.randn(1, 1, fused_dim) * 0.02)
116
+ for mod in self.modalities
117
+ })
118
+
119
+ def forward(self, embeddings, mask=None):
120
+ batch_size = next(iter(embeddings.values())).size(0)
121
+ projected = []
122
+ for mod in self.modalities:
123
+ if mod in embeddings:
124
+ proj = self.projections[mod](embeddings[mod]).unsqueeze(1)
125
+ proj = proj + self.modality_embeddings[mod]
126
+ projected.append(proj)
127
+ stacked = torch.cat(projected, dim=1)
128
+ if mask is not None:
129
+ padding_mask = torch.ones(batch_size, len(self.modalities), dtype=torch.bool, device=stacked.device)
130
+ for i, mod in enumerate(self.modalities):
131
+ if mod in mask:
132
+ padding_mask[:, i] = ~mask[mod]
133
+ else:
134
+ padding_mask = None
135
+ fused = self.fusion_transformer(stacked, src_key_padding_mask=padding_mask)
136
+ if padding_mask is not None:
137
+ mask_expanded = (~padding_mask).unsqueeze(-1).float()
138
+ fused = (fused * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1)
139
+ else:
140
+ fused = fused.mean(dim=1)
141
+ return fused
142
+
143
+
144
+ class FraudDetectionHead(nn.Module):
145
+ """Classification head with explainability and anomaly scoring."""
146
+
147
+ def __init__(self, input_dim, num_classes=2, explanation_dim=256):
148
+ super().__init__()
149
+ self.classifier = nn.Sequential(
150
+ nn.Linear(input_dim, input_dim//2), nn.ReLU(), nn.Dropout(0.3),
151
+ nn.Linear(input_dim//2, input_dim//4), nn.ReLU(), nn.Dropout(0.2),
152
+ nn.Linear(input_dim//4, num_classes)
153
+ )
154
+ self.explanation_proj = nn.Sequential(
155
+ nn.Linear(input_dim, explanation_dim), nn.ReLU(),
156
+ nn.Linear(explanation_dim, 4)
157
+ )
158
+ self.anomaly_proj = nn.Linear(input_dim, 1)
159
+
160
+ def forward(self, fused_embedding):
161
+ logits = self.classifier(fused_embedding)
162
+ modality_scores = torch.sigmoid(self.explanation_proj(fused_embedding))
163
+ anomaly_score = torch.sigmoid(self.anomaly_proj(fused_embedding))
164
+ return logits, modality_scores, anomaly_score
165
+
166
+
167
+ class MultimodalFraudDetector(nn.Module):
168
+ """
169
+ Complete multimodal fraudulent paper detection system.
170
+ Combines text, image, tabular, and metadata modalities.
171
+ """
172
+
173
+ def __init__(self, text_model="allenai/scibert_scivocab_uncased",
174
+ image_model="google/vit-base-patch16-224", tabular_features=20,
175
+ metadata_features=15, fused_dim=512, freeze_text_layers=6):
176
+ super().__init__()
177
+ self.text_encoder = TextEncoder(text_model, freeze_text_layers)
178
+ self.image_encoder = ImageEncoder(image_model)
179
+ self.tabular_encoder = TabularEncoder(tabular_features)
180
+ self.metadata_encoder = MetadataEncoder(metadata_features)
181
+ embed_dims = {
182
+ 'text': self.text_encoder.get_embedding_dim(),
183
+ 'image': self.image_encoder.get_embedding_dim(),
184
+ 'tabular': self.tabular_encoder.get_embedding_dim(),
185
+ 'metadata': self.metadata_encoder.get_embedding_dim()
186
+ }
187
+ self.fusion = CrossModalFusion(embed_dims, fused_dim)
188
+ self.head = FraudDetectionHead(fused_dim)
189
+ self.fused_dim = fused_dim
190
+
191
+ def forward(self, text_input_ids=None, text_attention_mask=None,
192
+ image_pixels=None, image_forensic=None, tabular_features=None,
193
+ metadata_features=None, modality_mask=None):
194
+ embeddings = {}
195
+ mask = modality_mask or {}
196
+ if text_input_ids is not None:
197
+ embeddings['text'] = self.text_encoder(text_input_ids, text_attention_mask)
198
+ mask['text'] = torch.ones(text_input_ids.size(0), dtype=torch.bool, device=text_input_ids.device)
199
+ if image_pixels is not None:
200
+ embeddings['image'] = self.image_encoder(image_pixels, image_forensic)
201
+ mask['image'] = torch.ones(image_pixels.size(0), dtype=torch.bool, device=image_pixels.device)
202
+ if tabular_features is not None:
203
+ embeddings['tabular'] = self.tabular_encoder(tabular_features)
204
+ mask['tabular'] = torch.ones(tabular_features.size(0), dtype=torch.bool, device=tabular_features.device)
205
+ if metadata_features is not None:
206
+ embeddings['metadata'] = self.metadata_encoder(metadata_features)
207
+ mask['metadata'] = torch.ones(metadata_features.size(0), dtype=torch.bool, device=metadata_features.device)
208
+ fused = self.fusion(embeddings, mask)
209
+ logits, modality_scores, anomaly_score = self.head(fused)
210
+ return {
211
+ 'logits': logits, 'fused_embedding': fused,
212
+ 'modality_scores': modality_scores, 'anomaly_score': anomaly_score,
213
+ 'embeddings': embeddings
214
+ }