File size: 1,831 Bytes
a7bb4a3
48c17bd
 
a7bb4a3
48c17bd
 
a7bb4a3
 
 
 
 
 
 
 
48c17bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d04e8ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import gradio as gr
import torch
from dotenv import load_dotenv
from transformers import AutoModelForSequenceClassification, AutoTokenizer

load_dotenv()

MODEL_NAME = os.getenv("MODEL_NAME", "Ankit74990/TruthX-DISTILBERT")

NEWSDATA_API_KEY = os.getenv("NEWSDATA_API_KEY", "")
NEWSAPI_API_KEY = os.getenv("NEWSAPI_API_KEY", "")
GNEWS_API_KEY = os.getenv("GNEWS_API_KEY", "")
MEDIASTACK_API_KEY = os.getenv("MEDIASTACK_API_KEY", "")

tokenizer = None
model = None

def load_model():
    global tokenizer, model
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
        model.eval()

def predict(text):
    load_model()
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    fake_prob = probs[0][0].item()
    real_prob = probs[0][1].item()
    label = "Real News" if fake_prob < real_prob else "Fake News"
    confidence = max(fake_prob, real_prob)
    return {"label": label, "confidence": f"{confidence:.2%}", "fake_probability": f"{fake_prob:.2%}", "real_probability": f"{real_prob:.2%}"}

demo = gr.Interface(
    fn=predict,
    inputs="text",
    outputs=gr.Label(label="Prediction Result"),
    title="TruthX - Fake News Detector",
    description="Enter a news article or headline to determine if it's real or fake news.",
    examples=[
        ["Breaking: Scientists discover new cure for cancer"],
        ["Government announces new policy changes for education"],
        ["Massive scandal revealed involving major corporation"],
    ]
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)