| """ |
| Data loading and preprocessing for multimodal fraud detection. |
| """ |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
| import numpy as np |
| import re |
|
|
|
|
| class FraudPaperDataset(Dataset): |
| """ |
| Dataset for multimodal fraudulent paper detection. |
| Supports text, image, tabular, and metadata modalities. |
| """ |
| |
| def __init__(self, hf_dataset_name="Lihuchen/pubmed_retraction", split="train", |
| text_tokenizer=None, max_length=512, text_column="Abstract", |
| title_column="Title", label_column="IsRetracted"): |
| self.dataset = load_dataset(hf_dataset_name, split=split) |
| self.tokenizer = text_tokenizer or AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased") |
| self.max_length = max_length |
| self.text_column = text_column |
| self.title_column = title_column |
| self.label_column = label_column |
| self.labels = self._extract_labels() |
| self.tabular_features = self._extract_tabular_features() |
| self.metadata_features = self._extract_metadata_features() |
|
|
| def _extract_labels(self): |
| labels = [] |
| for item in self.dataset: |
| label = item.get(self.label_column, "No") |
| if isinstance(label, str): |
| labels.append(1 if label.lower() in ['yes', 'true', '1', 'retracted'] else 0) |
| else: |
| labels.append(int(label)) |
| return np.array(labels) |
|
|
| def _extract_tabular_features(self): |
| features = [] |
| for item in self.dataset: |
| feat = [] |
| text = str(item.get(self.text_column, '')) |
| feat.append(len(text)) |
| feat.append(len(text.split())) |
| feat.append(len(str(item.get(self.title_column, '')).split())) |
| authors = str(item.get('Authors', '')) |
| feat.append(len(authors.split(';'))) |
| feat.append(len(authors)) |
| journal = str(item.get('Journal', '')) |
| feat.append(len(journal)) |
| pub_date = str(item.get('PublicationDate', '')) |
| year_match = re.search(r'(\d{4})', pub_date) |
| feat.append(int(year_match.group(1)) if year_match else 2020) |
| features.append(feat) |
| return np.array(features, dtype=np.float32) |
|
|
| def _extract_metadata_features(self): |
| features = [] |
| for item in self.dataset: |
| feat = [] |
| cited_by = item.get('cited_by_count', 0) or 0 |
| references = item.get('reference_count', 0) or 0 |
| feat.append(float(cited_by)) |
| feat.append(float(references)) |
| feat.append(float(cited_by) / max(float(references), 1)) |
| n_grants = item.get('n_grants', 0) or 0 |
| feat.append(float(n_grants)) |
| affiliations = str(item.get('affiliations', item.get('Institutions', ''))) |
| feat.append(len(affiliations.split(';'))) |
| feat.append(len(affiliations)) |
| country = str(item.get('country', '')) |
| feat.append(hash(country) % 1000 / 1000.0) |
| pub_types = str(item.get('publication_types', '')) |
| feat.append(1.0 if 'Review' in pub_types else 0.0) |
| feat.append(1.0 if 'Research' in pub_types else 0.0) |
| journal = str(item.get('Journal', '')) |
| feat.append(1.0 if 'Q1' in journal else 0.0) |
| feat.append(1.0 if 'Top' in journal else 0.0) |
| features.append(feat) |
| return np.array(features, dtype=np.float32) |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| item = self.dataset[idx] |
| title = str(item.get(self.title_column, '')) |
| abstract = str(item.get(self.text_column, '')) |
| text = f"{title} [SEP] {abstract}" |
| text_encoding = self.tokenizer(text, max_length=self.max_length, |
| padding='max_length', truncation=True, return_tensors='pt') |
| return { |
| 'input_ids': text_encoding['input_ids'].squeeze(0), |
| 'attention_mask': text_encoding['attention_mask'].squeeze(0), |
| 'tabular_features': torch.tensor(self.tabular_features[idx], dtype=torch.float32), |
| 'metadata_features': torch.tensor(self.metadata_features[idx], dtype=torch.float32), |
| 'labels': torch.tensor(self.labels[idx], dtype=torch.long), |
| 'paper_id': str(item.get('PID', idx)) |
| } |
|
|
|
|
| def collate_fn(batch): |
| keys = batch[0].keys() |
| result = {} |
| for key in keys: |
| if key in ['input_ids', 'attention_mask', 'labels']: |
| result[key] = torch.stack([item[key] for item in batch]) |
| elif key in ['tabular_features', 'metadata_features']: |
| result[key] = torch.stack([item[key] for item in batch]) |
| elif key == 'paper_id': |
| result[key] = [item[key] for item in batch] |
| else: |
| result[key] = [item[key] for item in batch] |
| return result |
|
|