trinadh-33 commited on
Commit
c2b56dc
·
verified ·
1 Parent(s): 3c49150

Upload 8 files

Browse files
Files changed (4) hide show
  1. main.py +247 -80
  2. model.pkl +3 -0
  3. requirements.txt +1 -4
  4. vectorizer.pkl +3 -0
main.py CHANGED
@@ -1,18 +1,15 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
 
4
  import re
5
  import os
6
  import uuid
7
- import json
8
- import numpy as np
9
- import torch
10
  from datetime import datetime
11
  from lime.lime_text import LimeTextExplainer
12
- from typing import List
13
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
14
 
15
- app = FastAPI(title="XAI Semantic Phishing Detector API")
16
 
17
  app.add_middleware(
18
  CORSMiddleware,
@@ -22,44 +19,40 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
 
25
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
26
- MODEL_DIR = os.path.join(BASE_DIR, "bert_phishing_model")
27
 
28
- # ---- Load BERT Model & Pipeline ----
29
- classifier = None
30
  try:
31
- print("Loading Semantic DistilBERT Engine...")
32
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
33
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
34
-
35
- device = 0 if torch.cuda.is_available() else -1
36
- classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device, top_k=None)
37
- print("Semantic Engine loaded successfully")
38
  except Exception as e:
39
- print(f"Warning: BERT Model not found. Error: {e}")
 
 
40
 
41
- # ---- Load SHAP-derived keywords ----
42
- KEYWORDS_PATH = os.path.join(BASE_DIR, "keywords.json")
43
-
44
- def load_keywords():
45
- if os.path.exists(KEYWORDS_PATH):
46
- with open(KEYWORDS_PATH) as f:
47
- return json.load(f)
48
-
49
- return {
50
- "verify": {"message": "Requests to verify identity are a classic phishing tactic.", "shap_score": 0.0},
51
- "password": {"message": "Never send or request passwords through email.", "shap_score": 0.0},
52
- "httplink": {"message": "Insecure HTTP link detected — legitimate emails use HTTPS.", "shap_score": 0.0},
53
- }
54
 
55
- PHISHING_KEYWORDS = load_keywords()
56
  scan_history: List[dict] = []
57
 
 
58
  class EmailRequest(BaseModel):
59
  email_text: str
60
 
 
 
 
 
 
61
  class ScanResult(BaseModel):
62
- cleaned_text: str
63
  id: str
64
  timestamp: str
65
  email_preview: str
@@ -69,6 +62,7 @@ class ScanResult(BaseModel):
69
  lime_words: List[dict]
70
  recommendations: List[str]
71
 
 
72
  def clean_text(text: str) -> str:
73
  text = text.lower()
74
  def expand_url(match):
@@ -82,91 +76,205 @@ def clean_text(text: str) -> str:
82
  text = re.sub(r"[^a-z\s]", "", text)
83
  return text
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def generate_recommendations(exp_list, email_text: str, pred: int) -> List[str]:
86
  rec = set()
87
  for word, score in exp_list:
88
  w = word.lower()
89
  if score > 0 and w in PHISHING_KEYWORDS:
90
- rec.add(f"⚠ '{word}' detected: {PHISHING_KEYWORDS[w]['message']}")
91
- for key, info in PHISHING_KEYWORDS.items():
92
  if key in email_text.lower():
93
- rec.add(f"⚠ '{key}' detected: {info['message']}")
94
  if pred == 1:
95
  rec.add("🚨 Strong phishing signal. Do NOT click links or share personal info.")
96
- if re.search(r"http://", email_text, re.IGNORECASE):
97
- rec.add("⚠ Insecure HTTP link detected — legitimate emails use HTTPS.")
98
  return list(rec)
99
 
100
  def predict_proba_fn(texts):
