File size: 4,989 Bytes
e57d1ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Data loading and preprocessing for multimodal fraud detection.
"""

import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
import re


class FraudPaperDataset(Dataset):
    """
    Dataset for multimodal fraudulent paper detection.
    Supports text, image, tabular, and metadata modalities.
    """
    
    def __init__(self, hf_dataset_name="Lihuchen/pubmed_retraction", split="train",
                 text_tokenizer=None, max_length=512, text_column="Abstract",
                 title_column="Title", label_column="IsRetracted"):
        self.dataset = load_dataset(hf_dataset_name, split=split)
        self.tokenizer = text_tokenizer or AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
        self.max_length = max_length
        self.text_column = text_column
        self.title_column = title_column
        self.label_column = label_column
        self.labels = self._extract_labels()
        self.tabular_features = self._extract_tabular_features()
        self.metadata_features = self._extract_metadata_features()

    def _extract_labels(self):
        labels = []
        for item in self.dataset:
            label = item.get(self.label_column, "No")
            if isinstance(label, str):
                labels.append(1 if label.lower() in ['yes', 'true', '1', 'retracted'] else 0)
            else:
                labels.append(int(label))
        return np.array(labels)

    def _extract_tabular_features(self):
        features = []
        for item in self.dataset:
            feat = []
            text = str(item.get(self.text_column, ''))
            feat.append(len(text))
            feat.append(len(text.split()))
            feat.append(len(str(item.get(self.title_column, '')).split()))
            authors = str(item.get('Authors', ''))
            feat.append(len(authors.split(';')))
            feat.append(len(authors))
            journal = str(item.get('Journal', ''))
            feat.append(len(journal))
            pub_date = str(item.get('PublicationDate', ''))
            year_match = re.search(r'(\d{4})', pub_date)
            feat.append(int(year_match.group(1)) if year_match else 2020)
            features.append(feat)
        return np.array(features, dtype=np.float32)

    def _extract_metadata_features(self):
        features = []
        for item in self.dataset:
            feat = []
            cited_by = item.get('cited_by_count', 0) or 0
            references = item.get('reference_count', 0) or 0
            feat.append(float(cited_by))
            feat.append(float(references))
            feat.append(float(cited_by) / max(float(references), 1))
            n_grants = item.get('n_grants', 0) or 0
            feat.append(float(n_grants))
            affiliations = str(item.get('affiliations', item.get('Institutions', '')))
            feat.append(len(affiliations.split(';')))
            feat.append(len(affiliations))
            country = str(item.get('country', ''))
            feat.append(hash(country) % 1000 / 1000.0)
            pub_types = str(item.get('publication_types', ''))
            feat.append(1.0 if 'Review' in pub_types else 0.0)
            feat.append(1.0 if 'Research' in pub_types else 0.0)
            journal = str(item.get('Journal', ''))
            feat.append(1.0 if 'Q1' in journal else 0.0)
            feat.append(1.0 if 'Top' in journal else 0.0)
            features.append(feat)
        return np.array(features, dtype=np.float32)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        title = str(item.get(self.title_column, ''))
        abstract = str(item.get(self.text_column, ''))
        text = f"{title} [SEP] {abstract}"
        text_encoding = self.tokenizer(text, max_length=self.max_length,
            padding='max_length', truncation=True, return_tensors='pt')
        return {
            'input_ids': text_encoding['input_ids'].squeeze(0),
            'attention_mask': text_encoding['attention_mask'].squeeze(0),
            'tabular_features': torch.tensor(self.tabular_features[idx], dtype=torch.float32),
            'metadata_features': torch.tensor(self.metadata_features[idx], dtype=torch.float32),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
            'paper_id': str(item.get('PID', idx))
        }


def collate_fn(batch):
    keys = batch[0].keys()
    result = {}
    for key in keys:
        if key in ['input_ids', 'attention_mask', 'labels']:
            result[key] = torch.stack([item[key] for item in batch])
        elif key in ['tabular_features', 'metadata_features']:
            result[key] = torch.stack([item[key] for item in batch])
        elif key == 'paper_id':
            result[key] = [item[key] for item in batch]
        else:
            result[key] = [item[key] for item in batch]
    return result