| """ |
| Training script for multimodal fraudulent paper detection. |
| """ |
|
|
| import os |
| import sys |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, random_split |
| from transformers import get_linear_schedule_with_warmup |
| import numpy as np |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score |
| from tqdm import tqdm |
| import json |
|
|
| from model import MultimodalFraudDetector |
| from data_loader import FraudPaperDataset, collate_fn |
|
|
|
|
| def compute_metrics(predictions, labels, probs): |
| preds = np.argmax(predictions, axis=1) |
| accuracy = accuracy_score(labels, preds) |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0) |
| try: |
| auc = roc_auc_score(labels, probs[:, 1]) |
| except: |
| auc = 0.5 |
| return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc} |
|
|
|
|
| def train_epoch(model, dataloader, optimizer, scheduler, device, epoch): |
| model.train() |
| total_loss = 0 |
| all_preds, all_labels, all_probs = [], [], [] |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch}") |
| for batch in pbar: |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| tabular = batch['tabular_features'].to(device) |
| metadata = batch['metadata_features'].to(device) |
| labels = batch['labels'].to(device) |
| outputs = model(text_input_ids=input_ids, text_attention_mask=attention_mask, |
| tabular_features=tabular, metadata_features=metadata) |
| logits = outputs['logits'] |
| modality_scores = outputs['modality_scores'] |
| anomaly_score = outputs['anomaly_score'] |
| ce_loss = nn.CrossEntropyLoss()(logits, labels) |
| consistency_loss = torch.mean((modality_scores - 0.5) ** 2) * 0.1 |
| fraud_mask = labels == 1 |
| if fraud_mask.any(): |
| anomaly_loss = torch.mean((anomaly_score[fraud_mask] - 1.0) ** 2) |
| anomaly_loss += torch.mean((anomaly_score[~fraud_mask] - 0.0) ** 2) |
| else: |
| anomaly_loss = torch.tensor(0.0, device=device) |
| loss = ce_loss + consistency_loss + 0.1 * anomaly_loss |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| total_loss += loss.item() |
| probs = torch.softmax(logits, dim=1).detach().cpu().numpy() |
| all_preds.append(logits.detach().cpu().numpy()) |
| all_labels.append(labels.cpu().numpy()) |
| all_probs.append(probs) |
| pbar.set_postfix({'loss': loss.item()}) |
| all_preds = np.concatenate(all_preds) |
| all_labels = np.concatenate(all_labels) |
| all_probs = np.concatenate(all_probs) |
| metrics = compute_metrics(all_preds, all_labels, all_probs) |
| metrics['loss'] = total_loss / len(dataloader) |
| return metrics |
|
|
|
|
| def evaluate(model, dataloader, device): |
| model.eval() |
| total_loss = 0 |
| all_preds, all_labels, all_probs = [], [], [] |
| all_embeddings, all_anomaly = [], [] |
| with torch.no_grad(): |
| for batch in tqdm(dataloader, desc="Evaluating"): |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| tabular = batch['tabular_features'].to(device) |
| metadata = batch['metadata_features'].to(device) |
| labels = batch['labels'].to(device) |
| outputs = model(text_input_ids=input_ids, text_attention_mask=attention_mask, |
| tabular_features=tabular, metadata_features=metadata) |
| logits = outputs['logits'] |
| loss = nn.CrossEntropyLoss()(logits, labels) |
| total_loss += loss.item() |
| probs = torch.softmax(logits, dim=1).cpu().numpy() |
| all_preds.append(logits.cpu().numpy()) |
| all_labels.append(labels.cpu().numpy()) |
| all_probs.append(probs) |
| all_embeddings.append(outputs['fused_embedding'].cpu().numpy()) |
| all_anomaly.append(outputs['anomaly_score'].cpu().numpy()) |
| all_preds = np.concatenate(all_preds) |
| all_labels = np.concatenate(all_labels) |
| all_probs = np.concatenate(all_probs) |
| all_embeddings = np.concatenate(all_embeddings) |
| all_anomaly = np.concatenate(all_anomaly) |
| metrics = compute_metrics(all_preds, all_labels, all_probs) |
| metrics['loss'] = total_loss / len(dataloader) |
| return metrics, all_embeddings, all_anomaly |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("MULTIMODAL FRAUD DETECTION - TRAINING") |
| print("=" * 60) |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| |
| output_dir = './outputs' |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| print("\nLoading dataset...") |
| dataset = FraudPaperDataset("Lihuchen/pubmed_retraction", split="train", max_length=256) |
| |
| |
| train_size = int(0.8 * len(dataset)) |
| val_size = len(dataset) - train_size |
| train_ds, val_ds = random_split(dataset, [train_size, val_size]) |
| |
| train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, collate_fn=collate_fn) |
| val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2, collate_fn=collate_fn) |
| |
| print(f"Train: {len(train_ds)}, Val: {len(val_ds)}") |
| |
| |
| sample = next(iter(train_loader)) |
| tabular_dim = sample['tabular_features'].shape[1] |
| metadata_dim = sample['metadata_features'].shape[1] |
| print(f"Tabular: {tabular_dim}, Metadata: {metadata_dim}") |
| |
| |
| print("\nBuilding model...") |
| model = MultimodalFraudDetector( |
| text_model="allenai/scibert_scivocab_uncased", |
| tabular_features=tabular_dim, |
| metadata_features=metadata_dim, |
| fused_dim=256, |
| freeze_text_layers=8 |
| ).to(device) |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Total params: {total_params:,}, Trainable: {trainable:,}") |
| |
| |
| optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) |
| total_steps = len(train_loader) * 3 |
| warmup = int(total_steps * 0.1) |
| scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup, num_training_steps=total_steps) |
| |
| |
| best_f1 = 0 |
| for epoch in range(1, 4): |
| print(f"\n=== Epoch {epoch}/3 ===") |
| train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device, epoch) |
| print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, F1: {train_metrics['f1']:.4f}") |
| val_metrics, val_emb, val_anom = evaluate(model, val_loader, device) |
| print(f"Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['auc']:.4f}") |
| |
| if val_metrics['f1'] > best_f1: |
| best_f1 = val_metrics['f1'] |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'f1': best_f1, |
| }, os.path.join(output_dir, 'best_model.pt')) |
| print(f"Saved best model (F1: {best_f1:.4f})") |
| |
| |
| np.save(os.path.join(output_dir, 'val_embeddings.npy'), val_emb) |
| np.save(os.path.join(output_dir, 'val_anomaly.npy'), val_anom) |
| |
| print(f"\nTraining complete! Best F1: {best_f1:.4f}") |
| print(f"Outputs saved to {output_dir}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|