pangweijlu commited on
Commit
aa7ac9b
·
verified ·
1 Parent(s): 127e91a

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +187 -0
train.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for multimodal fraudulent paper detection.
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from torch.utils.data import DataLoader, random_split
11
+ from transformers import get_linear_schedule_with_warmup
12
+ import numpy as np
13
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
14
+ from tqdm import tqdm
15
+ import json
16
+
17
+ from model import MultimodalFraudDetector
18
+ from data_loader import FraudPaperDataset, collate_fn
19
+
20
+
21
+ def compute_metrics(predictions, labels, probs):
22
+ preds = np.argmax(predictions, axis=1)
23
+ accuracy = accuracy_score(labels, preds)
24
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
25
+ try:
26
+ auc = roc_auc_score(labels, probs[:, 1])
27
+ except:
28
+ auc = 0.5
29
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}
30
+
31
+
32
+ def train_epoch(model, dataloader, optimizer, scheduler, device, epoch):
33
+ model.train()
34
+ total_loss = 0
35
+ all_preds, all_labels, all_probs = [], [], []
36
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
37
+ for batch in pbar:
38
+ input_ids = batch['input_ids'].to(device)
39
+ attention_mask = batch['attention_mask'].to(device)
40
+ tabular = batch['tabular_features'].to(device)
41
+ metadata = batch['metadata_features'].to(device)
42
+ labels = batch['labels'].to(device)
43
+ outputs = model(text_input_ids=input_ids, text_attention_mask=attention_mask,
44
+ tabular_features=tabular, metadata_features=metadata)
45
+ logits = outputs['logits']
46
+ modality_scores = outputs['modality_scores']
47
+ anomaly_score = outputs['anomaly_score']
48
+ ce_loss = nn.CrossEntropyLoss()(logits, labels)
49
+ consistency_loss = torch.mean((modality_scores - 0.5) ** 2) * 0.1
50
+ fraud_mask = labels == 1
51
+ if fraud_mask.any():
52
+ anomaly_loss = torch.mean((anomaly_score[fraud_mask] - 1.0) ** 2)
53
+ anomaly_loss += torch.mean((anomaly_score[~fraud_mask] - 0.0) ** 2)
54
+ else:
55
+ anomaly_loss = torch.tensor(0.0, device=device)
56
+ loss = ce_loss + consistency_loss + 0.1 * anomaly_loss
57
+ optimizer.zero_grad()
58
+ loss.backward()
59
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
60
+ optimizer.step()
61
+ scheduler.step()
62
+ total_loss += loss.item()
63
+ probs = torch.softmax(logits, dim=1).detach().cpu().numpy()
64
+ all_preds.append(logits.detach().cpu().numpy())
65
+ all_labels.append(labels.cpu().numpy())
66
+ all_probs.append(probs)
67
+ pbar.set_postfix({'loss': loss.item()})
68
+ all_preds = np.concatenate(all_preds)
69
+ all_labels = np.concatenate(all_labels)
70
+ all_probs = np.concatenate(all_probs)
71
+ metrics = compute_metrics(all_preds, all_labels, all_probs)
72
+ metrics['loss'] = total_loss / len(dataloader)
73
+ return metrics
74
+
75
+
76
+ def evaluate(model, dataloader, device):
77
+ model.eval()
78
+ total_loss = 0
79
+ all_preds, all_labels, all_probs = [], [], []
80
+ all_embeddings, all_anomaly = [], []
81
+ with torch.no_grad():
82
+ for batch in tqdm(dataloader, desc="Evaluating"):
83
+ input_ids = batch['input_ids'].to(device)
84
+ attention_mask = batch['attention_mask'].to(device)
85
+ tabular = batch['tabular_features'].to(device)
86
+ metadata = batch['metadata_features'].to(device)
87
+ labels = batch['labels'].to(device)
88
+ outputs = model(text_input_ids=input_ids, text_attention_mask=attention_mask,
89
+ tabular_features=tabular, metadata_features=metadata)
90
+ logits = outputs['logits']
91
+ loss = nn.CrossEntropyLoss()(logits, labels)
92
+ total_loss += loss.item()
93
+ probs = torch.softmax(logits, dim=1).cpu().numpy()
94
+ all_preds.append(logits.cpu().numpy())
95
+ all_labels.append(labels.cpu().numpy())
96
+ all_probs.append(probs)
97
+ all_embeddings.append(outputs['fused_embedding'].cpu().numpy())
98
+ all_anomaly.append(outputs['anomaly_score'].cpu().numpy())
99
+ all_preds = np.concatenate(all_preds)
100
+ all_labels = np.concatenate(all_labels)
101
+ all_probs = np.concatenate(all_probs)
102
+ all_embeddings = np.concatenate(all_embeddings)
103
+ all_anomaly = np.concatenate(all_anomaly)
104
+ metrics = compute_metrics(all_preds, all_labels, all_probs)
105
+ metrics['loss'] = total_loss / len(dataloader)
106
+ return metrics, all_embeddings, all_anomaly
107
+
108
+
109
+ def main():
110
+ print("=" * 60)
111
+ print("MULTIMODAL FRAUD DETECTION - TRAINING")
112
+ print("=" * 60)
113
+
114
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
115
+ print(f"Device: {device}")
116
+
117
+ output_dir = './outputs'
118
+ os.makedirs(output_dir, exist_ok=True)
119
+
120
+ # Load data
121
+ print("\nLoading dataset...")
122
+ dataset = FraudPaperDataset("Lihuchen/pubmed_retraction", split="train", max_length=256)
123
+
124
+ # Split
125
+ train_size = int(0.8 * len(dataset))
126
+ val_size = len(dataset) - train_size
127
+ train_ds, val_ds = random_split(dataset, [train_size, val_size])
128
+
129
+ train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, collate_fn=collate_fn)
130
+ val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2, collate_fn=collate_fn)
131
+
132
+ print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")
133
+
134
+ # Get dims
135
+ sample = next(iter(train_loader))
136
+ tabular_dim = sample['tabular_features'].shape[1]
137
+ metadata_dim = sample['metadata_features'].shape[1]
138
+ print(f"Tabular: {tabular_dim}, Metadata: {metadata_dim}")
139
+
140
+ # Model
141
+ print("\nBuilding model...")
142
+ model = MultimodalFraudDetector(
143
+ text_model="allenai/scibert_scivocab_uncased",
144
+ tabular_features=tabular_dim,
145
+ metadata_features=metadata_dim,
146
+ fused_dim=256,
147
+ freeze_text_layers=8
148
+ ).to(device)
149
+
150
+ total_params = sum(p.numel() for p in model.parameters())
151
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
152
+ print(f"Total params: {total_params:,}, Trainable: {trainable:,}")
153
+
154
+ # Optimizer
155
+ optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
156
+ total_steps = len(train_loader) * 3
157
+ warmup = int(total_steps * 0.1)
158
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup, num_training_steps=total_steps)
159
+
160
+ # Train
161
+ best_f1 = 0
162
+ for epoch in range(1, 4):
163
+ print(f"\n=== Epoch {epoch}/3 ===")
164
+ train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device, epoch)
165
+ print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, F1: {train_metrics['f1']:.4f}")
166
+ val_metrics, val_emb, val_anom = evaluate(model, val_loader, device)
167
+ print(f"Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['auc']:.4f}")
168
+
169
+ if val_metrics['f1'] > best_f1:
170
+ best_f1 = val_metrics['f1']
171
+ torch.save({
172
+ 'epoch': epoch,
173
+ 'model_state_dict': model.state_dict(),
174
+ 'f1': best_f1,
175
+ }, os.path.join(output_dir, 'best_model.pt'))
176
+ print(f"Saved best model (F1: {best_f1:.4f})")
177
+
178
+ # Save embeddings
179
+ np.save(os.path.join(output_dir, 'val_embeddings.npy'), val_emb)
180
+ np.save(os.path.join(output_dir, 'val_anomaly.npy'), val_anom)
181
+
182
+ print(f"\nTraining complete! Best F1: {best_f1:.4f}")
183
+ print(f"Outputs saved to {output_dir}")
184
+
185
+
186
+ if __name__ == '__main__':
187
+ main()