Esvanth commited on
Commit
202132a
Β·
verified Β·
1 Parent(s): 78264fe

Add predict.py

Browse files
Files changed (1) hide show
  1. predict.py +321 -0
predict.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MindScan β€” Prediction Logic
3
+ NCI H9DAI Research Project 2026
4
+
5
+ All model loading and prediction functions.
6
+ Imported by app.py β€” do not run directly.
7
+
8
+ Datasets:
9
+ D1 β€” Zenodo (Nusrat 2024) β€” 6-class depression type
10
+ D2 β€” Kaggle (albertobellardini) β€” binary depression (labels: '0'/'1')
11
+ D3 β€” Kaggle (nikhileswarkomati) β€” binary suicide risk
12
+
13
+ Models per dataset:
14
+ Logistic Regression, SVM, XGBoost, XLM-RoBERTa
15
+ (Random Forest excluded β€” 646 MB, worst performer on D1/D3)
16
+ """
17
+
18
+ import os, re, string, joblib
19
+ import numpy as np
20
+
21
+ # ─────────────────────────────────────────────────────────────────
22
+ # PATHS
23
+ # ─────────────────────────────────────────────────────────────────
24
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
+ CLASSICAL_DIR = os.path.join(BASE_DIR, 'models', 'classical')
26
+ TRANSFORMER_DIR = os.path.join(BASE_DIR, 'models', 'transformers')
27
+
28
+ # If transformers aren't present locally, fetch them from the HF model repo.
29
+ # Used on HF Spaces where only app/classical are pushed and heavy weights live
30
+ # in a separate model repo to avoid Space LFS limits.
31
+ HF_XLMR_REPO = "Esvanth/mindscan-xlmr"
32
+
33
+ # ─────────────────────────────────────────────────────────────────
34
+ # D2 LABEL MAPPING
35
+ # The dataset uses '0' and '1' as labels.
36
+ # We map them to human-readable strings for the UI.
37
+ # ─────────────────────────────────────────────────────────────────
38
+ D2_LABEL_MAP = {
39
+ '0': 'Not Depressed',
40
+ '1': 'Depressed',
41
+ 0: 'Not Depressed',
42
+ 1: 'Depressed',
43
+ }
44
+
45
+ # ─────────────────────────────────────────────────────────────────
46
+ # MODEL STORAGE β€” populated by load_all_models()
47
+ # ─────────────────────────────────────────────────────────────────
48
+ _models = {}
49
+ _loaded = False
50
+
51
+
52
+ def models_loaded():
53
+ return _loaded
54
+
55
+
56
+ def load_all_models():
57
+ """
58
+ Loads all 12 models (4 per dataset Γ— 3 datasets) into memory.
59
+ Called once at server startup. Takes ~30s on CPU due to XLM-RoBERTa.
60
+ """
61
+ global _loaded
62
+
63
+ # ── Classical support files ───────────────────────────────────
64
+ for ds in ['d1', 'd2', 'd3']:
65
+ _models[f'le_{ds}'] = joblib.load(os.path.join(CLASSICAL_DIR, f'le_{ds}.pkl'))
66
+ _models[f'tfidf_{ds}'] = joblib.load(os.path.join(CLASSICAL_DIR, f'tfidf_{ds}.pkl'))
67
+ print(f" βœ“ Loaded encoders/tfidf for {ds}")
68
+
69
+ # ── Classical models ──────────────────────────────────────────
70
+ for model_name in ['logistic_regression', 'svm', 'xgboost']:
71
+ for ds in ['d1', 'd2', 'd3']:
72
+ key = f'{model_name}_{ds}'
73
+ path = os.path.join(CLASSICAL_DIR, f'{key}.pkl')
74
+ _models[key] = joblib.load(path)
75
+ print(f" βœ“ Loaded {key}")
76
+
77
+ # ── XLM-RoBERTa transformers ──────────────────────────────────
78
+ try:
79
+ import torch
80
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
81
+
82
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
83
+ _models['device'] = device
84
+ print(f" βœ“ Using device: {device}")
85
+
86
+ # On HF Spaces the weights aren't bundled with the app β€” fetch them
87
+ # from the model repo into TRANSFORMER_DIR on first startup.
88
+ d1_local = os.path.join(TRANSFORMER_DIR, 'xlmr_d1_final')
89
+ if not os.path.isdir(d1_local):
90
+ from huggingface_hub import snapshot_download
91
+ print(f" ↓ Downloading transformers from {HF_XLMR_REPO} ...")
92
+ snapshot_download(
93
+ repo_id=HF_XLMR_REPO,
94
+ repo_type="model",
95
+ local_dir=TRANSFORMER_DIR,
96
+ local_dir_use_symlinks=False,
97
+ )
98
+ print(" βœ“ Transformers downloaded")
99
+
100
+ # Shared tokenizer (all 3 models use the same base tokeniser)
101
+ tokenizer_path = os.path.join(TRANSFORMER_DIR, 'xlmr_d1_final')
102
+ _models['tokenizer'] = AutoTokenizer.from_pretrained(tokenizer_path)
103
+ print(" βœ“ Tokeniser loaded")
104
+
105
+ for ds, max_len in [('d1', 128), ('d2', 128), ('d3', 256)]:
106
+ folder = os.path.join(TRANSFORMER_DIR, f'xlmr_{ds}_final')
107
+ model = AutoModelForSequenceClassification.from_pretrained(folder)
108
+ model = model.to(device)
109
+ model.eval()
110
+ _models[f'xlmr_{ds}'] = model
111
+ _models[f'xlmr_{ds}_len'] = max_len
112
+ print(f" βœ“ Loaded XLM-RoBERTa {ds} (max_length={max_len})")
113
+
114
+ except Exception as e:
115
+ print(f" ⚠ XLM-RoBERTa failed to load: {e}")
116
+ print(" Classical models will still work.")
117
+
118
+ _loaded = True
119
+ print(" βœ… All models ready")
120
+
121
+
122
+ # ─────────────────────────────────────────────────────────────────
123
+ # TEXT CLEANING β€” same function used in both notebooks
124
+ # ─────────────────────────────────────────────────────────────────
125
+ def clean_text(text):
126
+ text = str(text).lower()
127
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text)
128
+ text = re.sub(r'@\w+', '', text)
129
+ text = re.sub(r'#', '', text)
130
+ text = text.translate(str.maketrans('', '', string.punctuation))
131
+ text = re.sub(r'\s+', ' ', text).strip()
132
+ return text
133
+
134
+
135
+ # ─────────────────────────────────────────────────────────────────
136
+ # PREDICTION HELPERS
137
+ # ─────────────────────────────────────────────────────────────────
138
+ def predict_classical(text_clean, ds):
139
+ """
140
+ Runs text through the 3 classical models for one dataset.
141
+ Returns dict: { model_name: {label, confidence} }
142
+ """
143
+ tfidf = _models[f'tfidf_{ds}']
144
+ le = _models[f'le_{ds}']
145
+ vec = tfidf.transform([text_clean])
146
+
147
+ results = {}
148
+ display_names = {
149
+ 'logistic_regression': 'Logistic Regression',
150
+ 'svm': 'SVM',
151
+ 'xgboost': 'XGBoost',
152
+ }
153
+
154
+ for key, display in display_names.items():
155
+ model = _models[f'{key}_{ds}']
156
+ pred_idx = model.predict(vec)[0]
157
+ raw_label = le.classes_[pred_idx]
158
+
159
+ # Map D2 numeric labels to readable strings
160
+ if ds == 'd2':
161
+ label = D2_LABEL_MAP.get(raw_label, str(raw_label))
162
+ else:
163
+ label = str(raw_label)
164
+
165
+ # Confidence: predict_proba if available, else softmax of decision_function
166
+ if hasattr(model, 'predict_proba'):
167
+ conf = float(model.predict_proba(vec)[0][pred_idx])
168
+ elif hasattr(model, 'decision_function'):
169
+ scores = model.decision_function(vec)[0]
170
+ if np.ndim(scores) == 0:
171
+ scores = np.array([float(-scores), float(scores)])
172
+ e = np.exp(scores - scores.max())
173
+ conf = float(e[pred_idx] / e.sum())
174
+ else:
175
+ conf = 1.0
176
+
177
+ results[display] = {
178
+ 'label': label,
179
+ 'confidence': round(conf, 4),
180
+ }
181
+
182
+ return results
183
+
184
+
185
+ def predict_transformer(text_raw, ds):
186
+ """
187
+ Runs text through XLM-RoBERTa for one dataset.
188
+ Returns { label, confidence, all_probs }
189
+ all_probs = { class_name: probability } for all classes.
190
+ Used for the class breakdown bars in the UI.
191
+ """
192
+ if f'xlmr_{ds}' not in _models:
193
+ return None
194
+
195
+ import torch
196
+
197
+ model = _models[f'xlmr_{ds}']
198
+ tok = _models['tokenizer']
199
+ le = _models[f'le_{ds}']
200
+ max_len = _models[f'xlmr_{ds}_len']
201
+ device = _models.get('device', 'cpu')
202
+
203
+ inputs = tok(
204
+ text_raw,
205
+ return_tensors='pt',
206
+ max_length=max_len,
207
+ truncation=True,
208
+ padding='max_length'
209
+ ).to(device)
210
+
211
+ with torch.no_grad():
212
+ logits = model(**inputs).logits
213
+
214
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
215
+ pred_idx = int(probs.argmax())
216
+ raw_label = le.classes_[pred_idx]
217
+
218
+ if ds == 'd2':
219
+ label = D2_LABEL_MAP.get(raw_label, str(raw_label))
220
+ else:
221
+ label = str(raw_label)
222
+
223
+ # Build all_probs dict with readable labels
224
+ all_probs = {}
225
+ for i, p in enumerate(probs):
226
+ raw = le.classes_[i]
227
+ readable = D2_LABEL_MAP.get(raw, str(raw)) if ds == 'd2' else str(raw)
228
+ all_probs[readable] = round(float(p), 4)
229
+
230
+ return {
231
+ 'label': label,
232
+ 'confidence': round(float(probs[pred_idx]), 4),
233
+ 'all_probs': all_probs,
234
+ }
235
+
236
+
237
+ # ─────────────────────────────────────────────────────────────────
238
+ # MAIN FUNCTION β€” called by Flask /predict endpoint
239
+ # ─────────────────────────────────────────────────────────────────
240
+ def predict_all(raw_text):
241
+ """
242
+ Runs text through all 12 models across 3 datasets.
243
+
244
+ Returns dict:
245
+ {
246
+ dataset1: {
247
+ task, models: {LR, SVM, XGBoost, XLM-RoBERTa},
248
+ winner_model, winner_prediction, winner_confidence,
249
+ class_probs ← only D1, 6-class breakdown from XLM-RoBERTa
250
+ },
251
+ dataset2: { same structure, D2 labels mapped to readable strings },
252
+ dataset3: { same structure },
253
+ risk_flag: bool, ← True if β‰₯3 of 4 D3 models say "suicide"
254
+ suicide_votes: "N/4 models flagged suicide risk",
255
+ winner_summary: { depression_type, depressed, suicide_risk }
256
+ }
257
+ """
258
+ clean = clean_text(raw_text)
259
+
260
+ # ── Dataset 1: Depression type ────────────────────────────────
261
+ d1 = predict_classical(clean, 'd1')
262
+ xlmr1 = predict_transformer(raw_text, 'd1')
263
+ if xlmr1:
264
+ d1['XLM-RoBERTa'] = {k: xlmr1[k] for k in ('label','confidence')}
265
+
266
+ d1_winner = max(d1.items(), key=lambda x: x[1]['confidence'])
267
+
268
+ # ── Dataset 2: Binary depression ─────────────────────────────
269
+ d2 = predict_classical(clean, 'd2')
270
+ xlmr2 = predict_transformer(raw_text, 'd2')
271
+ if xlmr2:
272
+ d2['XLM-RoBERTa'] = {k: xlmr2[k] for k in ('label','confidence')}
273
+
274
+ d2_winner = max(d2.items(), key=lambda x: x[1]['confidence'])
275
+
276
+ # ── Dataset 3: Suicide risk ───────────────────────────────────
277
+ d3 = predict_classical(clean, 'd3')
278
+ xlmr3 = predict_transformer(raw_text, 'd3')
279
+ if xlmr3:
280
+ d3['XLM-RoBERTa'] = {k: xlmr3[k] for k in ('label','confidence')}
281
+
282
+ d3_winner = max(d3.items(), key=lambda x: x[1]['confidence'])
283
+
284
+ # ── Suicide risk flag β€” majority vote across 4 D3 models ─────
285
+ suicide_count = sum(
286
+ 1 for r in d3.values()
287
+ if 'suicide' in r['label'].lower() and 'non' not in r['label'].lower()
288
+ )
289
+ risk_flag = suicide_count >= 3
290
+
291
+ return {
292
+ 'dataset1': {
293
+ 'task': 'Depression Type (6 Classes)',
294
+ 'models': d1,
295
+ 'winner_model': d1_winner[0],
296
+ 'winner_prediction': d1_winner[1]['label'],
297
+ 'winner_confidence': d1_winner[1]['confidence'],
298
+ 'class_probs': xlmr1.get('all_probs', {}) if xlmr1 else {},
299
+ },
300
+ 'dataset2': {
301
+ 'task': 'Depressed or Not?',
302
+ 'models': d2,
303
+ 'winner_model': d2_winner[0],
304
+ 'winner_prediction': d2_winner[1]['label'],
305
+ 'winner_confidence': d2_winner[1]['confidence'],
306
+ },
307
+ 'dataset3': {
308
+ 'task': 'Suicide Risk Detection',
309
+ 'models': d3,
310
+ 'winner_model': d3_winner[0],
311
+ 'winner_prediction': d3_winner[1]['label'],
312
+ 'winner_confidence': d3_winner[1]['confidence'],
313
+ },
314
+ 'risk_flag': risk_flag,
315
+ 'suicide_votes': f'{suicide_count}/4 models flagged suicide risk',
316
+ 'winner_summary': {
317
+ 'depression_type': f"{d1_winner[1]['label']} ({d1_winner[1]['confidence']*100:.1f}% β€” {d1_winner[0]})",
318
+ 'depressed': f"{d2_winner[1]['label']} ({d2_winner[1]['confidence']*100:.1f}% β€” {d2_winner[0]})",
319
+ 'suicide_risk': f"{d3_winner[1]['label']} ({d3_winner[1]['confidence']*100:.1f}% β€” {d3_winner[0]})",
320
+ }
321
+ }