""" Training script for multimodal fraudulent paper detection. """ import os import sys import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split from transformers import get_linear_schedule_with_warmup import numpy as np from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score from tqdm import tqdm import json from model import MultimodalFraudDetector from data_loader import FraudPaperDataset, collate_fn def compute_metrics(predictions, labels, probs): preds = np.argmax(predictions, axis=1) accuracy = accuracy_score(labels, preds) precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0) try: auc = roc_auc_score(labels, probs[:, 1]) except: auc = 0.5 return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc} def train_epoch(model, dataloader, optimizer, scheduler, device, epoch): model.train() total_loss = 0 all_preds, all_labels, all_probs = [], [], [] pbar = tqdm(dataloader, desc=f"Epoch {epoch}") for batch in pbar: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) tabular = batch['tabular_features'].to(device) metadata = batch['metadata_features'].to(device) labels = batch['labels'].to(device) outputs = model(text_input_ids=input_ids, text_attention_mask=attention_mask, tabular_features=tabular, metadata_features=metadata) logits = outputs['logits'] modality_scores = outputs['modality_scores'] anomaly_score = outputs['anomaly_score'] ce_loss = nn.CrossEntropyLoss()(logits, labels) consistency_loss = torch.mean((modality_scores - 0.5) ** 2) * 0.1 fraud_mask = labels == 1 if fraud_mask.any(): anomaly_loss = torch.mean((anomaly_score[fraud_mask] - 1.0) ** 2) anomaly_loss += torch.mean((anomaly_score[~fraud_mask] - 0.0) ** 2) else: anomaly_loss = torch.tensor(0.0, device=device) loss = ce_loss + consistency_loss + 0.1 * anomaly_loss optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() total_loss += loss.item() probs = torch.softmax(logits, dim=1).detach().cpu().numpy() all_preds.append(logits.detach().cpu().numpy()) all_labels.append(labels.cpu().numpy()) all_probs.append(probs) pbar.set_postfix({'loss': loss.item()}) all_preds = np.concatenate(all_preds) all_labels = np.concatenate(all_labels) all_probs = np.concatenate(all_probs) metrics = compute_metrics(all_preds, all_labels, all_probs) metrics['loss'] = total_loss / len(dataloader) return metrics def evaluate(model, dataloader, device): model.eval() total_loss = 0 all_preds, all_labels, all_probs = [], [], [] all_embeddings, all_anomaly = [], [] with torch.no_grad(): for batch in tqdm(dataloader, desc="Evaluating"): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) tabular = batch['tabular_features'].to(device) metadata = batch['metadata_features'].to(device) labels = batch['labels'].to(device) outputs = model(text_input_ids=input_ids, text_attention_mask=attention_mask, tabular_features=tabular, metadata_features=metadata) logits = outputs['logits'] loss = nn.CrossEntropyLoss()(logits, labels) total_loss += loss.item() probs = torch.softmax(logits, dim=1).cpu().numpy() all_preds.append(logits.cpu().numpy()) all_labels.append(labels.cpu().numpy()) all_probs.append(probs) all_embeddings.append(outputs['fused_embedding'].cpu().numpy()) all_anomaly.append(outputs['anomaly_score'].cpu().numpy()) all_preds = np.concatenate(all_preds) all_labels = np.concatenate(all_labels) all_probs = np.concatenate(all_probs) all_embeddings = np.concatenate(all_embeddings) all_anomaly = np.concatenate(all_anomaly) metrics = compute_metrics(all_preds, all_labels, all_probs) metrics['loss'] = total_loss / len(dataloader) return metrics, all_embeddings, all_anomaly def main(): print("=" * 60) print("MULTIMODAL FRAUD DETECTION - TRAINING") print("=" * 60) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") output_dir = './outputs' os.makedirs(output_dir, exist_ok=True) # Load data print("\nLoading dataset...") dataset = FraudPaperDataset("Lihuchen/pubmed_retraction", split="train", max_length=256) # Split train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_ds, val_ds = random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, collate_fn=collate_fn) val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2, collate_fn=collate_fn) print(f"Train: {len(train_ds)}, Val: {len(val_ds)}") # Get dims sample = next(iter(train_loader)) tabular_dim = sample['tabular_features'].shape[1] metadata_dim = sample['metadata_features'].shape[1] print(f"Tabular: {tabular_dim}, Metadata: {metadata_dim}") # Model print("\nBuilding model...") model = MultimodalFraudDetector( text_model="allenai/scibert_scivocab_uncased", tabular_features=tabular_dim, metadata_features=metadata_dim, fused_dim=256, freeze_text_layers=8 ).to(device) total_params = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total params: {total_params:,}, Trainable: {trainable:,}") # Optimizer optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) total_steps = len(train_loader) * 3 warmup = int(total_steps * 0.1) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup, num_training_steps=total_steps) # Train best_f1 = 0 for epoch in range(1, 4): print(f"\n=== Epoch {epoch}/3 ===") train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device, epoch) print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, F1: {train_metrics['f1']:.4f}") val_metrics, val_emb, val_anom = evaluate(model, val_loader, device) print(f"Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['auc']:.4f}") if val_metrics['f1'] > best_f1: best_f1 = val_metrics['f1'] torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'f1': best_f1, }, os.path.join(output_dir, 'best_model.pt')) print(f"Saved best model (F1: {best_f1:.4f})") # Save embeddings np.save(os.path.join(output_dir, 'val_embeddings.npy'), val_emb) np.save(os.path.join(output_dir, 'val_anomaly.npy'), val_anom) print(f"\nTraining complete! Best F1: {best_f1:.4f}") print(f"Outputs saved to {output_dir}") if __name__ == '__main__': main()