File size: 9,204 Bytes
e384945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Multimodal Fraudulent Paper Detection - Core Model Architecture
"""

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, ViTModel, AutoConfig
from typing import Dict, Optional


class TextEncoder(nn.Module):
    """SciBERT-based text encoder with layer freezing."""
    
    def __init__(self, model_name="allenai/scibert_scivocab_uncased", freeze_layers=6):
        super().__init__()
        self.config = AutoConfig.from_pretrained(model_name)
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if freeze_layers > 0:
            for layer in self.encoder.encoder.layer[:freeze_layers]:
                for param in layer.parameters():
                    param.requires_grad = False
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]
    
    def get_embedding_dim(self):
        return self.config.hidden_size


class ImageEncoder(nn.Module):
    """ViT + forensic CNN for scientific figure analysis."""
    
    def __init__(self, model_name="google/vit-base-patch16-224", forensic_dim=64):
        super().__init__()
        self.vit = ViTModel.from_pretrained(model_name)
        self.forensic_cnn = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((8, 8)), nn.Flatten(),
            nn.Linear(32 * 8 * 8, forensic_dim)
        )
        vit_dim = self.vit.config.hidden_size
        self.fusion_proj = nn.Linear(vit_dim + forensic_dim, vit_dim)
    
    def forward(self, pixel_values, forensic_features=None):
        vit_out = self.vit(pixel_values).last_hidden_state[:, 0, :]
        if forensic_features is not None:
            forensic_out = self.forensic_cnn(forensic_features)
            combined = torch.cat([vit_out, forensic_out], dim=-1)
            return self.fusion_proj(combined)
        return vit_out
    
    def get_embedding_dim(self):
        return self.vit.config.hidden_size


class TabularEncoder(nn.Module):
    """FT-Transformer style encoder for tabular data."""
    
    def __init__(self, num_features, hidden_dim=256, num_layers=4, num_heads=8):
        super().__init__()
        self.input_proj = nn.Linear(num_features, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads,
            dim_feedforward=hidden_dim*4, dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.hidden_dim = hidden_dim
    
    def forward(self, tabular_features):
        x = self.input_proj(tabular_features).unsqueeze(1)
        x = self.transformer(x)
        return x.squeeze(1)
    
    def get_embedding_dim(self):
        return self.hidden_dim


class MetadataEncoder(nn.Module):
    """MLP encoder for metadata (author, journal, citation patterns)."""
    
    def __init__(self, metadata_dim, hidden_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(metadata_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.hidden_dim = hidden_dim
    
    def forward(self, metadata):
        return self.mlp(metadata)
    
    def get_embedding_dim(self):
        return self.hidden_dim


class CrossModalFusion(nn.Module):
    """Cross-modal attention fusion layer."""
    
    def __init__(self, embed_dims, fused_dim=512, num_heads=8, num_layers=2):
        super().__init__()
        self.modalities = list(embed_dims.keys())
        self.projections = nn.ModuleDict({
            mod: nn.Linear(dim, fused_dim) for mod, dim in embed_dims.items()
        })
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=fused_dim, nhead=num_heads,
            dim_feedforward=fused_dim*4, dropout=0.1, batch_first=True
        )
        self.fusion_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.modality_embeddings = nn.ParameterDict({
            mod: nn.Parameter(torch.randn(1, 1, fused_dim) * 0.02)
            for mod in self.modalities
        })
    
    def forward(self, embeddings, mask=None):
        batch_size = next(iter(embeddings.values())).size(0)
        projected = []
        for mod in self.modalities:
            if mod in embeddings:
                proj = self.projections[mod](embeddings[mod]).unsqueeze(1)
                proj = proj + self.modality_embeddings[mod]
                projected.append(proj)
        stacked = torch.cat(projected, dim=1)
        if mask is not None:
            padding_mask = torch.ones(batch_size, len(self.modalities), dtype=torch.bool, device=stacked.device)
            for i, mod in enumerate(self.modalities):
                if mod in mask:
                    padding_mask[:, i] = ~mask[mod]
        else:
            padding_mask = None
        fused = self.fusion_transformer(stacked, src_key_padding_mask=padding_mask)
        if padding_mask is not None:
            mask_expanded = (~padding_mask).unsqueeze(-1).float()
            fused = (fused * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1)
        else:
            fused = fused.mean(dim=1)
        return fused


class FraudDetectionHead(nn.Module):
    """Classification head with explainability and anomaly scoring."""
    
    def __init__(self, input_dim, num_classes=2, explanation_dim=256):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, input_dim//2), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(input_dim//2, input_dim//4), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(input_dim//4, num_classes)
        )
        self.explanation_proj = nn.Sequential(
            nn.Linear(input_dim, explanation_dim), nn.ReLU(),
            nn.Linear(explanation_dim, 4)
        )
        self.anomaly_proj = nn.Linear(input_dim, 1)
    
    def forward(self, fused_embedding):
        logits = self.classifier(fused_embedding)
        modality_scores = torch.sigmoid(self.explanation_proj(fused_embedding))
        anomaly_score = torch.sigmoid(self.anomaly_proj(fused_embedding))
        return logits, modality_scores, anomaly_score


class MultimodalFraudDetector(nn.Module):
    """
    Complete multimodal fraudulent paper detection system.
    Combines text, image, tabular, and metadata modalities.
    """
    
    def __init__(self, text_model="allenai/scibert_scivocab_uncased",
                 image_model="google/vit-base-patch16-224", tabular_features=20,
                 metadata_features=15, fused_dim=512, freeze_text_layers=6):
        super().__init__()
        self.text_encoder = TextEncoder(text_model, freeze_text_layers)
        self.image_encoder = ImageEncoder(image_model)
        self.tabular_encoder = TabularEncoder(tabular_features)
        self.metadata_encoder = MetadataEncoder(metadata_features)
        embed_dims = {
            'text': self.text_encoder.get_embedding_dim(),
            'image': self.image_encoder.get_embedding_dim(),
            'tabular': self.tabular_encoder.get_embedding_dim(),
            'metadata': self.metadata_encoder.get_embedding_dim()
        }
        self.fusion = CrossModalFusion(embed_dims, fused_dim)
        self.head = FraudDetectionHead(fused_dim)
        self.fused_dim = fused_dim
    
    def forward(self, text_input_ids=None, text_attention_mask=None,
                image_pixels=None, image_forensic=None, tabular_features=None,
                metadata_features=None, modality_mask=None):
        embeddings = {}
        mask = modality_mask or {}
        if text_input_ids is not None:
            embeddings['text'] = self.text_encoder(text_input_ids, text_attention_mask)
            mask['text'] = torch.ones(text_input_ids.size(0), dtype=torch.bool, device=text_input_ids.device)
        if image_pixels is not None:
            embeddings['image'] = self.image_encoder(image_pixels, image_forensic)
            mask['image'] = torch.ones(image_pixels.size(0), dtype=torch.bool, device=image_pixels.device)
        if tabular_features is not None:
            embeddings['tabular'] = self.tabular_encoder(tabular_features)
            mask['tabular'] = torch.ones(tabular_features.size(0), dtype=torch.bool, device=tabular_features.device)
        if metadata_features is not None:
            embeddings['metadata'] = self.metadata_encoder(metadata_features)
            mask['metadata'] = torch.ones(metadata_features.size(0), dtype=torch.bool, device=metadata_features.device)
        fused = self.fusion(embeddings, mask)
        logits, modality_scores, anomaly_score = self.head(fused)
        return {
            'logits': logits, 'fused_embedding': fused,
            'modality_scores': modality_scores, 'anomaly_score': anomaly_score,
            'embeddings': embeddings
        }