Ankit19102004 commited on
Commit
e70b7e5
·
0 Parent(s):

Clean TruthX API deployment without model weights

Browse files
Files changed (10) hide show
  1. .env +6 -0
  2. .gitattributes +6 -0
  3. .gitignore +22 -0
  4. README.md +52 -0
  5. api_keys.json +1 -0
  6. app.py +855 -0
  7. deployment.py +850 -0
  8. dockerfile +26 -0
  9. requirements.txt +12 -0
  10. requirements_space.txt +12 -0
.env ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ NEWSDATA_API_KEY=pub_427e5e1aadb64646a5e40826c0e7b5cc
2
+ NEWSAPI_API_KEY=e608b975addb47ffb8fdba39e756d631
3
+ GNEWS_API_KEY=310e612f245693ad3f86ad9a462ac7a0
4
+ MEDIASTACK_API_KEY=30f9f464ff009164a8827164df046170
5
+ FLASK_APP=main.py
6
+ FLASK_ENV=development
.gitattributes ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.csv filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.h5 filter=lfs diff=lfs merge=lfs -text
4
+ *.keras filter=lfs diff=lfs merge=lfs -text
5
+ *.pkl filter=lfs diff=lfs merge=lfs -text
6
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /.vscode
2
+
3
+ # data
4
+ data/
5
+
6
+ # notebooks
7
+ notebook/
8
+
9
+ # python
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+ .env
14
+ api_keys.json
15
+ instruction.txt
16
+ news_api.py
17
+ data/
18
+ notebook/
19
+ mlflow.db
20
+ .pytest_cache/
21
+ .coverage
22
+ htmlcov/
README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TruthX Fake News Detector
3
+ emoji: 🔍
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: docker
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # TruthX - Fake News Detection
12
+
13
+ TruthX uses state-of-the-art DistilBERT model to detect fake news articles with high accuracy.
14
+
15
+ ## Features
16
+
17
+ - **Real-time Detection**: Get instant predictions on news authenticity
18
+ - **Confidence Score**: See the model's confidence level
19
+ - **Multiple Models**: Supports BERT, DistilBERT, and RoBERTa models
20
+
21
+ ## How to Use
22
+
23
+ 1. Enter any news article or headline in the text box
24
+ 2. Click "Submit" to get the prediction
25
+ 3. View the classification (Real/Fake) with confidence scores
26
+
27
+ ## Technical Details
28
+
29
+ - **Model**: DistilBERT fine-tuned for fake news detection
30
+ - **Input**: Text up to 512 tokens
31
+ - **Output**: Classification label with probability scores
32
+
33
+ ## API Access
34
+
35
+ You can also access the model programmatically via the Hugging Face Inference API:
36
+
37
+ ```python
38
+ import requests
39
+
40
+ API_URL = "https://api-inference.huggingface.co/models/Ankit74990/TruthX-DISTILBERT"
41
+ headers = {"Authorization": "Bearer YOUR_TOKEN"}
42
+
43
+ def query(text):
44
+ response = requests.post(API_URL, headers=headers, json={"inputs": text})
45
+ return response.json()
46
+
47
+ result = query("Your news text here")
48
+ ```
49
+
50
+ ## Model Card
51
+
52
+ This space uses the [TruthX-DISTILBERT](https://huggingface.co/Ankit74990/TruthX-DISTILBERT) model.
api_keys.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"a0a6124c-25a4-48a4-bf45-44a42b9ebdf1": "anonymous"}
app.py ADDED
@@ -0,0 +1,855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle, json, uuid, re, traceback, nltk # noqa: E401
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import requests
6
+ from urllib.parse import quote
7
+ import xml.etree.ElementTree as ET
8
+
9
+ from flask import Flask, request, jsonify
10
+ from functools import wraps
11
+
12
+ from dotenv import load_dotenv
13
+ from nltk.corpus import stopwords
14
+ from nltk.stem.porter import PorterStemmer
15
+
16
+ import sys
17
+ import os
18
+
19
+ from transformers import (
20
+ AutoModel,
21
+ AutoTokenizer,
22
+ AutoModelForSequenceClassification,
23
+ BertTokenizerFast,
24
+ BertModel,
25
+ )
26
+ from huggingface_hub import hf_hub_download
27
+
28
+ torch.set_num_threads(1)
29
+ torch.set_grad_enabled(False)
30
+
31
+ import warnings
32
+
33
+ warnings.filterwarnings("ignore")
34
+
35
+
36
+ # ==============================
37
+ # APP INIT
38
+ # ==============================
39
+ load_dotenv()
40
+ app = Flask(__name__)
41
+ device = torch.device("cpu")
42
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
43
+ API_KEYS_FILE = os.path.join(BASE_DIR, "api_keys.json")
44
+
45
+ # ==============================
46
+ # NLTK
47
+ # ==============================
48
+ try:
49
+ nltk.download("stopwords", quiet=True)
50
+ all_stopwords = stopwords.words("english")
51
+ except Exception:
52
+ all_stopwords = []
53
+
54
+ ps = PorterStemmer()
55
+
56
+ # ==============================
57
+ # EXTERNAL API KEYS
58
+ # ==============================
59
+ NEWSDATA_KEY = os.getenv("NEWSDATA_API_KEY")
60
+ NEWSAPI_KEY = os.getenv("NEWSAPI_API_KEY")
61
+ GNEWS_KEY = os.getenv("GNEWS_API_KEY")
62
+ MEDIASTACK_KEY = os.getenv("MEDIASTACK_API_KEY")
63
+
64
+
65
+ # ==============================
66
+ # API KEY MANAGEMENT
67
+ # ==============================
68
+
69
+
70
+ def load_truthx_api_keys() -> dict:
71
+ if os.path.exists(API_KEYS_FILE):
72
+ try:
73
+ with open(API_KEYS_FILE, "r") as f:
74
+ data = json.load(f)
75
+ return data if isinstance(data, dict) else {}
76
+ except Exception as e:
77
+ print(f"[ERROR] Loading API keys: {e}")
78
+ return {}
79
+ else:
80
+ # Create empty file if not exists
81
+ save_truthx_api_keys({})
82
+ return {}
83
+
84
+
85
+ def save_truthx_api_keys(keys: dict) -> None:
86
+ try:
87
+ with open(API_KEYS_FILE, "w") as f:
88
+ json.dump(keys, f)
89
+ except Exception as e:
90
+ print(f"[ERROR] Saving API keys: {e}")
91
+
92
+
93
+ def verify_api_key(key: str) -> bool:
94
+ # Always reload to get newly generated keys
95
+ current_keys = load_truthx_api_keys()
96
+ return key in current_keys
97
+
98
+
99
+ def require_api_key(f):
100
+ @wraps(f)
101
+ def decorated_function(*args, **kwargs):
102
+ # Check header
103
+ api_key = request.headers.get("X-API-KEY")
104
+ # Fallback to query param
105
+ if not api_key:
106
+ api_key = request.args.get("api_key")
107
+
108
+ if not api_key or not verify_api_key(api_key):
109
+ return jsonify({"error": "Invalid or missing API key. Use /generate_key"}), 401
110
+ return f(*args, **kwargs)
111
+
112
+ return decorated_function
113
+
114
+
115
+ # ==============================
116
+ # TEXT PREPROCESSING
117
+ # ==============================
118
+
119
+
120
+ def preprocess_text(text: str) -> str:
121
+ """Lowercase, remove non-alpha, strip stopwords, stem."""
122
+ if not text:
123
+ return ""
124
+ tokens = re.sub("[^a-zA-Z]", " ", text).lower().split()
125
+ return " ".join(ps.stem(w) for w in tokens if w not in all_stopwords)
126
+
127
+
128
+ # ==============================
129
+ # PAD SEQUENCES
130
+ # ==============================
131
+
132
+
133
+ def pad_sequences(sequences: list, maxlen: int, padding: str = "pre") -> np.ndarray:
134
+
135
+ result = []
136
+ for seq in sequences:
137
+ seq = list(seq)
138
+ if len(seq) >= maxlen:
139
+ seq = seq[-maxlen:]
140
+ else:
141
+ pad = [0] * (maxlen - len(seq))
142
+ seq = (pad + seq) if padding == "pre" else (seq + pad)
143
+ result.append(seq)
144
+ return np.array(result, dtype=np.int32)
145
+
146
+
147
+ # ==============================
148
+ # EXTERNAL NEWS VERIFICATION
149
+ # ==============================
150
+
151
+
152
+ def check_external_news(query: str) -> float:
153
+ """Improved external verification with weighted scoring + Google RSS"""
154
+
155
+ if not query:
156
+ return 0.0
157
+
158
+ # 🔹 Full query
159
+ encoded = quote(query)
160
+
161
+ # 🔹 Smart keyword extraction (for Mediastack + Google)
162
+ stop_words = {"the", "is", "in", "on", "at", "a", "an", "of", "for", "to", "and"}
163
+ keywords = [w for w in query.lower().split() if w not in stop_words]
164
+ simple_query = " ".join(keywords[:3])
165
+ encoded_simple = quote(simple_query)
166
+
167
+ # =========================
168
+ # SCORES
169
+ # =========================
170
+ newsdata = 0
171
+ newsapi = 0
172
+ gnews = 0
173
+ mediastack = 0
174
+ google = 0
175
+
176
+ # =========================
177
+ # 1. NEWSDATA
178
+ # =========================
179
+ if NEWSDATA_KEY:
180
+ try:
181
+ r = requests.get(
182
+ f"https://newsdata.io/api/1/news?apikey={NEWSDATA_KEY}&q={encoded}",
183
+ timeout=5,
184
+ )
185
+ if r.status_code == 200 and r.json().get("totalResults", 0) > 0:
186
+ newsdata = 1
187
+ except Exception:
188
+ pass
189
+
190
+ # =========================
191
+ # 2. NEWSAPI
192
+ # =========================
193
+ if NEWSAPI_KEY:
194
+ try:
195
+ r = requests.get(
196
+ f"https://newsapi.org/v2/everything?q={encoded}&apiKey={NEWSAPI_KEY}&pageSize=1",
197
+ timeout=5,
198
+ )
199
+ if r.status_code == 200 and r.json().get("totalResults", 0) > 0:
200
+ newsapi = 1
201
+ except Exception:
202
+ pass
203
+
204
+ # =========================
205
+ # 3. GNEWS
206
+ # =========================
207
+ if GNEWS_KEY:
208
+ try:
209
+ r = requests.get(
210
+ f"https://gnews.io/api/v4/search?q={encoded}&token={GNEWS_KEY}&max=1",
211
+ timeout=5,
212
+ )
213
+ if r.status_code == 200 and r.json().get("totalArticles", 0) > 0:
214
+ gnews = 1
215
+ except Exception:
216
+ pass
217
+
218
+ # =========================
219
+ # 4. MEDIASTACK (FIXED)
220
+ # =========================
221
+ if MEDIASTACK_KEY:
222
+ try:
223
+ r = requests.get(
224
+ f"https://api.mediastack.com/v1/news?access_key={MEDIASTACK_KEY}&keywords={encoded_simple}&limit=1",
225
+ timeout=5,
226
+ )
227
+ total = r.json().get("pagination", {}).get("total", 0)
228
+
229
+ # 🔥 Ignore noisy results
230
+ if r.status_code == 200 and 0 < total < 5000:
231
+ mediastack = 1
232
+ except Exception:
233
+ pass
234
+
235
+ # =========================
236
+ # 5. GOOGLE NEWS RSS ⭐
237
+ # =========================
238
+ try:
239
+ r = requests.get(
240
+ f"https://news.google.com/rss/search?q={encoded_simple}",
241
+ timeout=5,
242
+ )
243
+ root = ET.fromstring(r.content)
244
+ items = root.findall(".//item")
245
+
246
+ if len(items) > 0:
247
+ google = 1
248
+ except Exception:
249
+ pass
250
+
251
+ # =========================
252
+ # FINAL WEIGHTED SCORE
253
+ # =========================
254
+ score = (
255
+ newsdata * 0.35
256
+ + newsapi * 0.15
257
+ + gnews * 0.25
258
+ + mediastack * 0.05
259
+ + google * 0.2
260
+ )
261
+
262
+ return round(score, 4)
263
+
264
+
265
+ # ======================================================
266
+ # MODEL 1 — NLP (TF-IDF + SVM)
267
+ # ======================================================
268
+
269
+ nlp_model = None
270
+ nlp_vector = None
271
+
272
+ def load_nlp():
273
+ global nlp_model, nlp_vector
274
+ if nlp_model is None:
275
+ try:
276
+ repo_id = "Ankit74990/TruthX-NLP"
277
+ m_path = hf_hub_download(repo_id=repo_id, filename="model2.pkl")
278
+ v_path = hf_hub_download(repo_id=repo_id, filename="tfidfvect2.pkl")
279
+ nlp_model = pickle.load(open(m_path, "rb"))
280
+ nlp_vector = pickle.load(open(v_path, "rb"))
281
+ print(f"[OK] NLP model loaded")
282
+ except Exception as e:
283
+ print(f"[WARN] NLP model not loaded: {e}")
284
+
285
+ def predict_nlp(text: str) -> list:
286
+ load_nlp()
287
+ if not nlp_model or not nlp_vector:
288
+ return []
289
+ vec = nlp_vector.transform([preprocess_text(text)])
290
+ pred = nlp_model.predict(vec)[0]
291
+ try:
292
+ decision = nlp_model.decision_function(vec)[0]
293
+ conf = 1 / (1 + np.exp(-abs(decision)))
294
+ except:
295
+ conf = 0.8
296
+ return [("Real News" if pred == 1 else "Fake News", float(conf))]
297
+
298
+
299
+ # ======================================================
300
+ # MODEL 2 — HYBRID
301
+ # ======================================================
302
+
303
+
304
+ class HybridModel_A(nn.Module):
305
+ """CNN → MaxPool → BiLSTM (your original correct model)"""
306
+
307
+ def __init__(self, vocab_size: int, embed_dim: int = 256):
308
+ super().__init__()
309
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
310
+
311
+ self.conv = nn.Conv1d(embed_dim, 256, kernel_size=5)
312
+ self.pool = nn.MaxPool1d(2)
313
+
314
+ self.lstm = nn.LSTM(256, 128, batch_first=True, bidirectional=True)
315
+
316
+ self.fc1 = nn.Linear(256, 128)
317
+ self.dropout = nn.Dropout(0.5)
318
+ self.fc2 = nn.Linear(128, 2)
319
+
320
+ def forward(self, x):
321
+ x = self.embedding(x)
322
+ x = x.permute(0, 2, 1)
323
+
324
+ x = torch.relu(self.conv(x))
325
+ x = self.pool(x)
326
+
327
+ x = x.permute(0, 2, 1)
328
+ x, _ = self.lstm(x)
329
+
330
+ x = x[:, -1, :]
331
+
332
+ x = torch.relu(self.fc1(x))
333
+ x = self.dropout(x)
334
+
335
+ return self.fc2(x)
336
+
337
+
338
+ class HybridModel_B(nn.Module):
339
+ """CNN + LSTM PARALLEL (second file model)"""
340
+
341
+ def __init__(self, vocab_size: int, embed_dim: int = 256):
342
+ super().__init__()
343
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
344
+
345
+ self.conv = nn.Conv1d(embed_dim, 256, kernel_size=5)
346
+ self.lstm = nn.LSTM(256, 128, batch_first=True, bidirectional=True)
347
+
348
+ self.fc1 = nn.Linear(256, 128)
349
+ self.fc2 = nn.Linear(128, 2)
350
+
351
+ def forward(self, x):
352
+ x_embed = self.embedding(x)
353
+
354
+ # CNN branch
355
+ x_cnn = torch.relu(self.conv(x_embed.permute(0, 2, 1)))
356
+ x_cnn = torch.max(x_cnn, dim=2)[0]
357
+
358
+ # LSTM branch
359
+ x_lstm, _ = self.lstm(x_embed)
360
+ x_lstm = x_lstm[:, -1, :]
361
+
362
+ x = x_cnn + x_lstm
363
+
364
+ x = torch.relu(self.fc1(x))
365
+ return self.fc2(x)
366
+
367
+
368
+ # ======================================================
369
+ # MODEL 2 — HYBRID (FIXED)
370
+ # ======================================================
371
+
372
+
373
+ class HybridEnsemble:
374
+ DIRS = [
375
+ ("Ankit74990/TruthX-HYBRID", HybridModel_A, "hybrid_model1.pt"),
376
+ ("Ankit74990/TruthX-HYBRID2", HybridModel_B, "hybrid_model2.pt"),
377
+ ]
378
+
379
+ def __init__(self):
380
+ self.models = []
381
+ self.tokenizers = []
382
+ self.max_lens = []
383
+
384
+ print("[HYBRID] Loading models...")
385
+ self._load_all()
386
+ print(f"[OK] Hybrid models loaded ({len(self.models)})")
387
+
388
+ def _load_all(self):
389
+ for repo_id, model_class, m_name in self.DIRS:
390
+ try:
391
+ tok_path = hf_hub_download(repo_id=repo_id, filename="tokenizer.pkl")
392
+ cfg_path = hf_hub_download(repo_id=repo_id, filename="config.pkl")
393
+ model_path = hf_hub_download(repo_id=repo_id, filename=m_name)
394
+
395
+ try:
396
+ tok_data = pickle.load(open(tok_path, "rb"))
397
+ if isinstance(tok_data, dict) and "word_index" in tok_data:
398
+ class CleanTokenizer:
399
+ def __init__(self, word_index):
400
+ self.word_index = word_index
401
+ def texts_to_sequences(self, texts):
402
+ return [[self.word_index.get(w, 0) for w in text.split()] for text in texts]
403
+ tok = CleanTokenizer(tok_data["word_index"])
404
+ else:
405
+ raise Exception()
406
+ except Exception:
407
+ class SimpleTokenizer:
408
+ def texts_to_sequences(self, texts):
409
+ return [[1] * len(t.split()) for t in texts]
410
+ tok = SimpleTokenizer()
411
+
412
+ cfg = pickle.load(open(cfg_path, "rb"))
413
+ vocab_size = cfg.get("max_words") or cfg.get("vocab_size")
414
+ max_len = cfg.get("max_len")
415
+
416
+ if not vocab_size or not max_len:
417
+ continue
418
+
419
+ model = model_class(vocab_size).to(device)
420
+ model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
421
+ model.eval()
422
+
423
+ self.models.append(model)
424
+ self.tokenizers.append(tok)
425
+ self.max_lens.append(max_len)
426
+ print(f"[OK] Hybrid model loaded from {repo_id}")
427
+
428
+ except Exception as e:
429
+ print(f"[ERR] Failed to load hybrid from {repo_id}: {e}")
430
+ continue
431
+
432
+ def predict(self, text: str) -> list:
433
+ if not self.models:
434
+ return []
435
+ results = []
436
+ for model, tok, max_len in zip(self.models, self.tokenizers, self.max_lens):
437
+ try:
438
+ seq = tok.texts_to_sequences([text])
439
+ padded = pad_sequences(seq, maxlen=max_len, padding="pre")
440
+ x = torch.tensor(padded, dtype=torch.long).to(device)
441
+ with torch.no_grad():
442
+ probs = torch.softmax(model(x), dim=1)
443
+ conf, pred = torch.max(probs, dim=1)
444
+ label = "Real News" if pred.item() == 1 else "Fake News"
445
+ results.append((label, float(conf.item())))
446
+ except Exception:
447
+ continue
448
+ return results
449
+
450
+
451
+ hybrid_ensemble = None
452
+
453
+
454
+ def get_hybrid():
455
+ global hybrid_ensemble
456
+ if hybrid_ensemble is None:
457
+ print("[HYBRID] Lazy loading...")
458
+ hybrid_ensemble = HybridEnsemble()
459
+ return hybrid_ensemble
460
+
461
+
462
+ def predict_hybrid(text: str) -> list:
463
+ return get_hybrid().predict(text)
464
+
465
+
466
+ # ======================================================
467
+ # MODEL 3 — NAIVE (Naive Bayes / Passive-Aggressive)
468
+ # ======================================================
469
+
470
+ naive_models = []
471
+
472
+ def load_naive():
473
+ global naive_models
474
+ if not naive_models:
475
+ repo_id = "Ankit74990/TruthX-NAIVE"
476
+ files = ["nb_tfidf.pkl", "nb_count.pkl", "passive_aggressive.pkl", "best_passive_aggressive.pkl"]
477
+ for f in files:
478
+ try:
479
+ p = hf_hub_download(repo_id=repo_id, filename=f)
480
+ naive_models.append(pickle.load(open(p, "rb")))
481
+ except:
482
+ pass
483
+ print(f"[OK] Naive models loaded ({len(naive_models)})")
484
+
485
+ def predict_naive(text: str) -> list:
486
+ load_naive()
487
+ results = []
488
+ for model in naive_models:
489
+ try:
490
+ probs = model.predict_proba([text])[0]
491
+ pred, conf = int(np.argmax(probs)), float(probs.max())
492
+ except Exception:
493
+ try:
494
+ d = model.decision_function([text])[0]
495
+ pred = 1 if d > 0 else 0
496
+ conf = 1 / (1 + np.exp(-abs(d)))
497
+ except:
498
+ pred = 0
499
+ conf = 0.5
500
+ results.append(("Fake News" if pred == 0 else "Real News", float(conf)))
501
+ return results
502
+
503
+
504
+ # ======================================================
505
+ # MODEL 4 — BERT
506
+ # ======================================================
507
+
508
+ bert_tokenizer = None
509
+ _bert_base = None
510
+
511
+ def load_bert_base():
512
+ global bert_tokenizer, _bert_base
513
+ if _bert_base is None:
514
+ repo_id = "bert-base-uncased"
515
+ try:
516
+ bert_tokenizer = BertTokenizerFast.from_pretrained(repo_id)
517
+ _bert_base = BertModel.from_pretrained(repo_id).to(device)
518
+ print("[OK] BERT base loaded")
519
+ except Exception as e:
520
+ print(f"[ERR] BERT base fail: {e}")
521
+
522
+ class BERT_Arch(nn.Module):
523
+ def __init__(self, bert):
524
+ super().__init__()
525
+ self.bert = bert
526
+ self.fc1 = nn.Linear(768, 512)
527
+ self.fc2 = nn.Linear(512, 2)
528
+ def forward(self, sent_id, mask):
529
+ x = self.bert(sent_id, attention_mask=mask)["pooler_output"]
530
+ return self.fc2(self.fc1(x))
531
+
532
+ def _load_bert_ckpt(repo_id: str, filename: str) -> BERT_Arch:
533
+ load_bert_base()
534
+ if _bert_base is None:
535
+ return None
536
+ model = BERT_Arch(_bert_base)
537
+ try:
538
+ path = hf_hub_download(repo_id=repo_id, filename=filename)
539
+ model.load_state_dict(torch.load(path, map_location=device, weights_only=False))
540
+ except:
541
+ pass
542
+ model.eval()
543
+ return model
544
+
545
+ bert_models = None
546
+
547
+ def get_bert_models():
548
+ global bert_models
549
+ if bert_models is None:
550
+ print("[BERT] Lazy loading...")
551
+ repo_id = "Ankit74990/TruthX-BERT"
552
+ bert_models = [
553
+ _load_bert_ckpt(repo_id, "bert_model.pt"),
554
+ _load_bert_ckpt(repo_id, "best_model.pt"),
555
+ _load_bert_ckpt(repo_id, "c2_new_model_weights.pt"),
556
+ ]
557
+ # Filter out failed loads
558
+ bert_models = [m for m in bert_models if m is not None]
559
+ print(f"[OK] BERT loaded ({len(bert_models)})")
560
+ return bert_models
561
+
562
+ def predict_bert(text: str) -> list:
563
+ load_bert_base()
564
+ if bert_tokenizer is None:
565
+ return []
566
+ tokens = bert_tokenizer(
567
+ [text],
568
+ max_length=128,
569
+ padding="max_length",
570
+ truncation=True,
571
+ return_tensors="pt",
572
+ )
573
+ tokens = {k: v.to(device) for k, v in tokens.items()}
574
+ results = []
575
+ for model in get_bert_models():
576
+ with torch.no_grad():
577
+ out = model(tokens["input_ids"], tokens["attention_mask"])
578
+ probs = torch.softmax(out, dim=1)
579
+ pred = torch.argmax(probs, dim=1).item()
580
+ conf = probs.max().item()
581
+ results.append(("Fake News" if pred == 1 else "Real News", float(conf)))
582
+ return results
583
+
584
+
585
+ # ======================================================
586
+ # MODEL 5 — DISTILBERT (HuggingFace fine-tuned)
587
+ # ======================================================
588
+
589
+ distil_model = None
590
+ distil_tokenizer = None
591
+
592
+ def get_distil():
593
+ global distil_model, distil_tokenizer
594
+ if distil_model is None:
595
+ print("[DISTIL] Lazy loading...")
596
+ repo_id = "Ankit74990/TruthX-DISTILBERT"
597
+ try:
598
+ distil_tokenizer = AutoTokenizer.from_pretrained(repo_id)
599
+ distil_model = AutoModelForSequenceClassification.from_pretrained(repo_id).to(device)
600
+ distil_model.eval()
601
+ print(f"[OK] DistilBERT loaded")
602
+ except Exception as e:
603
+ print(f"[ERR] DistilBERT fail: {e}")
604
+ return distil_model, distil_tokenizer
605
+
606
+ def predict_distil(text: str) -> list:
607
+ try:
608
+ model, tokenizer = get_distil()
609
+ if model is None or tokenizer is None:
610
+ return []
611
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
612
+ inputs = {k: v.to(device) for k, v in inputs.items()}
613
+ with torch.no_grad():
614
+ out = model(**inputs)
615
+ probs = torch.softmax(out.logits, dim=1)
616
+ conf, pred = torch.max(probs, dim=1)
617
+ return [("Real News" if pred.item() == 1 else "Fake News", float(conf.item()))]
618
+ except Exception:
619
+ return []
620
+
621
+
622
+ # ======================================================
623
+ # MODEL 6 — ROBERTA (HuggingFace fine-tuned)
624
+ # ======================================================
625
+
626
+ roberta_model = None
627
+ roberta_tokenizer = None
628
+
629
+ def get_roberta():
630
+ global roberta_model, roberta_tokenizer
631
+ if roberta_model is None:
632
+ print("[ROBERTA] Lazy loading...")
633
+ repo_id = "Ankit74990/TruthX-ROBERTA"
634
+ try:
635
+ roberta_tokenizer = AutoTokenizer.from_pretrained(repo_id)
636
+ roberta_model = AutoModelForSequenceClassification.from_pretrained(repo_id).to(device)
637
+ roberta_model.eval()
638
+ print(f"[OK] RoBERTa loaded")
639
+ except Exception as e:
640
+ print(f"[ERR] RoBERTa fail: {e}")
641
+ return roberta_model, roberta_tokenizer
642
+
643
+ def predict_roberta(text: str) -> list:
644
+ try:
645
+ model, tokenizer = get_roberta()
646
+ if model is None or tokenizer is None:
647
+ return []
648
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
649
+ inputs = {k: v.to(device) for k, v in inputs.items()}
650
+ with torch.no_grad():
651
+ out = model(**inputs)
652
+ probs = torch.softmax(out.logits, dim=1)
653
+ conf, pred = torch.max(probs, dim=1)
654
+ return [("Real News" if pred.item() == 1 else "Fake News", float(conf.item()))]
655
+ except Exception:
656
+ return []
657
+
658
+
659
+ # ======================================================
660
+ # ENSEMBLE FUSION
661
+ # ======================================================
662
+
663
+
664
+ def final_ensemble(all_results: list) -> tuple:
665
+ """Sum confidence scores per label; highest total wins."""
666
+ fake = sum(c for l, c in all_results if "Fake" in l) # noqa: E741
667
+ real = sum(c for l, c in all_results if "Real" in l) # noqa: E741
668
+ total = fake + real
669
+ if total == 0:
670
+ return "Real News", 0.5
671
+ label = "Fake News" if fake > real else "Real News"
672
+ return label, round(max(fake, real) / total, 4)
673
+
674
+
675
+ def format_output(raw: dict) -> dict:
676
+ return {
677
+ k: [{"prediction": l, "confidence": round(c, 4)} for l, c in v] # noqa: E741
678
+ for k, v in raw.items()
679
+ }
680
+
681
+
682
+ # ==============================
683
+ # ROUTES
684
+ # ==============================
685
+
686
+
687
+ @app.route("/", methods=["GET"])
688
+ def index():
689
+ return jsonify(
690
+ {
691
+ "message": "Welcome to TruthX API",
692
+ "endpoints": {
693
+ "POST /generate_key": "Get a new API key",
694
+ "POST /verify": "Full ensemble prediction (all models)",
695
+ "POST /predict/<model>": "Individual model prediction (nlp, hybrid, naive, bert, distilbert, roberta)",
696
+ "GET /test_hybrid": "Check how many hybrid models are loaded",
697
+ },
698
+ }
699
+ )
700
+
701
+
702
+ @app.route("/test_hybrid", methods=["GET"])
703
+ def test_hybrid():
704
+ """Quick diagnostic: check loaded hybrid models."""
705
+ try:
706
+ ensemble = get_hybrid()
707
+ return jsonify(
708
+ {
709
+ "hybrid_models_loaded": len(ensemble.models),
710
+ "configs": [
711
+ {"max_len": ml, "vocab_size": "loaded"}
712
+ for ml in ensemble.max_lens
713
+ ],
714
+ }
715
+ )
716
+ except Exception as e:
717
+ return jsonify({"error": str(e)}), 500
718
+
719
+
720
+ @app.route("/generate_key", methods=["GET", "POST"])
721
+ def generate_key():
722
+ """Generate and persist a new UUID API key."""
723
+ new_key = str(uuid.uuid4())
724
+ keys = load_truthx_api_keys()
725
+ keys[new_key] = "user"
726
+ save_truthx_api_keys(keys)
727
+ return jsonify(
728
+ {
729
+ "status": "success",
730
+ "api_key": new_key,
731
+ "message": "Store this key — required for all /predict and /verify",
732
+ }
733
+ )
734
+
735
+
736
+ def _get_request_text():
737
+ data = request.get_json(silent=True)
738
+ if not data or "text" not in data:
739
+ return None, "Provide 'text' in request body"
740
+ text = data["text"].strip()
741
+ if not text:
742
+ return None, "Empty text"
743
+ return text, None
744
+
745
+
746
+ @app.route("/predict/nlp", methods=["POST"])
747
+ @require_api_key
748
+ def predict_nlp_endpoint():
749
+ text, err = _get_request_text()
750
+ if err:
751
+ return jsonify({"error": err}), 400
752
+ return jsonify({"prediction": predict_nlp(text)})
753
+
754
+
755
+ @app.route("/predict/hybrid", methods=["POST"])
756
+ @require_api_key
757
+ def predict_hybrid_endpoint():
758
+ text, err = _get_request_text()
759
+ if err:
760
+ return jsonify({"error": err}), 400
761
+ return jsonify({"prediction": predict_hybrid(text)})
762
+
763
+
764
+ @app.route("/predict/naive", methods=["POST"])
765
+ @require_api_key
766
+ def predict_naive_endpoint():
767
+ text, err = _get_request_text()
768
+ if err:
769
+ return jsonify({"error": err}), 400
770
+ return jsonify({"prediction": predict_naive(text)})
771
+
772
+
773
+ @app.route("/predict/bert", methods=["POST"])
774
+ @require_api_key
775
+ def predict_bert_endpoint():
776
+ text, err = _get_request_text()
777
+ if err:
778
+ return jsonify({"error": err}), 400
779
+ return jsonify({"prediction": predict_bert(text)})
780
+
781
+
782
+ @app.route("/predict/distilbert", methods=["POST"])
783
+ @require_api_key
784
+ def predict_distilbert_endpoint():
785
+ text, err = _get_request_text()
786
+ if err:
787
+ return jsonify({"error": err}), 400
788
+ return jsonify({"prediction": predict_distil(text)})
789
+
790
+
791
+ @app.route("/predict/roberta", methods=["POST"])
792
+ @require_api_key
793
+ def predict_roberta_endpoint():
794
+ text, err = _get_request_text()
795
+ if err:
796
+ return jsonify({"error": err}), 400
797
+ return jsonify({"prediction": predict_roberta(text)})
798
+
799
+
800
+ @app.route("/verify", methods=["POST"])
801
+ @require_api_key
802
+ def verify():
803
+ try:
804
+ data = request.get_json(silent=True)
805
+ if not data or "text" not in data:
806
+ return jsonify({"error": "Provide 'text' in request body"}), 400
807
+
808
+ text = data["text"].strip()
809
+ title = data.get("title", text[:100]).strip()
810
+
811
+ if not text:
812
+ return jsonify({"error": "Empty text"}), 400
813
+
814
+ full_doc = f"{title} {text}".strip()
815
+
816
+ def safe(fn):
817
+ try: return fn(full_doc)
818
+ except Exception as e:
819
+ print(f"[MODEL ERROR] {fn.__name__}: {e}")
820
+ return []
821
+
822
+ raw = {
823
+ "nlp": safe(predict_nlp),
824
+ "hybrid": safe(predict_hybrid),
825
+ "naive": safe(predict_naive),
826
+ "bert": safe(predict_bert),
827
+ "distilbert": safe(predict_distil),
828
+ "roberta": safe(predict_roberta),
829
+ }
830
+
831
+ all_preds = [p for preds in raw.values() for p in preds]
832
+ final_label, model_conf = final_ensemble(all_preds)
833
+
834
+ ext_score = check_external_news(title)
835
+ # Weighted ensemble: 40% models, 60% external as per user request
836
+ final_accuracy = round((model_conf * 0.4 + ext_score * 0.6) * 100, 2)
837
+
838
+ return jsonify(
839
+ {
840
+ "title": title,
841
+ "prediction": final_label,
842
+ "confidence": model_conf,
843
+ "accuracy": f"{final_accuracy}%",
844
+ "external_score": round(ext_score, 4),
845
+ "models": format_output(raw),
846
+ }
847
+ )
848
+
849
+ except Exception as e:
850
+ traceback.print_exc()
851
+ return jsonify({"error": str(e)}), 500
852
+
853
+
854
+ if __name__ == "__main__":
855
+ app.run(host="0.0.0.0", port=7860, debug=False)
deployment.py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle, json, uuid, re, traceback, nltk # noqa: E401
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import requests
6
+ from urllib.parse import quote
7
+ import xml.etree.ElementTree as ET
8
+
9
+ from flask import Flask, request, jsonify
10
+ from functools import wraps
11
+
12
+ from dotenv import load_dotenv
13
+ from nltk.corpus import stopwords
14
+ from nltk.stem.porter import PorterStemmer
15
+
16
+ import sys
17
+ import os
18
+
19
+ from transformers import (
20
+ AutoModel,
21
+ AutoTokenizer,
22
+ AutoModelForSequenceClassification,
23
+ BertTokenizerFast,
24
+ BertModel,
25
+ )
26
+
27
+ torch.set_num_threads(1)
28
+ torch.set_grad_enabled(False)
29
+
30
+ import warnings
31
+
32
+ warnings.filterwarnings("ignore")
33
+
34
+
35
+ # ==============================
36
+ # APP INIT
37
+ # ==============================
38
+ load_dotenv()
39
+ app = Flask(__name__)
40
+ device = torch.device("cpu")
41
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
42
+ API_KEYS_FILE = "api_keys.json"
43
+
44
+ # ==============================
45
+ # NLTK
46
+ # ==============================
47
+ try:
48
+ nltk.download("stopwords", quiet=True)
49
+ all_stopwords = stopwords.words("english")
50
+ except Exception:
51
+ all_stopwords = []
52
+
53
+ ps = PorterStemmer()
54
+
55
+ # ==============================
56
+ # EXTERNAL API KEYS
57
+ # ==============================
58
+ NEWSDATA_KEY = os.getenv("NEWSDATA_API_KEY")
59
+ NEWSAPI_KEY = os.getenv("NEWSAPI_API_KEY")
60
+ GNEWS_KEY = os.getenv("GNEWS_API_KEY")
61
+ MEDIASTACK_KEY = os.getenv("MEDIASTACK_API_KEY")
62
+
63
+
64
+ # ==============================
65
+ # API KEY MANAGEMENT
66
+ # ==============================
67
+
68
+
69
+ def load_truthx_api_keys() -> dict:
70
+ if os.path.exists(API_KEYS_FILE):
71
+ with open(API_KEYS_FILE, "r") as f:
72
+ try:
73
+ return json.load(f)
74
+ except Exception:
75
+ return {}
76
+ return {}
77
+
78
+
79
+ def save_truthx_api_keys(keys: dict) -> None:
80
+ try:
81
+ with open(API_KEYS_FILE, "w") as f:
82
+ json.dump(keys, f)
83
+ except Exception as e:
84
+ print(f"[ERROR] Saving API keys: {e}")
85
+
86
+
87
+ def verify_api_key(key: str) -> bool:
88
+ return key in truthx_api_keys
89
+
90
+
91
+ def require_api_key(f):
92
+ @wraps(f)
93
+ def decorated_function(*args, **kwargs):
94
+ # Check header
95
+ api_key = request.headers.get("X-API-KEY")
96
+ # Fallback to query param
97
+ if not api_key:
98
+ api_key = request.args.get("api_key")
99
+
100
+ if not api_key or not verify_api_key(api_key):
101
+ return jsonify({"error": "Invalid or missing API key. Use /generate_key"}), 401
102
+ return f(*args, **kwargs)
103
+
104
+ return decorated_function
105
+
106
+
107
+ truthx_api_keys = load_truthx_api_keys()
108
+
109
+
110
+ # ==============================
111
+ # TEXT PREPROCESSING
112
+ # ==============================
113
+
114
+
115
+ def preprocess_text(text: str) -> str:
116
+ """Lowercase, remove non-alpha, strip stopwords, stem."""
117
+ tokens = re.sub("[^a-zA-Z]", " ", text).lower().split()
118
+ return " ".join(ps.stem(w) for w in tokens if w not in all_stopwords)
119
+
120
+
121
+ # ==============================
122
+ # PAD SEQUENCES
123
+ # ==============================
124
+
125
+
126
+ def pad_sequences(sequences: list, maxlen: int, padding: str = "pre") -> np.ndarray:
127
+
128
+ result = []
129
+ for seq in sequences:
130
+ seq = list(seq)
131
+ if len(seq) >= maxlen:
132
+ seq = seq[-maxlen:]
133
+ else:
134
+ pad = [0] * (maxlen - len(seq))
135
+ seq = (pad + seq) if padding == "pre" else (seq + pad)
136
+ result.append(seq)
137
+ return np.array(result, dtype=np.int32)
138
+
139
+
140
+ # ==============================
141
+ # EXTERNAL NEWS VERIFICATION
142
+ # ==============================
143
+
144
+
145
+ def check_external_news(query: str) -> float:
146
+ """Improved external verification with weighted scoring + Google RSS"""
147
+
148
+ if not query:
149
+ return 0.0
150
+
151
+ # 🔹 Full query
152
+ encoded = quote(query)
153
+
154
+ # 🔹 Smart keyword extraction (for Mediastack + Google)
155
+ stop_words = {"the", "is", "in", "on", "at", "a", "an", "of", "for", "to", "and"}
156
+ keywords = [w for w in query.lower().split() if w not in stop_words]
157
+ simple_query = " ".join(keywords[:3])
158
+ encoded_simple = quote(simple_query)
159
+
160
+ # =========================
161
+ # SCORES
162
+ # =========================
163
+ newsdata = 0
164
+ newsapi = 0
165
+ gnews = 0
166
+ mediastack = 0
167
+ google = 0
168
+
169
+ # =========================
170
+ # 1. NEWSDATA
171
+ # =========================
172
+ try:
173
+ r = requests.get(
174
+ f"https://newsdata.io/api/1/news?apikey={NEWSDATA_KEY}&q={encoded}",
175
+ timeout=5,
176
+ )
177
+ if r.status_code == 200 and r.json().get("totalResults", 0) > 0:
178
+ newsdata = 1
179
+ except Exception:
180
+ pass
181
+
182
+ # =========================
183
+ # 2. NEWSAPI
184
+ # =========================
185
+ try:
186
+ r = requests.get(
187
+ f"https://newsapi.org/v2/everything?q={encoded}&apiKey={NEWSAPI_KEY}&pageSize=1",
188
+ timeout=5,
189
+ )
190
+ if r.status_code == 200 and r.json().get("totalResults", 0) > 0:
191
+ newsapi = 1
192
+ except Exception:
193
+ pass
194
+
195
+ # =========================
196
+ # 3. GNEWS
197
+ # =========================
198
+ try:
199
+ r = requests.get(
200
+ f"https://gnews.io/api/v4/search?q={encoded}&token={GNEWS_KEY}&max=1",
201
+ timeout=5,
202
+ )
203
+ if r.status_code == 200 and r.json().get("totalArticles", 0) > 0:
204
+ gnews = 1
205
+ except Exception:
206
+ pass
207
+
208
+ # =========================
209
+ # 4. MEDIASTACK (FIXED)
210
+ # =========================
211
+ try:
212
+ r = requests.get(
213
+ f"https://api.mediastack.com/v1/news?access_key={MEDIASTACK_KEY}&keywords={encoded_simple}&limit=1",
214
+ timeout=5,
215
+ )
216
+ total = r.json().get("pagination", {}).get("total", 0)
217
+
218
+ # 🔥 Ignore noisy results
219
+ if r.status_code == 200 and 0 < total < 5000:
220
+ mediastack = 1
221
+ except Exception:
222
+ pass
223
+
224
+ # =========================
225
+ # 5. GOOGLE NEWS RSS ⭐
226
+ # =========================
227
+ try:
228
+ r = requests.get(
229
+ f"https://news.google.com/rss/search?q={encoded_simple}",
230
+ timeout=5,
231
+ )
232
+ root = ET.fromstring(r.content)
233
+ items = root.findall(".//item")
234
+
235
+ if len(items) > 0:
236
+ google = 1
237
+ except Exception:
238
+ pass
239
+
240
+ # =========================
241
+ # FINAL WEIGHTED SCORE
242
+ # =========================
243
+ score = (
244
+ newsdata * 0.35
245
+ + newsapi * 0.15
246
+ + gnews * 0.25
247
+ + mediastack * 0.05
248
+ + google * 0.2
249
+ )
250
+
251
+ return round(score, 4)
252
+
253
+
254
+ # ======================================================
255
+ # MODEL 1 — NLP (TF-IDF + SVM)
256
+ # ======================================================
257
+
258
+ try:
259
+ nlp_model = pickle.load(
260
+ open(os.path.join(BASE_DIR, "model", "NLP", "model2.pkl"), "rb")
261
+ )
262
+ nlp_vector = pickle.load(
263
+ open(os.path.join(BASE_DIR, "model", "NLP", "tfidfvect2.pkl"), "rb")
264
+ )
265
+ print(f"[OK] NLP model loaded ({1 if nlp_model else 0})")
266
+ except Exception as e:
267
+ nlp_model = nlp_vector = None
268
+ print(f"[WARN] NLP model not loaded: {e}")
269
+
270
+
271
+ def predict_nlp(text: str) -> list:
272
+ if not nlp_model or not nlp_vector:
273
+ return []
274
+ vec = nlp_vector.transform([preprocess_text(text)])
275
+ pred = nlp_model.predict(vec)[0]
276
+ decision = nlp_model.decision_function(vec)[0]
277
+ conf = 1 / (1 + np.exp(-abs(decision)))
278
+ return [("Real News" if pred == 1 else "Fake News", float(conf))]
279
+
280
+
281
+ # ======================================================
282
+ # MODEL 2 — HYBRID
283
+ # ======================================================
284
+
285
+
286
+ class HybridModel_A(nn.Module):
287
+ """CNN → MaxPool → BiLSTM (your original correct model)"""
288
+
289
+ def __init__(self, vocab_size: int, embed_dim: int = 256):
290
+ super().__init__()
291
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
292
+
293
+ self.conv = nn.Conv1d(embed_dim, 256, kernel_size=5)
294
+ self.pool = nn.MaxPool1d(2)
295
+
296
+ self.lstm = nn.LSTM(256, 128, batch_first=True, bidirectional=True)
297
+
298
+ self.fc1 = nn.Linear(256, 128)
299
+ self.dropout = nn.Dropout(0.5)
300
+ self.fc2 = nn.Linear(128, 2)
301
+
302
+ def forward(self, x):
303
+ x = self.embedding(x)
304
+ x = x.permute(0, 2, 1)
305
+
306
+ x = torch.relu(self.conv(x))
307
+ x = self.pool(x)
308
+
309
+ x = x.permute(0, 2, 1)
310
+ x, _ = self.lstm(x)
311
+
312
+ x = x[:, -1, :]
313
+
314
+ x = torch.relu(self.fc1(x))
315
+ x = self.dropout(x)
316
+
317
+ return self.fc2(x)
318
+
319
+
320
+ class HybridModel_B(nn.Module):
321
+ """CNN + LSTM PARALLEL (second file model)"""
322
+
323
+ def __init__(self, vocab_size: int, embed_dim: int = 256):
324
+ super().__init__()
325
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
326
+
327
+ self.conv = nn.Conv1d(embed_dim, 256, kernel_size=5)
328
+ self.lstm = nn.LSTM(256, 128, batch_first=True, bidirectional=True)
329
+
330
+ self.fc1 = nn.Linear(256, 128)
331
+ self.fc2 = nn.Linear(128, 2)
332
+
333
+ def forward(self, x):
334
+ x_embed = self.embedding(x)
335
+
336
+ # CNN branch
337
+ x_cnn = torch.relu(self.conv(x_embed.permute(0, 2, 1)))
338
+ x_cnn = torch.max(x_cnn, dim=2)[0]
339
+
340
+ # LSTM branch
341
+ x_lstm, _ = self.lstm(x_embed)
342
+ x_lstm = x_lstm[:, -1, :]
343
+
344
+ x = x_cnn + x_lstm
345
+
346
+ x = torch.relu(self.fc1(x))
347
+ return self.fc2(x)
348
+
349
+
350
+ # ======================================================
351
+ # SAFE TOKENIZER
352
+ # ======================================================
353
+
354
+
355
+ def safe_load_tokenizer(path):
356
+ try:
357
+ return pickle.load(open(path, "rb"))
358
+ except Exception as e:
359
+ print(f"[TOKENIZER ERROR] {e}")
360
+ print("[FIX] Using fallback tokenizer (reduced accuracy)")
361
+
362
+ class SimpleTokenizer:
363
+ def texts_to_sequences(self, texts):
364
+ return [[1] * len(t.split()) for t in texts]
365
+
366
+ return SimpleTokenizer()
367
+
368
+
369
+ # ======================================================
370
+ # MODEL 2 — HYBRID (FIXED)
371
+ # ======================================================
372
+
373
+
374
+ class HybridEnsemble:
375
+ DIRS = [
376
+ (os.path.join(BASE_DIR, "model", "HYBRID"), HybridModel_A),
377
+ (os.path.join(BASE_DIR, "model", "HYBRID_"), HybridModel_B),
378
+ ]
379
+
380
+ def __init__(self):
381
+ self.models = []
382
+ self.tokenizers = []
383
+ self.max_lens = []
384
+
385
+ print("[HYBRID] Loading models...")
386
+ self._load_all()
387
+ print(f"[OK] Hybrid models loaded ({len(self.models)})")
388
+
389
+ def _load_all(self):
390
+ for path, model_class in self.DIRS:
391
+ try:
392
+ tok_path, cfg_path, model_path = None, None, None
393
+
394
+ for f in os.listdir(path):
395
+ f_lower = f.lower()
396
+
397
+ if "tokenizer" in f_lower:
398
+ tok_path = os.path.join(path, f)
399
+ elif "config" in f_lower:
400
+ cfg_path = os.path.join(path, f)
401
+ elif "hybrid_model" in f_lower:
402
+ model_path = os.path.join(path, f)
403
+
404
+ if not tok_path or not cfg_path or not model_path:
405
+ continue
406
+
407
+ try:
408
+ tok_data = pickle.load(open(tok_path, "rb"))
409
+
410
+ if isinstance(tok_data, dict) and "word_index" in tok_data:
411
+
412
+ class CleanTokenizer:
413
+ def __init__(self, word_index):
414
+ self.word_index = word_index
415
+
416
+ def texts_to_sequences(self, texts):
417
+ return [
418
+ [self.word_index.get(w, 0) for w in text.split()]
419
+ for text in texts
420
+ ]
421
+
422
+ tok = CleanTokenizer(tok_data["word_index"])
423
+ else:
424
+ raise Exception()
425
+
426
+ except Exception:
427
+
428
+ class SimpleTokenizer:
429
+ def texts_to_sequences(self, texts):
430
+ return [[1] * len(t.split()) for t in texts]
431
+
432
+ tok = SimpleTokenizer()
433
+
434
+ cfg = pickle.load(open(cfg_path, "rb"))
435
+ vocab_size = cfg.get("max_words") or cfg.get("vocab_size")
436
+ max_len = cfg.get("max_len")
437
+
438
+ if not vocab_size or not max_len:
439
+ continue
440
+
441
+ model = model_class(vocab_size).to(device)
442
+ model.load_state_dict(
443
+ torch.load(model_path, map_location=device, weights_only=True)
444
+ )
445
+ model.eval()
446
+
447
+ self.models.append(model)
448
+ self.tokenizers.append(tok)
449
+ self.max_lens.append(max_len)
450
+
451
+ print("[OK] Hybrid model loaded")
452
+
453
+ except Exception:
454
+ continue
455
+
456
+ def predict(self, text: str) -> list:
457
+ if not self.models:
458
+ return []
459
+
460
+ results = []
461
+
462
+ for model, tok, max_len in zip(self.models, self.tokenizers, self.max_lens):
463
+ try:
464
+ seq = tok.texts_to_sequences([text])
465
+ padded = pad_sequences(seq, maxlen=max_len, padding="pre")
466
+
467
+ x = torch.tensor(padded, dtype=torch.long).to(device)
468
+
469
+ with torch.no_grad():
470
+ probs = torch.softmax(model(x), dim=1)
471
+
472
+ conf, pred = torch.max(probs, dim=1)
473
+ label = "Real News" if pred.item() == 1 else "Fake News"
474
+
475
+ results.append((label, float(conf.item())))
476
+
477
+ except Exception:
478
+ continue
479
+
480
+ return results
481
+
482
+
483
+ hybrid_ensemble = None
484
+
485
+
486
+ def get_hybrid():
487
+ global hybrid_ensemble
488
+ if hybrid_ensemble is None:
489
+ print("[HYBRID] Lazy loading...")
490
+ hybrid_ensemble = HybridEnsemble()
491
+ return hybrid_ensemble
492
+
493
+
494
+ def predict_hybrid(text: str) -> list:
495
+ return get_hybrid().predict(text)
496
+
497
+
498
+ # ======================================================
499
+ # MODEL 3 — NAIVE (Naive Bayes / Passive-Aggressive)
500
+ # ======================================================
501
+
502
+ _naive_paths = [
503
+ os.path.join(BASE_DIR, "model", "NAIVE_", "nb_tfidf.pkl"),
504
+ os.path.join(BASE_DIR, "model", "NAIVE_", "nb_count.pkl"),
505
+ os.path.join(BASE_DIR, "model", "NAIVE_", "passive_aggressive.pkl"),
506
+ os.path.join(BASE_DIR, "model", "NAIVE_", "best_passive_aggressive.pkl"),
507
+ ]
508
+ naive_models = []
509
+ for _p in _naive_paths:
510
+ try:
511
+ naive_models.append(pickle.load(open(_p, "rb")))
512
+ except Exception:
513
+ pass
514
+ print(f"[OK] Naive models loaded ({len(naive_models)})")
515
+
516
+
517
+ def predict_naive(text: str) -> list:
518
+ results = []
519
+ for model in naive_models:
520
+ try:
521
+ probs = model.predict_proba([text])[0]
522
+ pred, conf = int(np.argmax(probs)), float(probs.max())
523
+ except Exception:
524
+ d = model.decision_function([text])[0]
525
+ pred = 1 if d > 0 else 0
526
+ conf = 1 / (1 + np.exp(-abs(d)))
527
+ results.append(("Fake News" if pred == 0 else "Real News", float(conf)))
528
+ return results
529
+
530
+
531
+ # ======================================================
532
+ # MODEL 4 — BERT
533
+ # ======================================================
534
+
535
+ BERT_CACHE_PATH = os.path.join(os.path.expanduser("~/.cache/huggingface"), "hub", "models--bert-base-uncased", "snapshots", "86b5e0934494bd15c9632b12f734a8a67f723594")
536
+ bert_tokenizer = BertTokenizerFast.from_pretrained(BERT_CACHE_PATH, local_files_only=True)
537
+ _bert_base = BertModel.from_pretrained(BERT_CACHE_PATH, local_files_only=True).to(device)
538
+ print("[OK] BERT base loaded")
539
+
540
+
541
+ class BERT_Arch(nn.Module):
542
+ def __init__(self, bert):
543
+ super().__init__()
544
+ self.bert = bert
545
+ self.fc1 = nn.Linear(768, 512)
546
+ self.fc2 = nn.Linear(512, 2)
547
+
548
+ def forward(self, sent_id, mask):
549
+ x = self.bert(sent_id, attention_mask=mask)["pooler_output"]
550
+ return self.fc2(self.fc1(x))
551
+
552
+
553
+ def _load_bert_ckpt(path: str) -> BERT_Arch:
554
+ model = BERT_Arch(_bert_base)
555
+ if os.path.exists(path):
556
+ model.load_state_dict(torch.load(path, map_location=device, weights_only=False))
557
+ model.eval()
558
+ return model
559
+
560
+
561
+ bert_models = None
562
+
563
+
564
+ def get_bert_models():
565
+ global bert_models
566
+ if bert_models is None:
567
+ print("[BERT] Lazy loading...")
568
+ bert_models = [
569
+ _load_bert_ckpt(os.path.join(BASE_DIR, "model", "BERT", "bert_model.pt")),
570
+ _load_bert_ckpt(os.path.join(BASE_DIR, "model", "BERT", "best_model.pt")),
571
+ _load_bert_ckpt(
572
+ os.path.join(BASE_DIR, "model", "BERT", "c2_new_model_weights.pt")
573
+ ),
574
+ ]
575
+ print(f"[OK] BERT loaded ({len(bert_models)})")
576
+ return bert_models
577
+
578
+
579
+ # print(f"[OK] BERT checkpoints loaded ({len(bert_models)})")
580
+
581
+
582
+ def predict_bert(text: str) -> list:
583
+ tokens = bert_tokenizer(
584
+ [text],
585
+ max_length=128,
586
+ padding="max_length",
587
+ truncation=True,
588
+ return_tensors="pt",
589
+ )
590
+ tokens = {k: v.to(device) for k, v in tokens.items()}
591
+
592
+ results = []
593
+ for model in get_bert_models():
594
+ with torch.no_grad():
595
+ out = model(tokens["input_ids"], tokens["attention_mask"])
596
+ probs = torch.softmax(out, dim=1)
597
+ pred = torch.argmax(probs, dim=1).item()
598
+ conf = probs.max().item()
599
+ # Training convention: 1 = Fake News, 0 = Real News
600
+ results.append(("Fake News" if pred == 1 else "Real News", float(conf)))
601
+ return results
602
+
603
+
604
+ # ======================================================
605
+ # MODEL 5 — DISTILBERT (HuggingFace fine-tuned)
606
+ # ======================================================
607
+
608
+ distil_model = None
609
+ distil_tokenizer = None
610
+
611
+
612
+ def get_distil():
613
+ global distil_model, distil_tokenizer
614
+ if distil_model is None:
615
+ print("[DISTIL] Lazy loading...")
616
+ path = os.path.join(BASE_DIR, "model", "DISTILBERT", "distilbert_model")
617
+
618
+ distil_tokenizer = AutoTokenizer.from_pretrained(path)
619
+ distil_model = AutoModelForSequenceClassification.from_pretrained(path).to(
620
+ device
621
+ )
622
+
623
+ distil_model.eval()
624
+ print(f"[OK] DistilBERT loaded ({1 if distil_model else 0})")
625
+
626
+ return distil_model, distil_tokenizer
627
+
628
+
629
+ def predict_distil(text: str) -> list:
630
+ try:
631
+ model, tokenizer = get_distil()
632
+
633
+ inputs = tokenizer(
634
+ text, return_tensors="pt", truncation=True, padding=True, max_length=256
635
+ )
636
+
637
+ inputs = {k: v.to(device) for k, v in inputs.items()}
638
+
639
+ with torch.no_grad():
640
+ out = model(**inputs)
641
+
642
+ probs = torch.softmax(out.logits, dim=1)
643
+ conf, pred = torch.max(probs, dim=1)
644
+
645
+ return [("Real News" if pred.item() == 1 else "Fake News", float(conf.item()))]
646
+
647
+ except Exception:
648
+ return []
649
+
650
+
651
+ # ======================================================
652
+ # ENSEMBLE FUSION
653
+ # ======================================================
654
+
655
+
656
+ def final_ensemble(all_results: list) -> tuple:
657
+ """Sum confidence scores per label; highest total wins."""
658
+ fake = sum(c for l, c in all_results if "Fake" in l) # noqa: E741
659
+ real = sum(c for l, c in all_results if "Real" in l) # noqa: E741
660
+ total = fake + real
661
+ if total == 0:
662
+ return "Real News", 0.5
663
+ label = "Fake News" if fake > real else "Real News"
664
+ return label, round(max(fake, real) / total, 4)
665
+
666
+
667
+ def format_output(raw: dict) -> dict:
668
+ return {
669
+ k: [{"prediction": l, "confidence": round(c, 4)} for l, c in v] # noqa: E741
670
+ for k, v in raw.items()
671
+ }
672
+
673
+
674
+ # ======================================================
675
+ # ROUTES
676
+ # ======================================================
677
+
678
+
679
+ @app.route("/", methods=["GET"])
680
+ def index():
681
+ return jsonify(
682
+ {
683
+ "message": "Welcome to TruthX API",
684
+ "endpoints": {
685
+ "POST /generate_key": "Get a new API key",
686
+ "POST /verify": "Full ensemble prediction (all models)",
687
+ "POST /predict/<model>": "Individual model prediction (nlp, hybrid, naive, bert, distilbert)",
688
+ "GET /test_hybrid": "Check how many hybrid models are loaded",
689
+ },
690
+ }
691
+ )
692
+
693
+
694
+ @app.route("/test_hybrid", methods=["GET"])
695
+ def test_hybrid():
696
+ """Quick diagnostic: check loaded hybrid models."""
697
+ try:
698
+ ensemble = get_hybrid()
699
+ return jsonify(
700
+ {
701
+ "hybrid_models_loaded": len(ensemble.models),
702
+ "configs": [
703
+ {"max_len": ml, "vocab_size": tok.num_words}
704
+ if hasattr(tok, "num_words")
705
+ else {"max_len": ml, "vocab_size": "unknown"}
706
+ for tok, ml in zip(ensemble.tokenizers, ensemble.max_lens)
707
+ ],
708
+ }
709
+ )
710
+ except Exception as e:
711
+ return jsonify({"error": str(e)}), 500
712
+
713
+
714
+ @app.route("/generate_key", methods=["POST"])
715
+ def generate_key():
716
+ """Generate and persist a new UUID API key."""
717
+ body = request.json if isinstance(request.json, dict) else {}
718
+ user_id = body.get("user_id", "anonymous")
719
+ new_key = str(uuid.uuid4())
720
+ truthx_api_keys[new_key] = user_id
721
+ save_truthx_api_keys(truthx_api_keys)
722
+ return jsonify(
723
+ {
724
+ "status": "success",
725
+ "api_key": new_key,
726
+ "message": "Store this key — required for all /predict and /verify",
727
+ }
728
+ )
729
+
730
+
731
+ def _get_request_text():
732
+ data = request.get_json(silent=True)
733
+ if not data or "text" not in data:
734
+ return None, "Provide 'text' in request body"
735
+ text = data["text"].strip()
736
+ if not text:
737
+ return None, "Empty text"
738
+ return text, None
739
+
740
+
741
+ @app.route("/predict/nlp", methods=["POST"])
742
+ @require_api_key
743
+ def predict_nlp_endpoint():
744
+ text, err = _get_request_text()
745
+ if err:
746
+ return jsonify({"error": err}), 400
747
+ return jsonify({"prediction": predict_nlp(text)})
748
+
749
+
750
+ @app.route("/predict/hybrid", methods=["POST"])
751
+ @require_api_key
752
+ def predict_hybrid_endpoint():
753
+ text, err = _get_request_text()
754
+ if err:
755
+ return jsonify({"error": err}), 400
756
+ return jsonify({"prediction": predict_hybrid(text)})
757
+
758
+
759
+ @app.route("/predict/naive", methods=["POST"])
760
+ @require_api_key
761
+ def predict_naive_endpoint():
762
+ text, err = _get_request_text()
763
+ if err:
764
+ return jsonify({"error": err}), 400
765
+ return jsonify({"prediction": predict_naive(text)})
766
+
767
+
768
+ @app.route("/predict/bert", methods=["POST"])
769
+ @require_api_key
770
+ def predict_bert_endpoint():
771
+ text, err = _get_request_text()
772
+ if err:
773
+ return jsonify({"error": err}), 400
774
+ return jsonify({"prediction": predict_bert(text)})
775
+
776
+
777
+ @app.route("/predict/distilbert", methods=["POST"])
778
+ @require_api_key
779
+ def predict_distilbert_endpoint():
780
+ text, err = _get_request_text()
781
+ if err:
782
+ return jsonify({"error": err}), 400
783
+ return jsonify({"prediction": predict_distil(text)})
784
+
785
+
786
+ @app.route("/verify", methods=["POST"])
787
+ @require_api_key
788
+ def verify():
789
+ """
790
+ Run full ensemble on submitted news article.
791
+ Header : X-API-KEY: <key>
792
+ Body : { "title": "...", "text": "..." }
793
+ """
794
+ try:
795
+ data = request.get_json(silent=True)
796
+ if not data or "text" not in data:
797
+ return jsonify({"error": "Provide 'text' in request body"}), 400
798
+
799
+ text = data["text"].strip()
800
+ external = data.get("title", text[:100])
801
+ title = data.get("title", text)
802
+
803
+ if not text:
804
+ return jsonify({"error": "Empty text"}), 400
805
+
806
+ full_doc = f"{title} {text}".strip()
807
+
808
+ # Wrap each model in try/except so one failure doesn't kill the whole request
809
+ def safe(fn):
810
+ try:
811
+ return fn(full_doc)
812
+ except Exception as e:
813
+ print(f"[MODEL ERROR] {fn.__name__}: {e}")
814
+ return []
815
+
816
+ raw = {
817
+ "nlp": safe(predict_nlp),
818
+ "hybrid": safe(predict_hybrid),
819
+ "naive": safe(predict_naive),
820
+ "bert": safe(predict_bert),
821
+ "distilbert": safe(predict_distil),
822
+ }
823
+
824
+ all_preds = [p for preds in raw.values() for p in preds]
825
+ final_label, model_conf = final_ensemble(all_preds)
826
+
827
+ ext_score = check_external_news(external)
828
+ final_accuracy = round((model_conf * 0.7 + ext_score * 0.3) * 100, 2)
829
+
830
+ return jsonify(
831
+ {
832
+ "title": title,
833
+ "prediction": final_label,
834
+ "confidence": model_conf,
835
+ "accuracy": f"{final_accuracy}%",
836
+ "external_score": round(ext_score, 4),
837
+ "models": format_output(raw),
838
+ }
839
+ )
840
+
841
+ except Exception as e:
842
+ traceback.print_exc()
843
+ return jsonify({"error": str(e)}), 500
844
+
845
+
846
+ # ==============================
847
+ # RUN
848
+ # ==============================
849
+ if __name__ == "__main__":
850
+ app.run(host="0.0.0.0", port=5000, debug=False)
dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements first for better caching
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy application files
15
+ COPY app.py .
16
+ COPY .env .
17
+
18
+ # Initialize api_keys.json if it doesn't exist
19
+ RUN if [ ! -f api_keys.json ]; then echo "{}" > api_keys.json; fi
20
+ RUN chmod 666 api_keys.json
21
+
22
+ # Standard Hugging Face Space port
23
+ EXPOSE 7860
24
+
25
+ # Run the Flask app
26
+ CMD ["python", "app.py"]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ torch
3
+ transformers
4
+ numpy
5
+ requests
6
+ python-dotenv
7
+ nltk
8
+ scikit-learn
9
+ sentencepiece
10
+ protobuf
11
+ huggingface-hub
12
+ accelerate
requirements_space.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ torch
3
+ transformers
4
+ numpy
5
+ requests
6
+ python-dotenv
7
+ nltk
8
+ scikit-learn
9
+ sentencepiece
10
+ protobuf
11
+ huggingface-hub
12
+ accelerate