multimodal-fraud-detection / data_loader.py
pangweijlu's picture
Upload data_loader.py
e57d1ae verified
"""
Data loading and preprocessing for multimodal fraud detection.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
import re
class FraudPaperDataset(Dataset):
"""
Dataset for multimodal fraudulent paper detection.
Supports text, image, tabular, and metadata modalities.
"""
def __init__(self, hf_dataset_name="Lihuchen/pubmed_retraction", split="train",
text_tokenizer=None, max_length=512, text_column="Abstract",
title_column="Title", label_column="IsRetracted"):
self.dataset = load_dataset(hf_dataset_name, split=split)
self.tokenizer = text_tokenizer or AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
self.max_length = max_length
self.text_column = text_column
self.title_column = title_column
self.label_column = label_column
self.labels = self._extract_labels()
self.tabular_features = self._extract_tabular_features()
self.metadata_features = self._extract_metadata_features()
def _extract_labels(self):
labels = []
for item in self.dataset:
label = item.get(self.label_column, "No")
if isinstance(label, str):
labels.append(1 if label.lower() in ['yes', 'true', '1', 'retracted'] else 0)
else:
labels.append(int(label))
return np.array(labels)
def _extract_tabular_features(self):
features = []
for item in self.dataset:
feat = []
text = str(item.get(self.text_column, ''))
feat.append(len(text))
feat.append(len(text.split()))
feat.append(len(str(item.get(self.title_column, '')).split()))
authors = str(item.get('Authors', ''))
feat.append(len(authors.split(';')))
feat.append(len(authors))
journal = str(item.get('Journal', ''))
feat.append(len(journal))
pub_date = str(item.get('PublicationDate', ''))
year_match = re.search(r'(\d{4})', pub_date)
feat.append(int(year_match.group(1)) if year_match else 2020)
features.append(feat)
return np.array(features, dtype=np.float32)
def _extract_metadata_features(self):
features = []
for item in self.dataset:
feat = []
cited_by = item.get('cited_by_count', 0) or 0
references = item.get('reference_count', 0) or 0
feat.append(float(cited_by))
feat.append(float(references))
feat.append(float(cited_by) / max(float(references), 1))
n_grants = item.get('n_grants', 0) or 0
feat.append(float(n_grants))
affiliations = str(item.get('affiliations', item.get('Institutions', '')))
feat.append(len(affiliations.split(';')))
feat.append(len(affiliations))
country = str(item.get('country', ''))
feat.append(hash(country) % 1000 / 1000.0)
pub_types = str(item.get('publication_types', ''))
feat.append(1.0 if 'Review' in pub_types else 0.0)
feat.append(1.0 if 'Research' in pub_types else 0.0)
journal = str(item.get('Journal', ''))
feat.append(1.0 if 'Q1' in journal else 0.0)
feat.append(1.0 if 'Top' in journal else 0.0)
features.append(feat)
return np.array(features, dtype=np.float32)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
title = str(item.get(self.title_column, ''))
abstract = str(item.get(self.text_column, ''))
text = f"{title} [SEP] {abstract}"
text_encoding = self.tokenizer(text, max_length=self.max_length,
padding='max_length', truncation=True, return_tensors='pt')
return {
'input_ids': text_encoding['input_ids'].squeeze(0),
'attention_mask': text_encoding['attention_mask'].squeeze(0),
'tabular_features': torch.tensor(self.tabular_features[idx], dtype=torch.float32),
'metadata_features': torch.tensor(self.metadata_features[idx], dtype=torch.float32),
'labels': torch.tensor(self.labels[idx], dtype=torch.long),
'paper_id': str(item.get('PID', idx))
}
def collate_fn(batch):
keys = batch[0].keys()
result = {}
for key in keys:
if key in ['input_ids', 'attention_mask', 'labels']:
result[key] = torch.stack([item[key] for item in batch])
elif key in ['tabular_features', 'metadata_features']:
result[key] = torch.stack([item[key] for item in batch])
elif key == 'paper_id':
result[key] = [item[key] for item in batch]
else:
result[key] = [item[key] for item in batch]
return result