Nick-2x commited on
Commit
103e422
·
verified ·
1 Parent(s): 9c8c464

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+
6
+ app = FastAPI()
7
+
8
+ # NEW MODEL: Multimodal Phishing Detector (URLs, SMS, Email)
9
+ MODEL_ID = "ealvaradob/bert-finetuned-phishing"
10
+
11
+ print("Loading model... This might take a minute as it's a 'large' BERT model.")
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
14
+
15
+ class URLInput(BaseModel):
16
+ url: str
17
+
18
+ @app.get("/")
19
+ async def root():
20
+ return {"status": "URL Phishing Detector API is running"}
21
+
22
+ @app.post("/predict")
23
+ async def predict_url(data: URLInput):
24
+ # 1. Basic Pre-check
25
+ if not data.url or len(data.url) < 4:
26
+ return {"error": "Invalid URL provided"}
27
+
28
+ # 2. Tokenize and Predict
29
+ inputs = tokenizer(data.url, return_tensors="pt", truncation=True, max_length=512)
30
+
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+ # Apply Softmax to get percentages
34
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
35
+
36
+ probs = predictions[0].tolist()
37
+
38
+ # 3. Dynamic Label Mapping
39
+ # The model usually uses LABEL_0 (Legitimate) and LABEL_1 (Phishing)
40
+ confidences = {model.config.id2label[i]: prob for i, prob in enumerate(probs)}
41
+
42
+ # Identify the highest confidence label
43
+ max_label = max(confidences.items(), key=lambda x: x[1])
44
+ label_name = max_label[0]
45
+
46
+ # Check for "LABEL_1" or "phishing" keyword in the output
47
+ is_phishing = "1" in label_name or "phishing" in label_name.lower()
48
+
49
+ return {
50
+ "url": data.url,
51
+ "prediction": "phishing" if is_phishing else "legitimate",
52
+ "confidence": round(max_label[1], 4),
53
+ "raw_scores": confidences,
54
+ "is_malicious": is_phishing
55
+ }
56
+
57
+ if __name__ == "__main__":
58
+ import uvicorn
59
+ # 7860 is the standard port for Hugging Face Spaces
60
+ uvicorn.run(app, host="0.0.0.0", port=7860)