pangweijlu commited on
Commit
5fc37d4
·
verified ·
1 Parent(s): aa7ac9b

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +127 -0
inference.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for multimodal fraudulent paper detection.
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ import torch
8
+ import numpy as np
9
+ from transformers import AutoTokenizer
10
+ import argparse
11
+ import json
12
+
13
+ from model import MultimodalFraudDetector
14
+
15
+
16
+ def predict_fraud(model, tokenizer, text, tabular, metadata, device):
17
+ """Predict fraud probability for a single paper."""
18
+ model.eval()
19
+
20
+ # Tokenize text
21
+ encoding = tokenizer(
22
+ text,
23
+ max_length=512,
24
+ padding='max_length',
25
+ truncation=True,
26
+ return_tensors='pt'
27
+ )
28
+
29
+ input_ids = encoding['input_ids'].to(device)
30
+ attention_mask = encoding['attention_mask'].to(device)
31
+ tabular = torch.tensor(tabular, dtype=torch.float32).unsqueeze(0).to(device)
32
+ metadata = torch.tensor(metadata, dtype=torch.float32).unsqueeze(0).to(device)
33
+
34
+ with torch.no_grad():
35
+ outputs = model(
36
+ text_input_ids=input_ids,
37
+ text_attention_mask=attention_mask,
38
+ tabular_features=tabular,
39
+ metadata_features=metadata
40
+ )
41
+
42
+ logits = outputs['logits']
43
+ probs = torch.softmax(logits, dim=1)
44
+ fraud_prob = probs[0, 1].item()
45
+
46
+ modality_scores = outputs['modality_scores'][0].cpu().numpy()
47
+ anomaly_score = outputs['anomaly_score'][0].item()
48
+
49
+ return {
50
+ 'fraud_probability': fraud_prob,
51
+ 'is_fraudulent': fraud_prob > 0.5,
52
+ 'modality_contributions': {
53
+ 'text': float(modality_scores[0]),
54
+ 'image': float(modality_scores[1]),
55
+ 'tabular': float(modality_scores[2]),
56
+ 'metadata': float(modality_scores[3])
57
+ },
58
+ 'anomaly_score': anomaly_score
59
+ }
60
+
61
+
62
+ def explain_prediction(result):
63
+ """Generate human-readable explanation."""
64
+ explanations = []
65
+
66
+ if result['fraud_probability'] > 0.5:
67
+ explanations.append(f"FRAUDULENT (probability: {result['fraud_probability']:.2%})")
68
+ else:
69
+ explanations.append(f"AUTHENTIC (fraud probability: {result['fraud_probability']:.2%})")
70
+
71
+ # Modality contributions
72
+ contrib = result['modality_contributions']
73
+ max_modality = max(contrib, key=contrib.get)
74
+ explanations.append(f"Primary fraud indicator: {max_modality} modality (score: {contrib[max_modality]:.3f})")
75
+
76
+ if result['anomaly_score'] > 0.7:
77
+ explanations.append(f"High anomaly score ({result['anomaly_score']:.3f}): Paper shows strong outlier patterns")
78
+
79
+ return "\n".join(explanations)
80
+
81
+
82
+ def main():
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument('--model_path', required=True)
85
+ parser.add_argument('--text', default='')
86
+ parser.add_argument('--title', default='')
87
+ parser.add_argument('--output', default='prediction.json')
88
+ args = parser.parse_args()
89
+
90
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
+
92
+ # Load model
93
+ checkpoint = torch.load(args.model_path, map_location=device)
94
+ model_args = checkpoint.get('args', {})
95
+
96
+ model = MultimodalFraudDetector(
97
+ text_model=model_args.get('text_model', 'allenai/scibert_scivocab_uncased'),
98
+ tabular_features=10,
99
+ metadata_features=12
100
+ ).to(device)
101
+
102
+ model.load_state_dict(checkpoint['model_state_dict'])
103
+ model.eval()
104
+
105
+ tokenizer = AutoTokenizer.from_pretrained(model_args.get('text_model'))
106
+
107
+ # Prepare input
108
+ text = f"{args.title} [SEP] {args.text}"
109
+
110
+ # Dummy features for demo (in production, extract from actual paper)
111
+ tabular = np.random.randn(10).astype(np.float32)
112
+ metadata = np.random.randn(12).astype(np.float32)
113
+
114
+ # Predict
115
+ result = predict_fraud(model, tokenizer, text, tabular, metadata, device)
116
+ result['explanation'] = explain_prediction(result)
117
+
118
+ print(result['explanation'])
119
+
120
+ with open(args.output, 'w') as f:
121
+ json.dump(result, f, indent=2)
122
+
123
+ print(f"\nSaved to {args.output}")
124
+
125
+
126
+ if __name__ == '__main__':
127
+ main()