101
- cleaned_texts = [clean_text(t) for t in texts]
102
- results = classifier(cleaned_texts, batch_size=16, truncation=True, max_length=512)
103
-
104
- probs = []
105
- for res in results:
106
- safe_score = next(x['score'] for x in res if x['label'] == 'LABEL_0')
107
- phish_score = next(x['score'] for x in res if x['label'] == 'LABEL_1')
108
- probs.append([safe_score, phish_score])
109
-
110
- return np.array(probs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  @app.get("/")
113
  def root():
114
- return {"status": "XAI Semantic Phishing Detector API is running"}
115
 
116
  @app.get("/health")
117
  def health():
118
  return {
119
  "status": "ok",
120
- "model_loaded": classifier is not None,
121
- "engine_type": "DistilBERT"
122
- }
123
-
124
- @app.get("/keywords")
125
- def get_keywords():
126
- return {
127
- "keywords": list(PHISHING_KEYWORDS.keys()),
128
- "count": len(PHISHING_KEYWORDS)
129
  }
130
 
131
  @app.post("/analyze", response_model=ScanResult)
132
  def analyze_email(request: EmailRequest):
133
- if classifier is None:
134
- raise HTTPException(status_code=503, detail="Model not loaded.")
135
 
136
  email_text = request.email_text.strip()
 
 
 
137
  cleaned = clean_text(email_text)
 
 
 
 
 
 
138
 
139
- # 1. Get Prediction from BERT
140
- bert_result = classifier([cleaned], truncation=True, max_length=512)[0]
141
- safe_score = next(x['score'] for x in bert_result if x['label'] == 'LABEL_0')
142
- phish_score = next(x['score'] for x in bert_result if x['label'] == 'LABEL_1')
143
-
144
- pred = 1 if phish_score > safe_score else 0
145
- confidence = float(phish_score if pred == 1 else safe_score)
146
- risk_score = float(phish_score)
147
-
148
- # 2. LIME Explanation
149
- explainer = LimeTextExplainer(class_names=["Safe", "Phishing"], random_state=42)
150
  tokens = [t for t in cleaned.split() if len(t) > 1]
151
  num_features = max(20, min(len(tokens), 60))
152
 
153
  exp = explainer.explain_instance(
154
- cleaned, predict_proba_fn, labels=(1,), num_features=num_features, num_samples=200
 
 
 
155
  )
156
- exp_list = exp.as_list(label=1)
 
 
 
 
 
157
 
158
- lime_words = [{"word": word, "score": float(score)} for word, score in exp_list]
159
  recommendations = generate_recommendations(exp_list, email_text, pred)
160
 
161
  result = {
162
- "id": str(uuid.uuid4()),
163
- "timestamp": datetime.now().isoformat(),
164
  "email_preview": email_text[:120] + ("..." if len(email_text) > 120 else ""),
165
- "cleaned_text": cleaned,
166
- "prediction": pred,
167
- "confidence": round(confidence, 4),
168
- "risk_score": round(risk_score, 4),
169
- "lime_words": lime_words,
170
  "recommendations": recommendations,
171
  }
172
 
@@ -176,6 +284,65 @@ def analyze_email(request: EmailRequest):
176
 
177
  return result
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  @app.get("/history")
180
  def get_history():
181
  return scan_history
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
+ import pickle
5
  import re
6
  import os
7
  import uuid
 
 
 
8
  from datetime import datetime
9
  from lime.lime_text import LimeTextExplainer
10
+ from typing import List, Optional
 
11
 
12
+ app = FastAPI(title="XAI Phishing Detector API")
13
 
14
  app.add_middleware(
15
  CORSMiddleware,
 
19
  allow_headers=["*"],
20
  )
21
 
22
+ # ---- Load model & vectorizer ----
23
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
24
 
 
 
25
  try:
26
+ model = pickle.load(open(os.path.join(BASE_DIR, "model.pkl"), "rb"))
27
+ vectorizer = pickle.load(open(os.path.join(BASE_DIR, "vectorizer.pkl"), "rb"))
28
+ print("Model loaded successfully")
 
 
 
 
29
  except Exception as e:
30
+ print(f"Warning: Model not found. Run train.py first. Error: {e}")
31
+ model = None
32
+ vectorizer = None
33
 
34
+ # ---- Load BERT MLM for contextual word suggestions ----
35
+ mlm_pipeline = None
36
+ try:
37
+ from transformers import pipeline as hf_pipeline
38
+ mlm_pipeline = hf_pipeline("fill-mask", model="distilbert-base-uncased", top_k=30)
39
+ print("DistilBERT MLM loaded successfully for word suggestions")
40
+ except Exception as e:
41
+ print(f"Warning: BERT MLM not loaded. /suggest will be unavailable. Error: {e}")
 
 
 
 
 
42
 
43
+ # ---- In-memory scan history ----
44
  scan_history: List[dict] = []
45
 
46
+ # ---- Schemas ----
47
  class EmailRequest(BaseModel):
48
  email_text: str
49
 
50
+ class SuggestRequest(BaseModel):
51
+ email_text: str
52
+ lime_words: List[dict] # [{"word": str, "score": float}, ...]
53
+ skip_words: List[str] = [] # words already shown — skip these
54
+
55
  class ScanResult(BaseModel):
 
56
  id: str
57
  timestamp: str
58
  email_preview: str
 
62
  lime_words: List[dict]
63
  recommendations: List[str]
64
 
65
+ # ---- Helpers ----
66
  def clean_text(text: str) -> str:
67
  text = text.lower()
68
  def expand_url(match):
 
76
  text = re.sub(r"[^a-z\s]", "", text)
77
  return text
78
 
79
+ def get_risk_score(text: str) -> float:
80
+ """Return phishing probability for a piece of text."""
81
+ vec = vectorizer.transform([clean_text(text)])
82
+ return float(model.predict_proba(vec)[0][1])
83
+
84
+ PHISHING_KEYWORDS = {
85
+ "verify": "Do not ask users to verify sensitive info via email.",
86
+ "password": "Never request passwords through email.",
87
+ "login": "Avoid embedding login links — direct to official site.",
88
+ "bank": "Avoid impersonating banks or financial institutions.",
89
+ "urgent": "Creating urgency is a classic phishing tactic.",
90
+ "click": "Avoid forcing users to click suspicious links.",
91
+ "account": "Avoid threatening account suspension.",
92
+ "update": "Avoid forcing users to perform updates via email.",
93
+ "security": "Avoid fake security warnings.",
94
+ "confirm": "Avoid asking users to confirm personal details.",
95
+ "suspended": "Account suspension threats are common in phishing.",
96
+ "winner": "Prize/winner announcements are common phishing tactics.",
97
+ "free": "Unsolicited free offers are a phishing red flag.",
98
+ "prize": "Lottery/prize claims are common phishing tactics.",
99
+ }
100
+
101
  def generate_recommendations(exp_list, email_text: str, pred: int) -> List[str]:
102
  rec = set()
103
  for word, score in exp_list:
104
  w = word.lower()
105
  if score > 0 and w in PHISHING_KEYWORDS:
106
+ rec.add(f"⚠ '{word}' detected: {PHISHING_KEYWORDS[w]}")
107
+ for key, msg in PHISHING_KEYWORDS.items():
108
  if key in email_text.lower():
109
+ rec.add(f"⚠ '{key}' detected: {msg}")
110
  if pred == 1:
111
  rec.add("🚨 Strong phishing signal. Do NOT click links or share personal info.")
 
 
112
  return list(rec)
113
 
114
  def predict_proba_fn(texts):
115
+ cleaned = [clean_text(t) for t in texts]
116
+ vec = vectorizer.transform(cleaned)
117
+ return model.predict_proba(vec)
118
+
119
+ # ---- /suggest logic ----
120
+ # Words to never suggest as replacements they'd look obviously wrong in an email
121
+ STOPWORDS = {
122
+ "the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
123
+ "have", "has", "had", "do", "does", "did", "will", "would", "could",
124
+ "should", "may", "might", "shall", "can", "need", "dare", "ought",
125
+ "used", "to", "of", "in", "on", "at", "by", "for", "with", "about",
126
+ "as", "into", "through", "during", "before", "after", "above", "below",
127
+ "from", "up", "down", "out", "off", "over", "under", "again", "then",
128
+ "once", "here", "there", "when", "where", "why", "how", "all", "both",
129
+ "each", "few", "more", "most", "other", "some", "such", "no", "not",
130
+ "only", "own", "same", "so", "than", "too", "very", "just", "but",
131
+ "and", "or", "nor", "yet", "also", "this", "that", "these", "those",
132
+ "i", "me", "my", "we", "our", "you", "your", "he", "she", "it", "its",
133
+ "they", "them", "their", "what", "which", "who", "whom", "any", "if",
134
+ }
135
+
136
+ def get_mlm_candidates(email_text: str, target_word: str, topn: int = 30) -> List[str]:
137
+ """
138
+ Use BERT masked language modeling to find contextually appropriate
139
+ replacements for target_word in the email text.
140
+ Returns up to topn candidate words.
141
+ """
142
+ if mlm_pipeline is None:
143
+ return []
144
+
145
+ # Find first occurrence and mask it
146
+ pattern = re.compile(r'\b' + re.escape(target_word) + r'\b', re.IGNORECASE)
147
+ match = pattern.search(email_text)
148
+ if not match:
149
+ return []
150
+
151
+ # Replace first occurrence with [MASK]
152
+ masked_text = email_text[:match.start()] + "[MASK]" + email_text[match.end():]
153
+
154
+ # Truncate to 512 tokens worth of text (BERT's limit) — keep context around mask
155
+ # Simple truncation: keep first ~400 words
156
+ words = masked_text.split()
157
+ if len(words) > 400:
158
+ mask_pos = next((i for i, w in enumerate(words) if "[MASK]" in w), 0)
159
+ start = max(0, mask_pos - 200)
160
+ end = min(len(words), mask_pos + 200)
161
+ masked_text = " ".join(words[start:end])
162
+
163
+ try:
164
+ predictions = mlm_pipeline(masked_text)
165
+ candidates = []
166
+ for pred in predictions:
167
+ token = pred["token_str"].strip().lower()
168
+ if (token.isalpha()
169
+ and len(token) > 2
170
+ and token != target_word.lower()
171
+ and token not in STOPWORDS):
172
+ candidates.append(token)
173
+ return candidates[:topn]
174
+ except Exception:
175
+ return []
176
+
177
+ def find_suggestions(email_text: str, original_word: str, base_risk: float) -> List[dict]:
178
+ """
179
+ For a single phishing word, find up to 3 substitutions using BERT MLM that:
180
+ - fit the sentence context naturally
181
+ - are NOT themselves known phishing keywords
182
+ - actually lower or maintain the model's risk score when swapped in
183
+ Returns list of {word, new_risk, delta} sorted by delta (biggest drop first).
184
+ """
185
+ candidates = get_mlm_candidates(email_text, original_word, topn=30)
186
+ results = []
187
+
188
+ # Words to exclude as candidates — known phishing terms
189
+ phishing_blocklist = set(PHISHING_KEYWORDS.keys()) | {
190
+ "verify", "validate", "confirm", "click", "login", "account", "bank",
191
+ "password", "suspended", "urgent", "security", "winner", "free", "prize",
192
+ "expire", "update", "credit", "payment", "invoice", "authorize", "alert",
193
+ "ssn", "transfer", "wire", "reward", "claim", "restore", "access",
194
+ "immediately", "locked", "blocked", "notification", "dear",
195
+ }
196
+
197
+ # Build regex that matches the whole word (case-insensitive)
198
+ pattern = re.compile(r'\b' + re.escape(original_word) + r'\b', re.IGNORECASE)
199
 
200
+ for candidate in candidates:
201
+ # Skip if candidate is a phishing keyword itself
202
+ if candidate.lower() in phishing_blocklist:
203
+ continue
204
+ # Replace all occurrences in the raw email text
205
+ modified = pattern.sub(candidate, email_text)
206
+ new_risk = get_risk_score(modified)
207
+ delta = base_risk - new_risk # positive = risk dropped
208
+ # Only include if it doesn't increase risk
209
+ if delta >= 0:
210
+ results.append({
211
+ "word": candidate,
212
+ "new_risk": round(new_risk, 4),
213
+ "delta": round(delta, 4),
214
+ })
215
+
216
+ # Sort by biggest risk drop, return top 3
217
+ results.sort(key=lambda x: -x["delta"])
218
+ return results[:3]
219
+
220
+ # ---- Routes ----
221
  @app.get("/")
222
  def root():
223
+ return {"status": "XAI Phishing Detector API is running"}
224
 
225
  @app.get("/health")
226
  def health():
227
  return {
228
  "status": "ok",
229
+ "model_loaded": model is not None,
230
+ "vectorizer_loaded": vectorizer is not None,
231
+ "mlm_loaded": mlm_pipeline is not None,
 
 
 
 
 
 
232
  }
233
 
234
  @app.post("/analyze", response_model=ScanResult)
235
  def analyze_email(request: EmailRequest):
236
+ if model is None or vectorizer is None:
237
+ raise HTTPException(status_code=503, detail="Model not loaded. Run train.py first.")
238
 
239
  email_text = request.email_text.strip()
240
+ if not email_text:
241
+ raise HTTPException(status_code=400, detail="Email text cannot be empty.")
242
+
243
  cleaned = clean_text(email_text)
244
+ vec = vectorizer.transform([cleaned])
245
+
246
+ pred = int(model.predict(vec)[0])
247
+ proba = model.predict_proba(vec)[0]
248
+ confidence = float(proba[pred])
249
+ risk_score = float(proba[1])
250
 
251
+ explainer = LimeTextExplainer(class_names=["Safe", "Phishing"])
 
 
 
 
 
 
 
 
 
 
252
  tokens = [t for t in cleaned.split() if len(t) > 1]
253
  num_features = max(20, min(len(tokens), 60))
254
 
255
  exp = explainer.explain_instance(
256
+ cleaned,
257
+ predict_proba_fn,
258
+ num_features=num_features,
259
+ num_samples=1000,
260
  )
261
+ exp_list = exp.as_list()
262
+
263
+ lime_words = [
264
+ {"word": word, "score": float(score)}
265
+ for word, score in exp_list
266
+ ]
267
 
 
268
  recommendations = generate_recommendations(exp_list, email_text, pred)
269
 
270
  result = {
271
+ "id": str(uuid.uuid4()),
272
+ "timestamp": datetime.now().isoformat(),
273
  "email_preview": email_text[:120] + ("..." if len(email_text) > 120 else ""),
274
+ "prediction": pred,
275
+ "confidence": round(confidence, 4),
276
+ "risk_score": round(risk_score, 4),
277
+ "lime_words": lime_words,
 
278
  "recommendations": recommendations,
279
  }
280
 
 
284
 
285
  return result
286
 
287
+ @app.post("/suggest")
288
+ def suggest_evasions(request: SuggestRequest):
289
+ """
290
+ Given an analyzed email and its LIME word scores, return substitution
291
+ suggestions for the top 4 phishing-signal words.
292
+
293
+ Response shape:
294
+ [
295
+ {
296
+ "original": "verify",
297
+ "lime_score": 0.045,
298
+ "suggestions": [
299
+ { "word": "check", "new_risk": 0.41, "delta": 0.18 },
300
+ ...
301
+ ]
302
+ },
303
+ ...
304
+ ]
305
+ """
306
+ if model is None or vectorizer is None:
307
+ raise HTTPException(status_code=503, detail="Model not loaded.")
308
+ if mlm_pipeline is None:
309
+ raise HTTPException(status_code=503, detail="BERT MLM not loaded.")
310
+
311
+ email_text = request.email_text.strip()
312
+ if not email_text:
313
+ raise HTTPException(status_code=400, detail="Email text cannot be empty.")
314
+
315
+ # Current risk score (ground truth from the raw email text)
316
+ base_risk = get_risk_score(email_text)
317
+
318
+ # Pick positive LIME words that appear in the email, skip already-shown words
319
+ skip_set = set(w.lower() for w in request.skip_words)
320
+ top_phish_words = [
321
+ w for w in sorted(request.lime_words, key=lambda x: -x["score"])
322
+ if w["score"] > 0
323
+ and w["word"].lower() not in skip_set
324
+ and re.search(r'\b' + re.escape(w["word"]) + r'\b', email_text, re.IGNORECASE)
325
+ ][:6]
326
+
327
+ if not top_phish_words:
328
+ return []
329
+
330
+ output = []
331
+ for entry in top_phish_words:
332
+ word = entry["word"]
333
+ suggestions = find_suggestions(email_text, word, base_risk)
334
+ pattern = re.compile(r"\b" + re.escape(word) + r"\b", re.IGNORECASE)
335
+ occurrences = len(pattern.findall(email_text))
336
+ output.append({
337
+ "original": word,
338
+ "lime_score": round(entry["score"], 4),
339
+ "base_risk": round(base_risk, 4),
340
+ "occurrences": occurrences,
341
+ "suggestions": suggestions,
342
+ })
343
+
344
+ return output
345
+
346
  @app.get("/history")
347
  def get_history():
348
  return scan_history
model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd8c9d65c76769a559d5f080a1ea6bcf72481664a38c43bd1fb60a242106c5dc
3
+ size 1163695
requirements.txt CHANGED
@@ -5,7 +5,4 @@ scikit-learn==1.4.2
5
  pandas==2.2.2
6
  lime==0.2.0.1
7
  python-multipart==0.0.9
8
- numpy==1.26.4
9
- shap==0.45.1
10
- torch
11
- transformers
 
5
  pandas==2.2.2
6
  lime==0.2.0.1
7
  python-multipart==0.0.9
8
+ spacy==3.8.14
 
 
 
vectorizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65fa33f137cafd28a980fb840d00293cfdd8e3459173eb1993e4735f2c63643b
3
+ size 27588922