DevPatel0611 commited on
Commit
86b932c
·
0 Parent(s):

Clean build with correct gitignore

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +2 -0
  3. README.md +102 -0
  4. TruthLens_Paper.tex +84 -0
  5. app.py +641 -0
  6. config/config.yaml +60 -0
  7. models/saved/distilbert_model/cm_distil.png +0 -0
  8. models/saved/distilbert_model/config.json +28 -0
  9. models/saved/distilbert_model/distilbert_oof.npy +3 -0
  10. models/saved/distilbert_model/model.safetensors +3 -0
  11. models/saved/distilbert_model/tokenizer.json +0 -0
  12. models/saved/distilbert_model/tokenizer_config.json +14 -0
  13. models/saved/distilbert_model/training_args.bin +3 -0
  14. models/saved/logistic_model/cm.png +0 -0
  15. models/saved/logistic_model/logistic_model.pkl +3 -0
  16. models/saved/logistic_model/lr_oof.npy +3 -0
  17. models/saved/logistic_model/metrics.json +8 -0
  18. models/saved/lstm_model/cm.png +0 -0
  19. models/saved/lstm_model/lstm_oof.npy +3 -0
  20. models/saved/lstm_model/metrics.json +8 -0
  21. models/saved/lstm_model/model.pt +3 -0
  22. models/saved/meta_classifier/cm_meta.png +0 -0
  23. models/saved/meta_classifier/meta_classifier.pkl +3 -0
  24. models/saved/meta_classifier/metrics.json +8 -0
  25. models/saved/roberta_model/cm_roberta.png +0 -0
  26. models/saved/roberta_model/config.json +29 -0
  27. models/saved/roberta_model/model.safetensors +3 -0
  28. models/saved/roberta_model/roberta_oof.npy +3 -0
  29. models/saved/roberta_model/tokenizer.json +0 -0
  30. models/saved/roberta_model/tokenizer_config.json +16 -0
  31. models/saved/roberta_model/training_args.bin +3 -0
  32. models/saved/tokenizer.pkl +3 -0
  33. requirements.txt +18 -0
  34. run_pipeline.py +100 -0
  35. src/__init__.py +1 -0
  36. src/models/__init__.py +1 -0
  37. src/models/distilbert_model.py +201 -0
  38. src/models/logistic_model.py +141 -0
  39. src/models/lstm_model.py +314 -0
  40. src/models/meta_classifier.py +229 -0
  41. src/models/roberta_model.py +198 -0
  42. src/stage1_ingestion.py +728 -0
  43. src/stage2_preprocessing.py +186 -0
  44. src/stage3_training.py +41 -0
  45. src/stage4_inference.py +867 -0
  46. src/utils/__init__.py +1 -0
  47. src/utils/deduplication.py +256 -0
  48. src/utils/domain_weights.py +102 -0
  49. src/utils/freshness.py +71 -0
  50. src/utils/rag_retrieval.py +157 -0
.gitattributes ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
+ *.pt filter=lfs diff=lfs merge=lfs -text
4
+ *.pth filter=lfs diff=lfs merge=lfs -text
5
+ *.pkl filter=lfs diff=lfs merge=lfs -text
6
+ *.npy filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ data/
2
+ __pycache__/
README.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TruthLens: Advanced Fake News Detection Pipeline
2
+
3
+ TruthLens is an end-to-end fake news detection system that moves beyond simple machine learning probabilities. It employs a robust **5-signal weighted scoring framework** built on journalistic standards, combining deep learning models (DistilBERT, RoBERTa), sequence models (LSTM), statistical models (Logistic Regression), and heuristic analysis to deliver explainable verdicts.
4
+
5
+ ## 🌟 Key Features
6
+
7
+ * **5-Signal Scoring Framework:**
8
+ * **Source Credibility (30%):** Evaluates outlet reputation, author presence, and source corroboration, including typosquatting checks.
9
+ * **Claim Verification (30%):** Combines AI probability with spaCy-based Named Entity Recognition (NER) and quote attribution analysis.
10
+ * **Linguistic Quality (20%):** Detects sensationalism, superlatives, passive voice, and uses DistilBERT to check if the headline contradicts the body.
11
+ * **Freshness (10%):** Contextual and date-based temporal scoring to detect outdated information.
12
+ * **AI Model Consensus (10%):** Ensemble voting from Logistic Regression, LSTM, DistilBERT, and RoBERTa.
13
+ * **Adversarial Guardrails:** Hard caps and overrides for highly suspicious patterns (Triple Anonymity, Uncited Statistics, Headline Contradictions).
14
+ * **Live Web Corroboration:** RAG (Retrieval-Augmented Generation) pipeline using live search to verify unambiguous claims.
15
+ * **TruthLens UI:** A sleek, dark/light mode adaptable Streamlit dashboard providing detailed explainability down to the specific signals and deductions.
16
+
17
+ ---
18
+
19
+ ## 📁 Project Structure
20
+
21
+ ```text
22
+ fake_news_detection/
23
+ ├── app.py # Streamlit frontend (TruthLens UI)
24
+ ├── run_pipeline.py # Main script to run pipeline stages
25
+ ├── requirements.txt # Python dependencies
26
+ ├── src/
27
+ │ ├── stage1_ingestion.py # Downloads and prepares datasets
28
+ │ ├── stage2_preprocessing.py# Cleans text, tokenizes, and saves artifacts
29
+ │ ├── stage3_training.py # Trains models (LR, LSTM, DistilBERT, RoBERTa)
30
+ │ ├── stage4_inference.py # The 5-signal scoring engine and prediction logic
31
+ │ └── utils/
32
+ │ └── rag_retrieval.py # Live web search corroboration functions
33
+ ├── data/ # Raw and processed datasets (created during execution)
34
+ └── models/ # Trained models and vectorizers (created during execution)
35
+ ```
36
+
37
+ ---
38
+
39
+ ## 🚀 Getting Started
40
+
41
+ ### 1. Installation
42
+
43
+ Ensure you have Python 3.8+ installed. Install the required dependencies:
44
+
45
+ ```bash
46
+ pip install -r requirements.txt
47
+ python -m spacy download en_core_web_sm
48
+ ```
49
+
50
+ ### 2. Running the Pipeline
51
+
52
+ The project is divided into stages. You can run the entire pipeline end-to-end, or run specific stages individually using `run_pipeline.py`.
53
+
54
+ **To run the complete training pipeline (Stages 1 to 3):**
55
+ *Note: This will download datasets, preprocess them, and train all models. It may take a significant amount of time depending on your hardware.*
56
+
57
+ ```bash
58
+ python run_pipeline.py --stage 1 2 3
59
+ ```
60
+
61
+ **To run individual stages:**
62
+
63
+ * **Stage 1: Data Ingestion**
64
+ Downloads and formats the necessary datasets (e.g., LIAR, ISOT).
65
+ ```bash
66
+ python run_pipeline.py --stage 1
67
+ ```
68
+
69
+ * **Stage 2: Preprocessing**
70
+ Cleans the text, maps verdicts to binary labels, and prepares DataFrames for training.
71
+ ```bash
72
+ python run_pipeline.py --stage 2
73
+ ```
74
+
75
+ * **Stage 3: Training**
76
+ Trains the ensemble: Logistic Regression, LSTM, DistilBERT, and RoBERTa. Saves the models to the `/models` directory.
77
+ ```bash
78
+ python run_pipeline.py --stage 3
79
+ ```
80
+
81
+ * **Stage 4: Evaluation**
82
+ Evaluates the trained pipeline on the holdout test set using the 5-signal inference framework.
83
+ ```bash
84
+ python run_pipeline.py --eval
85
+ ```
86
+
87
+ ---
88
+
89
+ ## 🖥️ Running the Application
90
+
91
+ Once the models are trained (or if you already have the pre-trained weights in the `/models` directory), you can launch the TruthLens UI.
92
+
93
+ ```bash
94
+ python -m streamlit run app.py
95
+ ```
96
+
97
+ This will start a local web server (usually at `http://localhost:8501`).
98
+
99
+ ### Using the App:
100
+ 1. **Paste text or provide a URL:** You can paste the raw text of an article (with or without a headline) or simply provide a URL for the app to parse automatically.
101
+ 2. **Select depth:** Choose Quick, Standard, or Deep analysis.
102
+ 3. **View Results:** Explore the four-tier verdict (True, Uncertain, Likely False, False), signal breakdown, adversarial flags, and live web corroboration results.
TruthLens_Paper.tex ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ \documentclass[11pt,a4paper]{article}
2
+ \usepackage[utf8]{inputenc}
3
+ \usepackage{amsmath}
4
+ \usepackage{amsfonts}
5
+ \usepackage{amssymb}
6
+ \usepackage{graphicx}
7
+ \usepackage{hyperref}
8
+ \usepackage{geometry}
9
+ \usepackage{booktabs}
10
+
11
+ \geometry{margin=1in}
12
+
13
+ \title{\textbf{TruthLens: A 5-Signal Weighted Architecture \\for Explainable Fake News Detection}}
14
+ \author{Your Name \\ \textit{Your Institution/University} \\ \texttt{email@domain.com}}
15
+ \date{\today}
16
+
17
+ \begin{document}
18
+
19
+ \maketitle
20
+
21
+ \begin{abstract}
22
+ The proliferation of misinformation and fake news necessitates robust automated detection systems. Traditional machine learning approaches often lack explainability, relying solely on black-box probabilistic outputs without verifying journalistic integrity. In this paper, we present TruthLens, an end-to-end framework integrating deep language models (DistilBERT, RoBERTa), sequence models (LSTM), and statistical baselines (Logistic Regression) into a unique 5-signal weighted scoring architecture. The framework assesses source credibility, claim verification via Named Entity Recognition (NER), linguistic quality, temporal freshness, and ensemble model consensus. Furthermore, the system incorporates adversarial overrides and live Retrieval-Augmented Generation (RAG) corroboration to guard against typosquatting, triple anonymity, and hallucinated statistics. Our methodology transitions away from a simple binary classifier to an explainable, multi-faceted verification engine, significantly improving both robustness and user trust in algorithmic verdicts.
23
+ \end{abstract}
24
+
25
+ \section{Introduction}
26
+ As digital media consumption accelerates, the threat of fabricated content, deliberately crafted to mislead, has escalated dramatically. The challenge of automated fake news detection lies not only in accurately classifying claims as ``True'' or ``False'', but in providing actionable contexts and explanations to content moderators and end-users. Current state-of-the-art transformer models achieve high accuracy on standard datasets; however, their black-box nature provides little insight into \textit{why} a verdict was reached. Furthermore, they are highly susceptible to temporal drift—where a genuinely true article becomes factually outdated—and adversarial manipulation, such as credible-sounding linguistic structures applied to unverified sources.
27
+
28
+ To address these shortcomings, we developed TruthLens. TruthLens diverges from standard classification pipelines by acting as a programmatic misinformation analyst. It does not blindly trust model consensus; instead, it synthesizes AI probabilities with deterministic heuristics based on journalistic standards.
29
+
30
+ \section{Problem Definition}
31
+ Given an article comprising an optional headline ($H$), a body text ($T$), a source domain ($D$), and a publication date ($P$), the objective is to map the document to one of four explainable verdicts: \texttt{TRUE}, \texttt{UNCERTAIN}, \texttt{LIKELY FALSE}, or \texttt{FALSE}.
32
+
33
+ A secondary objective is to dynamically generate a confidence score ($C \in \{LOW, MEDIUM, HIGH\}$) and qualitative justifications (e.g., deductions applied, adversarial flags triggered). The system must account for edge cases such as missing publication dates, anonymous or typo-squatted domains, sensationalist headlines safely contradicting the body, and the verification of quoted entities.
34
+
35
+ \section{Methodology}
36
+ The TruthLens pipeline is partitioned into four distinct operational stages, culminating in a real-time 5-signal evaluation engine.
37
+
38
+ \subsection{Stage 1: Data Ingestion}
39
+ In the first stage, raw datasets are fetched, homogenized, and aggregated. Disparate datasets originally formatted for binary or multi-class detection are ingested to form a consolidated corpus. During this stage, critical metadata (such as publication timestamps and source URLs) is preserved if available, or flagged for contextual imputation later.
40
+
41
+ \subsection{Stage 2: Preprocessing \& Feature Extraction}
42
+ Text entries undergo rigorous cleaning to standardize capitalization, extract URLs, expand contractions, and remove non-alphanumeric noise without losing semantic meaning. We construct vocabulary indices and tokenization mappings necessary for deep learning backbones. Cleaned datasets are persisted to disk to accelerate experimental reiteration.
43
+
44
+ \subsection{Stage 3: Ensemble Training}
45
+ We utilize a heterogeneous ensemble of four models, each capturing distinct features of the input:
46
+ \begin{itemize}
47
+ \item \textbf{Logistic Regression:} A statistical baseline utilizing TF-IDF vectors to capture simple lexical patterns heavily correlated with misinformation.
48
+ \item \textbf{LSTM:} A recurrent neural network designed to capture sequential linguistic dependencies.
49
+ \item \textbf{DistilBERT \& RoBERTa:} High-capacity, attention-based deep contextual language models fine-tuned to detect semantic nuances and internal contradictions.
50
+ \end{itemize}
51
+
52
+ \subsection{Stage 4: 5-Signal Inference Framework}
53
+ The core innovation of TruthLens is the Stage 4 inference engine, which combines model predictions with four deterministic pillars:
54
+ \begin{enumerate}
55
+ \item \textbf{Source Credibility (30\%):} Evaluates the article domain against a known credible database, checks for explicit author attribution (bylines), and penalizes subtle spelling manipulations (typosquatting).
56
+ \item \textbf{Claim Verification (30\%):} Utilizes spaCy for Named Entity Recognition (NER), tracking the ratio of attributed quotes. The meta-classifier probability is blended dynamically with these entity-level checks.
57
+ \item \textbf{Linguistic Quality (20\%):} Applies rule-based deductions for sensationalism (e.g., excessive capitalized words, superlatives), passive voice overuse, and deploys DistilBERT to compute a cosine similarity check identifying headline-body contradictions.
58
+ \item \textbf{Temporal Freshness (10\%):} Operates via a dual-case system. Case A uses explicit publication dates with a non-linear decay function. Case B parses contextual temporal cues (e.g., ``yesterday'', current year) when explicit dates are absent.
59
+ \item \textbf{Model Vote Consensus (10\%):} Represents the pure ensemble vote ratio.
60
+ \end{enumerate}
61
+
62
+ Finally, adversarial overrides cap maximum scores at 25\% if critical journalistic standards are violated (e.g., ``Triple Anonymity''), ensuring weak models do not inadvertently legitimize fabricated articles. When freshness enters an ambiguous threshold, a Retrieval-Augmented Generation (RAG) module queries live web APIs to corroborate claims with recent index updates.
63
+
64
+ \section{Data Description}
65
+ The model is trained on a homogenized amalgamation of well-established fake news benchmarks, primarily the ISOT Fake News Dataset and the LIAR Dataset.
66
+
67
+ \subsection{Data Divisions \& Preprocessing}
68
+ The aggregated data is structured into a typical split architecture to ensure maximum generalizability:
69
+ \begin{itemize}
70
+ \item \textbf{Training Set (70\%):} Utilized for fitting the TF-IDF vectorizer, tuning the LSTM layer weights, and fine-tuning the Transformer heads.
71
+ \item \textbf{Validation Set (15\%):} Employed to monitor epoch loss, apply early stopping during deep model training, and tune the interpolation thresholds for the 5-signal weights.
72
+ \item \textbf{Holdout / Test Set (15\%):} Kept strictly isolated. Evaluation operates under the constraint of mapping the legacy truth classes to the newly defined 4-tier verdict system (\texttt{TRUE}/\texttt{UNCERTAIN} mapped to 1; \texttt{LIKELY FALSE}/\texttt{FALSE} mapped to 0).
73
+ \end{itemize}
74
+
75
+ \section{Results}
76
+ Integration of the 5-signal weighting framework demonstrated a marked improvement in explainability. While the baseline ensemble alone yielded high raw accuracy on synthetic datasets, it failed to identify maliciously injected adversarial examples (e.g., valid text attributed to ``c-n-n.com''). The complete infrastructure successfully caps adversarial scores, raising overall real-world robustness. On the isolated hold-out set, the system dynamically flags low-confidence articles, ensuring that unverifiable claims are strictly prevented from passing a theoretical ``publishing'' boundary. Live RAG integration further eliminated false-positives previously caused by temporal drift in the training data.
77
+
78
+ \section{Discussion}
79
+ TruthLens introduces a pivotal shift from pure binary classification to diagnostic evaluation in fake news detection. By hardcoding journalism heuristics (source validation, NER quoting, temporality) atop deep learning embeddings, the model successfully simulates a human analyst's workflow.
80
+
81
+ A notable outcome of our framework is the treatment of Confidence. By decoupling probability from confidence (e.g., lowering confidence for texts under 50 words or lacking entities), the user interfaces can visually decouple ``truthfulness'' from ``reliability''.
82
+ Future work includes expanding the Live RAG validation checks with multi-hop reasoning LLMs, and expanding the author-attribution regex engine to cover a broader range of international linguistic formats.
83
+
84
+ \end{document}
app.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import time
5
+ import pandas as pd
6
+ import numpy as np
7
+ import streamlit as st
8
+
9
+ _ROOT = os.path.dirname(os.path.abspath(__file__))
10
+ if _ROOT not in sys.path:
11
+ sys.path.insert(0, _ROOT)
12
+
13
+ # ── Page config ──────────────────────────────────────────────────────────────
14
+ st.set_page_config(
15
+ page_title="TruthLens · Fake News Detector",
16
+ page_icon="🔍",
17
+ layout="wide",
18
+ initial_sidebar_state="collapsed",
19
+ )
20
+
21
+ # ── Global CSS ───────────────────────────────────────────────────────────────
22
+ st.markdown("""
23
+ <style>
24
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap');
25
+
26
+ /* ── Reset ── */
27
+ html, body, [data-testid="stAppViewContainer"] {
28
+ font-family: 'Inter', sans-serif;
29
+ background: #f4f6fb;
30
+ color: #1e293b;
31
+ }
32
+ [data-testid="stMain"] { background: #f4f6fb; }
33
+ .block-container {
34
+ padding-top: 2.5rem !important;
35
+ padding-bottom: 2rem !important;
36
+ max-width: 920px;
37
+ }
38
+
39
+ /* ── Remove Streamlit chrome ── */
40
+ header[data-testid="stHeader"] { display: none; }
41
+ footer { display: none; }
42
+ #MainMenu { display: none; }
43
+ [data-testid="stSidebar"] { display: none; }
44
+
45
+ /* ── Predict button ── */
46
+ .stButton > button[kind="primary"] {
47
+ background: linear-gradient(135deg, #3b82f6 0%, #6366f1 100%) !important;
48
+ color: #fff !important;
49
+ border: none !important;
50
+ border-radius: 12px !important;
51
+ font-weight: 700 !important;
52
+ font-size: 1.05rem !important;
53
+ letter-spacing: 0.02em;
54
+ padding: 0.75rem 2rem !important;
55
+ transition: transform 0.15s, box-shadow 0.2s;
56
+ box-shadow: 0 4px 16px rgba(59,130,246,0.2);
57
+ }
58
+ .stButton > button[kind="primary"]:hover {
59
+ transform: translateY(-1px);
60
+ box-shadow: 0 6px 24px rgba(59,130,246,0.3) !important;
61
+ }
62
+
63
+ /* ── Tab styling ── */
64
+ [data-testid="stTabs"] button {
65
+ color: #94a3b8 !important;
66
+ font-size: 0.92rem !important;
67
+ font-weight: 500 !important;
68
+ padding: 10px 20px !important;
69
+ }
70
+ [data-testid="stTabs"] button[aria-selected="true"] {
71
+ color: #1e293b !important;
72
+ border-bottom: 2px solid #3b82f6 !important;
73
+ font-weight: 600 !important;
74
+ }
75
+
76
+ /* ── Verdict banner ── */
77
+ .verdict-box {
78
+ border-radius: 16px;
79
+ padding: 32px 36px;
80
+ margin-bottom: 28px;
81
+ display: flex;
82
+ align-items: center;
83
+ gap: 24px;
84
+ animation: fadeSlide 0.5s ease;
85
+ }
86
+ @keyframes fadeSlide {
87
+ from { opacity: 0; transform: translateY(-16px); }
88
+ to { opacity: 1; transform: translateY(0); }
89
+ }
90
+ .verdict-emoji { font-size: 3.5rem; line-height: 1; }
91
+ .verdict-label { font-size: 1.8rem; font-weight: 800; letter-spacing: -0.03em; }
92
+ .verdict-conf { font-size: 1rem; opacity: 0.85; margin-top: 6px; font-weight: 400; }
93
+ .verdict-explain { font-size: 0.88rem; color: #64748b; margin-top: 6px; line-height: 1.5; }
94
+
95
+ /* ── Info cards ── */
96
+ .info-card {
97
+ background: #ffffff;
98
+ border: 1px solid #e2e8f0;
99
+ border-radius: 12px;
100
+ padding: 20px 24px;
101
+ margin: 12px 0;
102
+ line-height: 1.6;
103
+ color: #475569;
104
+ }
105
+ .info-card b { color: #1e293b; }
106
+
107
+ /* ── Freshness bar ── */
108
+ .fresh-track { background: #e2e8f0; border-radius: 8px; height: 12px; margin: 10px 0 6px; overflow: hidden; }
109
+ .fresh-fill { height: 100%; border-radius: 8px; transition: width 0.8s ease; }
110
+
111
+ /* ── Source card ── */
112
+ .source-card {
113
+ background: #ffffff;
114
+ border: 1px solid #e2e8f0;
115
+ border-radius: 12px;
116
+ padding: 18px 22px;
117
+ margin: 10px 0;
118
+ display: flex;
119
+ justify-content: space-between;
120
+ align-items: flex-start;
121
+ gap: 16px;
122
+ }
123
+ .source-text { flex: 1; font-size: 0.88rem; line-height: 1.5; color: #475569; }
124
+ .source-score { text-align: center; min-width: 60px; }
125
+ .source-score-val { font-size: 1.4rem; font-weight: 700; font-family: 'Inter', sans-serif; }
126
+ .source-score-tag { font-size: 0.65rem; text-transform: uppercase; letter-spacing: 0.1em; margin-top: 4px; }
127
+
128
+ /* ── Hero ── */
129
+ .hero-wrap { text-align: center; padding: 60px 20px 40px; }
130
+ .hero-icon { font-size: 4rem; margin-bottom: 16px; }
131
+ .hero-title { font-size: 2.4rem; font-weight: 800; letter-spacing: -0.04em; color: #0f172a; }
132
+ .hero-sub { font-size: 1.05rem; color: #64748b; margin-top: 12px; line-height: 1.6; max-width: 520px; margin-left: auto; margin-right: auto; }
133
+
134
+ /* ── How-it-works ── */
135
+ .how-grid { display: grid; grid-template-columns: repeat(3, 1fr); gap: 16px; margin: 36px 0; }
136
+ .how-card {
137
+ background: #ffffff;
138
+ border: 1px solid #e2e8f0;
139
+ border-radius: 12px;
140
+ padding: 24px;
141
+ text-align: center;
142
+ box-shadow: 0 1px 3px rgba(0,0,0,0.04);
143
+ }
144
+ .how-num { font-size: 2rem; margin-bottom: 8px; }
145
+ .how-title { font-size: 0.95rem; font-weight: 600; margin-bottom: 6px; color: #0f172a; }
146
+ .how-desc { font-size: 0.82rem; color: #64748b; line-height: 1.5; }
147
+
148
+ /* ── Verdict legend ── */
149
+ .legend-row {
150
+ display: flex;
151
+ gap: 24px;
152
+ justify-content: center;
153
+ flex-wrap: wrap;
154
+ margin: 20px 0;
155
+ }
156
+ .legend-item { font-size: 0.85rem; color: #64748b; }
157
+
158
+ /* ── Metric overrides ── */
159
+ [data-testid="stMetric"] {
160
+ background: #ffffff;
161
+ border: 1px solid #e2e8f0;
162
+ border-radius: 10px;
163
+ padding: 14px 18px !important;
164
+ box-shadow: 0 1px 3px rgba(0,0,0,0.04);
165
+ }
166
+ [data-testid="stMetricLabel"] { color: #64748b !important; font-size: 0.78rem !important; }
167
+ [data-testid="stMetricValue"] { color: #0f172a !important; font-size: 1.3rem !important; }
168
+
169
+ /* ── Expander ── */
170
+ [data-testid="stExpander"] {
171
+ background: #ffffff !important;
172
+ border: 1px solid #e2e8f0 !important;
173
+ border-radius: 10px !important;
174
+ }
175
+
176
+ /* ── Text inputs ── */
177
+ [data-testid="stTextInput"] input, [data-testid="stTextArea"] textarea {
178
+ background: #ffffff !important;
179
+ border: 1px solid #cbd5e1 !important;
180
+ border-radius: 8px !important;
181
+ color: #1e293b !important;
182
+ }
183
+ [data-testid="stTextInput"] input:focus, [data-testid="stTextArea"] textarea:focus {
184
+ border-color: #3b82f6 !important;
185
+ box-shadow: 0 0 0 2px rgba(59,130,246,0.15) !important;
186
+ }
187
+
188
+ /* ── Select slider / radio ── */
189
+ [data-testid="stSlider"] label, .stRadio label { color: #475569 !important; }
190
+
191
+ /* ── Progress bar ── */
192
+ [data-testid="stProgress"] > div > div > div > div { background: linear-gradient(90deg, #3b82f6, #6366f1) !important; }
193
+ </style>
194
+ """, unsafe_allow_html=True)
195
+
196
+
197
+ # ── Cached inference loader ──────────────────────────────────────────────────
198
+ @st.cache_resource(show_spinner=False)
199
+ def load_pipeline():
200
+ from src.stage4_inference import predict_article, ModelNotTrainedError
201
+ return predict_article, ModelNotTrainedError
202
+
203
+
204
+ # ── Session state ────────────────────────────────────────────────────────────
205
+ for k, v in [("analyzed", False), ("last_result", None), ("last_input", "")]:
206
+ if k not in st.session_state:
207
+ st.session_state[k] = v
208
+
209
+
210
+ # =============================================================================
211
+ # LANDING PAGE (shown before any analysis)
212
+ # =============================================================================
213
+ if not st.session_state["analyzed"]:
214
+
215
+ # ── Hero section ──
216
+ st.markdown("""
217
+ <div class="hero-wrap">
218
+ <div class="hero-icon">🔍</div>
219
+ <div class="hero-title">TruthLens</div>
220
+ <div class="hero-sub">
221
+ Paste any news article or drop a URL — our AI will tell you
222
+ if it's real, fake, or outdated in seconds.
223
+ </div>
224
+ </div>
225
+ """, unsafe_allow_html=True)
226
+
227
+ # ── How it works ──
228
+ st.markdown("""
229
+ <div class="how-grid">
230
+ <div class="how-card">
231
+ <div class="how-num">📋</div>
232
+ <div class="how-title">Paste or Link</div>
233
+ <div class="how-desc">Drop in the article text or a URL. We'll extract everything automatically.</div>
234
+ </div>
235
+ <div class="how-card">
236
+ <div class="how-num">⚡</div>
237
+ <div class="how-title">Instant Analysis</div>
238
+ <div class="how-desc">Our AI analyzes language patterns, checks freshness, and searches live sources.</div>
239
+ </div>
240
+ <div class="how-card">
241
+ <div class="how-num">✅</div>
242
+ <div class="how-title">Get Your Verdict</div>
243
+ <div class="how-desc">See a clear REAL / FAKE / OUTDATED verdict with a confidence score and explanation.</div>
244
+ </div>
245
+ </div>
246
+ """, unsafe_allow_html=True)
247
+
248
+ # ── Input area ──
249
+ input_tab = st.radio("How would you like to provide the article?",
250
+ ["✍️ Write or paste text", "🔗 Paste a URL"],
251
+ horizontal=True, label_visibility="visible")
252
+
253
+ input_text, input_title, input_url, input_date, input_domain = "", "", "", "", ""
254
+
255
+ if input_tab == "✍️ Write or paste text":
256
+ input_title = st.text_input("Headline (optional)",
257
+ placeholder="e.g. Breaking: Scientists discover high-speed interstellar travel")
258
+ input_text = st.text_area("Article content",
259
+ height=180,
260
+ placeholder="Paste the full article body here…")
261
+ # ── Auto-extract title from pasted text if headline field is empty ──
262
+ if not input_title.strip() and input_text.strip():
263
+ if input_text.lower().startswith("title:"):
264
+ lines = input_text.split("\n", 1)
265
+ input_title = lines[0].replace("Title:", "").replace("title:", "").strip()
266
+ input_text = lines[1].replace("Body:", "").replace("body:", "").strip() if len(lines) > 1 else ""
267
+ else:
268
+ # Fallback: first sentence is title
269
+ input_title = input_text.split(".")[0].strip()
270
+
271
+ else:
272
+ input_url = st.text_input("Article URL",
273
+ placeholder="https://www.example.com/news/breaking-story")
274
+ st.caption("We'll automatically extract the title, body, and publish date.")
275
+
276
+ # ── Analysis mode (kept minimal — user doesn't need to understand internals)
277
+ speed = st.select_slider("Analysis depth",
278
+ options=["Quick", "Standard", "Deep"],
279
+ value="Deep",
280
+ help="Quick ≈ 2 sec · Standard ≈ 10 sec · Deep ≈ 30 sec (most accurate)")
281
+ speed_map = {"Quick": "fast", "Standard": "balanced", "Deep": "full"}
282
+ selected_mode = speed_map[speed]
283
+
284
+ # ── Predict button ──
285
+ predict_clicked = st.button("🔍 Check this article", use_container_width=True, type="primary")
286
+
287
+ # ── Verdict legend ──
288
+ st.markdown("""
289
+ <div class="legend-row">
290
+ <div class="legend-item">🟢 Verified True</div>
291
+ <div class="legend-item">🔴 Likely Fake</div>
292
+ <div class="legend-item">🟡 Outdated</div>
293
+ <div class="legend-item">🟠 Needs Review</div>
294
+ </div>
295
+ """, unsafe_allow_html=True)
296
+
297
+ # ── Execute prediction ──
298
+ if predict_clicked:
299
+ # Validate
300
+ if input_tab == "✍️ Write or paste text":
301
+ if not input_text or len(input_text.split()) < 10:
302
+ st.warning("⚠️ Please paste at least a few sentences so we can analyze it properly.")
303
+ st.stop()
304
+ else:
305
+ if not input_url:
306
+ st.warning("⚠️ Please enter a URL first.")
307
+ st.stop()
308
+ try:
309
+ import newspaper
310
+ from urllib.parse import urlparse
311
+ art = newspaper.Article(input_url)
312
+ art.download()
313
+ art.parse()
314
+ input_title = art.title or ""
315
+ input_text = art.text or ""
316
+ input_date = art.publish_date.isoformat() if art.publish_date else ""
317
+ input_domain = urlparse(input_url).netloc
318
+ if len(input_text.split()) < 10:
319
+ st.warning("⚠️ Couldn't extract enough text from that URL. Try pasting the article directly.")
320
+ st.stop()
321
+ except Exception:
322
+ st.error("❌ Couldn't fetch that URL. Please check the link or paste the text directly.")
323
+ st.stop()
324
+
325
+ predict_article, ModelNotTrainedError = load_pipeline()
326
+
327
+ with st.status("🔍 Analyzing article…", expanded=True) as status:
328
+ st.write("📖 Reading article…")
329
+ time.sleep(0.3)
330
+ st.write("🧠 Running AI analysis…")
331
+ try:
332
+ result = predict_article(
333
+ title=input_title,
334
+ text=input_text,
335
+ source_domain=input_domain,
336
+ published_date=input_date,
337
+ mode=selected_mode,
338
+ )
339
+ st.write("🕐 Checking article freshness…")
340
+ st.write("🌐 Searching live sources…")
341
+ status.update(label="✅ Done!", state="complete")
342
+ st.session_state["last_result"] = result
343
+ st.session_state["last_input"] = input_text
344
+ st.session_state["analyzed"] = True
345
+ st.rerun()
346
+ except ModelNotTrainedError:
347
+ status.update(label="❌ Setup required", state="error")
348
+ st.error("The AI models haven't been trained yet.")
349
+ st.info("Ask your administrator to run: `python run_pipeline.py --stage 1 2 3`")
350
+ st.stop()
351
+ except Exception as e:
352
+ status.update(label="❌ Error", state="error")
353
+ st.error(f"Something went wrong: {e}")
354
+ st.stop()
355
+
356
+
357
+
358
+ # =============================================================================
359
+ # RESULTS PAGE (shown after analysis)
360
+ # =============================================================================
361
+ else:
362
+ res = st.session_state["last_result"]
363
+ verdict = res.get("verdict", "UNKNOWN")
364
+ final_score = res.get("final_score", 0.0)
365
+ scores = res.get("scores", {})
366
+ confidence = res.get("confidence", "MEDIUM")
367
+ action = res.get("recommended_action", "Flag for review")
368
+ top_reasons = res.get("top_reasons", [])
369
+ missing_signals = res.get("missing_signals", [])
370
+ adv_flags = res.get("adversarial_flags", [])
371
+ wc = res.get("word_count", 0)
372
+ probas = res.get("base_model_probas", {})
373
+ votes = res.get("base_model_votes", {})
374
+ fresh_case = res.get("freshness_case", "B")
375
+ fresh_signals = res.get("freshness_signals_found", [])
376
+ deductions = res.get("deductions_applied", [])
377
+ entities = res.get("entities_found", [])
378
+
379
+ # ── Map verdict to display ──
380
+ V = {
381
+ "TRUE": {"bg":"#f0fdf4", "bdr":"#86efac", "icon":"🟢", "label":"This appears to be true", "color":"#15803d",
382
+ "explain":"Source, claims, language, and AI models all align with credible journalism."},
383
+ "UNCERTAIN": {"bg":"#fff7ed", "bdr":"#fdba74", "icon":"🟠", "label":"Uncertain — needs review", "color":"#c2410c",
384
+ "explain":"Mixed signals detected. We recommend verifying the sources yourself before sharing."},
385
+ "LIKELY FALSE": {"bg":"#fef2f2", "bdr":"#fca5a5", "icon":"🔴", "label":"Likely false", "color":"#b91c1c",
386
+ "explain":"Multiple signals indicate this content may be fabricated or misleading."},
387
+ "FALSE": {"bg":"#fef2f2", "bdr":"#fca5a5", "icon":"⛔", "label":"This looks fake", "color":"#991b1b",
388
+ "explain":"Strong evidence of misinformation. Do not share without independent verification."},
389
+ }
390
+ vc = V.get(verdict, {"bg":"#f8fafc","bdr":"#cbd5e1","icon":"⚪","label":verdict,"color":"#475569",
391
+ "explain":"Analysis complete."})
392
+
393
+ # ── Verdict banner ──
394
+ score_pct = final_score * 100
395
+ st.markdown(f"""
396
+ <div class="verdict-box" style="background:{vc['bg']}; border:1px solid {vc['bdr']};">
397
+ <div class="verdict-emoji">{vc['icon']}</div>
398
+ <div>
399
+ <div class="verdict-label" style="color:{vc['color']};">{vc['label']}</div>
400
+ <div class="verdict-conf" style="color:{vc['color']};">Score: {score_pct:.0f}% · Confidence: {confidence}</div>
401
+ <div class="verdict-explain">{vc['explain']}</div>
402
+ </div>
403
+ </div>
404
+ """, unsafe_allow_html=True)
405
+
406
+ # ── Recommended action badge ──
407
+ action_colors = {
408
+ "Publish": ("#f0fdf4", "#15803d"),
409
+ "Flag for review": ("#fff7ed", "#c2410c"),
410
+ "Suppress": ("#fef2f2", "#b91c1c"),
411
+ "Escalate": ("#fef2f2", "#991b1b"),
412
+ }
413
+ abg, acol = action_colors.get(action, ("#f8fafc", "#475569"))
414
+ st.markdown(f"""
415
+ <div style="background:{abg}; border-radius:8px; padding:10px 16px; display:inline-block; margin-bottom:24px;">
416
+ <span style="font-weight:600; color:{acol};">Recommended: {action}</span>
417
+ </div>
418
+ """, unsafe_allow_html=True)
419
+
420
+ # ── Tabs ──
421
+ tab_why, tab_fresh, tab_sources, tab_details = st.tabs(
422
+ ["🧠 Why this verdict?", "🕐 Freshness", "🌐 Live sources", "📋 Details"]
423
+ )
424
+
425
+ # ── TAB 1: Why this verdict ──────────────────────────────────────────
426
+ with tab_why:
427
+
428
+ # ── 5-Signal Score Breakdown ──
429
+ st.markdown("#### Signal Breakdown")
430
+ SIGNAL_INFO = [
431
+ ("Source", "source", "Is the outlet known and accountable?"),
432
+ ("Claims", "claim", "Are facts verifiable with named entities?"),
433
+ ("Language", "linguistic", "Is the writing neutral and attributed?"),
434
+ ("Freshness", "freshness", "How recent is the content?"),
435
+ ("AI Models", "model_vote", "What do the AI models think?"),
436
+ ]
437
+ WEIGHTS = {"source": "30%", "claim": "30%", "linguistic": "20%", "freshness": "10%", "model_vote": "10%"}
438
+
439
+ cols = st.columns(5)
440
+ for i, (label, key, desc) in enumerate(SIGNAL_INFO):
441
+ val = scores.get(key, 0.0)
442
+ pct = val * 100
443
+ if pct >= 70:
444
+ col_hex = "#15803d"
445
+ elif pct >= 50:
446
+ col_hex = "#ca8a04"
447
+ else:
448
+ col_hex = "#b91c1c"
449
+ with cols[i]:
450
+ st.markdown(f"""
451
+ <div style="text-align:center; background:#ffffff; border:1px solid #e2e8f0;
452
+ border-radius:10px; padding:16px 8px; box-shadow:0 1px 3px rgba(0,0,0,0.04);">
453
+ <div style="font-size:1.6rem; font-weight:800; color:{col_hex};">{pct:.0f}%</div>
454
+ <div style="font-size:0.85rem; font-weight:600; color:#0f172a; margin-top:4px;">{label}</div>
455
+ <div style="font-size:0.7rem; color:#94a3b8; margin-top:2px;">Weight: {WEIGHTS[key]}</div>
456
+ </div>
457
+ """, unsafe_allow_html=True)
458
+
459
+ st.markdown("")
460
+
461
+ # ── Progress bars for each signal ──
462
+ for label, key, desc in SIGNAL_INFO:
463
+ val = scores.get(key, 0.0)
464
+ st.caption(f"**{label}** — {desc}")
465
+ st.progress(min(val, 1.0))
466
+
467
+ st.markdown("---")
468
+
469
+ # ── Top Reasons ──
470
+ if top_reasons:
471
+ st.markdown("#### Key Factors")
472
+ for r in top_reasons:
473
+ if any(neg in r.lower() for neg in ["fake", "false", "unknown", "not", "manipulation", "adversarial", "sensationalism", "reduces", "could not", "inconsistent", "missing"]):
474
+ st.markdown(f"🔴 {r}")
475
+ else:
476
+ st.markdown(f"🟢 {r}")
477
+
478
+ st.markdown("---")
479
+
480
+ # ── What did each AI model think? ──
481
+ st.markdown("#### AI Model Votes")
482
+ MODEL_NAMES = [
483
+ ("Statistical", "logistic", "lr_proba"),
484
+ ("Language", "lstm", "lstm_proba"),
485
+ ("Deep A", "distilbert", "distilbert_proba"),
486
+ ("Deep B", "roberta", "roberta_proba"),
487
+ ]
488
+ mcols = st.columns(len(MODEL_NAMES))
489
+ for i, (nice_name, vote_key, pk) in enumerate(MODEL_NAMES):
490
+ vote_val = votes.get(vote_key)
491
+ prob_val = probas.get(pk)
492
+ with mcols[i]:
493
+ if vote_val is None or prob_val is None or np.isnan(prob_val):
494
+ st.metric(nice_name, "Skipped")
495
+ else:
496
+ lbl = "Real" if int(vote_val) == 1 else "Fake"
497
+ st.metric(nice_name, lbl, f"{prob_val*100:.0f}%")
498
+
499
+ if res.get("short_text_warning"):
500
+ st.warning("⚠️ Short article (under 50 words) — confidence is dampened.")
501
+ st.caption(f"Article length: {wc} words")
502
+
503
+ # ── TAB 2: Freshness ─────────────────────────────────────────────────
504
+ with tab_fresh:
505
+ fresh_val = scores.get("freshness", 0.5)
506
+ bar_pct = int(fresh_val * 100)
507
+
508
+ if fresh_val >= 0.70:
509
+ fbg, flbl, fdesc = "#f0fdf4", "🟢 Fresh", "This article appears to be recent."
510
+ fbar = "#16a34a"
511
+ elif fresh_val >= 0.40:
512
+ fbg, flbl, fdesc = "#fefce8", "🟡 Moderate", "Article may not be very recent."
513
+ fbar = "#ca8a04"
514
+ else:
515
+ fbg, flbl, fdesc = "#fef2f2", "🔴 Outdated", "This article appears to be old."
516
+ fbar = "#dc2626"
517
+
518
+ st.markdown(f"""
519
+ <div style="background:{fbg}; border-radius:12px; padding:20px 24px; margin-bottom:20px;">
520
+ <div style="font-size:1.2rem; font-weight:600;">{flbl}</div>
521
+ <div style="font-size:0.88rem; color:#64748b; margin-top:8px;">{fdesc}</div>
522
+ <div class="fresh-track">
523
+ <div class="fresh-fill" style="width:{bar_pct}%; background:{fbar};"></div>
524
+ </div>
525
+ <div style="font-size:0.8rem; color:#6b7280; margin-top:4px;">Freshness: {fresh_val:.0%}</div>
526
+ </div>
527
+ """, unsafe_allow_html=True)
528
+
529
+ # Case indicator
530
+ case_label = "📅 Date-based scoring" if fresh_case == "A" else "🔎 Contextual signal scanning (no date found)"
531
+ st.markdown(f"""
532
+ <div class="info-card">
533
+ <b>Method:</b> {case_label}
534
+ </div>
535
+ """, unsafe_allow_html=True)
536
+
537
+ # Signals found (Case B)
538
+ if fresh_case == "B" and fresh_signals:
539
+ st.markdown("**Signals detected:**")
540
+ for sig in fresh_signals:
541
+ st.markdown(f"✅ {sig}")
542
+ elif fresh_case == "B":
543
+ st.caption("No contextual freshness signals were found in the article text.")
544
+
545
+ # ── TAB 3: Live sources ──────────────────────────────────────────────
546
+ with tab_sources:
547
+ rag_data = res.get("rag_results")
548
+ source_list = []
549
+ if isinstance(rag_data, dict):
550
+ source_list = rag_data.get("data", [])
551
+ elif isinstance(rag_data, list):
552
+ source_list = rag_data
553
+
554
+ if not source_list:
555
+ st.markdown("""
556
+ <div class="info-card">
557
+ <b>Live source check was not triggered</b><br><br>
558
+ Live source verification runs when freshness is ambiguous.
559
+ This analysis relied on the 5-signal scoring framework instead.
560
+ </div>
561
+ """, unsafe_allow_html=True)
562
+ else:
563
+ st.caption(f"Compared against {len(source_list)} live web results.")
564
+ for item in source_list:
565
+ snippet = item.get("snippet", "")
566
+ sim = item.get("similarity", 0.0)
567
+ if sim > 0.65:
568
+ sc_col, sc_tag = "#16a34a", "Supports"
569
+ elif sim < 0.30:
570
+ sc_col, sc_tag = "#dc2626", "Conflicts"
571
+ else:
572
+ sc_col, sc_tag = "#ca8a04", "Neutral"
573
+
574
+ st.markdown(f"""
575
+ <div class="source-card">
576
+ <div class="source-text">{snippet}</div>
577
+ <div class="source-score">
578
+ <div class="source-score-val" style="color:{sc_col};">{sim:.0%}</div>
579
+ <div class="source-score-tag" style="color:{sc_col};">{sc_tag}</div>
580
+ </div>
581
+ </div>
582
+ """, unsafe_allow_html=True)
583
+
584
+ # ── TAB 4: Details ───────────────────────────────────────��───────────
585
+ with tab_details:
586
+
587
+ # ── Missing Signals ──
588
+ if missing_signals:
589
+ st.markdown("#### ⚠️ Missing Signals")
590
+ for ms in missing_signals:
591
+ st.markdown(f"- {ms}")
592
+ st.markdown("")
593
+
594
+ # ── Adversarial Flags ──
595
+ if adv_flags:
596
+ st.markdown("#### 🚩 Adversarial Flags Triggered")
597
+ for af in adv_flags:
598
+ st.error(f"🚩 {af}")
599
+ st.caption("Adversarial flags cap the final score at 25% maximum.")
600
+ st.markdown("")
601
+
602
+ # ── Linguistic Deductions ──
603
+ if deductions:
604
+ st.markdown("#### 📝 Linguistic Deductions")
605
+ for d in deductions:
606
+ st.markdown(f"- {d}")
607
+ st.markdown("")
608
+
609
+ # ── Named Entities Found ──
610
+ if entities:
611
+ st.markdown("#### 🏷️ Entities Detected")
612
+ st.markdown(", ".join([f"`{e}`" for e in entities]))
613
+ q_attr = res.get("quotes_attributed", 0)
614
+ q_total = res.get("quotes_total", 0)
615
+ if q_total > 0:
616
+ st.caption(f"Quotes: {q_attr}/{q_total} attributed")
617
+ st.markdown("")
618
+
619
+ # ── Summary Table ──
620
+ st.markdown("#### Analysis Summary")
621
+ rows = [
622
+ ("Verdict", vc["label"]),
623
+ ("Final Score", f"{score_pct:.1f}%"),
624
+ ("Confidence", confidence),
625
+ ("Action", action),
626
+ ("Word Count", str(wc)),
627
+ ("Freshness", f"{scores.get('freshness', 0):.0%} (Case {fresh_case})"),
628
+ ]
629
+ df_rep = pd.DataFrame(rows, columns=["Field", "Value"])
630
+ st.dataframe(df_rep, use_container_width=True, hide_index=True, height=240)
631
+
632
+ with st.expander("🔧 Raw JSON (for developers)"):
633
+ st.code(json.dumps(res, indent=2, default=str), language="json")
634
+
635
+ # ── Analyze another ──
636
+ st.markdown("---")
637
+ if st.button("← Analyze another article", use_container_width=True):
638
+ st.session_state["analyzed"] = False
639
+ st.session_state["last_result"] = None
640
+ st.rerun()
641
+
config/config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ═══════════════════════════════════════════════════════════
2
+ # Fake News Detection System — Configuration
3
+ # ═══════════════════════════════════════════════════════════
4
+
5
+ paths:
6
+ # Raw dataset root (relative to project root)
7
+ dataset_root: "../Dataset"
8
+ # Processed data output
9
+ processed_dir: "data/processed"
10
+ # Train/val/test splits
11
+ splits_dir: "data/splits"
12
+ # Saved model weights
13
+ models_dir: "models/saved"
14
+ # Logs
15
+ logs_dir: "logs"
16
+ # GloVe embeddings
17
+ glove_path: "../glove.6B.100d.txt"
18
+
19
+ dataset:
20
+ min_domain_samples: 20
21
+ dedup_threshold: 0.92
22
+ dedup_batch_size: 64
23
+
24
+ preprocessing:
25
+ max_tfidf_features: 50000
26
+ lstm_max_len: 256
27
+ bert_max_len: 128
28
+ short_text_threshold: 50
29
+ min_word_count: 3
30
+
31
+ training:
32
+ lstm_batch_size: 64
33
+ lstm_epochs: 10
34
+ bert_batch_size: 8
35
+ bert_epochs: 3
36
+ lr_learning_rate: 2e-5
37
+ roberta_learning_rate: 1e-5
38
+
39
+ inference:
40
+ true_threshold: 0.55
41
+ fresh_threshold: 0.70
42
+ outdated_threshold: 0.40
43
+ undated_freshness_score: 0.35
44
+ max_multiplier: 10
45
+
46
+ rag:
47
+ top_k: 5
48
+ max_result_age_days: 90
49
+ support_threshold: 0.65
50
+ conflict_threshold: 0.30
51
+
52
+ drift:
53
+ word_count_std_alert: 2.0
54
+ fake_ratio_upper: 0.80
55
+ fake_ratio_lower: 0.20
56
+ rolling_window: 100
57
+
58
+ holdout:
59
+ stratified_test_size: 0.10
60
+ random_state: 42
models/saved/distilbert_model/cm_distil.png ADDED
models/saved/distilbert_model/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "bos_token_id": null,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "dtype": "float32",
11
+ "eos_token_id": null,
12
+ "hidden_dim": 3072,
13
+ "initializer_range": 0.02,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "distilbert",
16
+ "n_heads": 12,
17
+ "n_layers": 6,
18
+ "pad_token_id": 0,
19
+ "problem_type": "single_label_classification",
20
+ "qa_dropout": 0.1,
21
+ "seq_classif_dropout": 0.2,
22
+ "sinusoidal_pos_embds": false,
23
+ "tie_weights_": true,
24
+ "tie_word_embeddings": true,
25
+ "transformers_version": "5.5.3",
26
+ "use_cache": false,
27
+ "vocab_size": 30522
28
+ }
models/saved/distilbert_model/distilbert_oof.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe20e00deb86d90904d65166c92605f5afbe923986eba9ea9d980cf6c2555992
3
+ size 161240
models/saved/distilbert_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd61e968af64237e2024f8ccfbeb8e7df658f7bde6a3310826ad7767dcd1118b
3
+ size 267832560
models/saved/distilbert_model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/saved/distilbert_model/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "is_local": false,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 512,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "strip_accents": null,
11
+ "tokenize_chinese_chars": true,
12
+ "tokenizer_class": "BertTokenizer",
13
+ "unk_token": "[UNK]"
14
+ }
models/saved/distilbert_model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bc15d538a51fdb305b84c6674de46ba456c3288bbf8030c105f2034004edc06
3
+ size 5329
models/saved/logistic_model/cm.png ADDED
models/saved/logistic_model/logistic_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07d15b87ba7fc5e352c37141f8ec2f4a182fff0e8ff5efa26ef6f8cef1bd9db3
3
+ size 2383153
models/saved/logistic_model/lr_oof.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02da29135c36a6752466dbe6b18ef62dbfdb4dca131b035546ce1720ba9eae69
3
+ size 322352
models/saved/logistic_model/metrics.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "roc_auc": 0.9668348166617273,
3
+ "bucket_accuracy": {
4
+ "short": 0.6078748651564185,
5
+ "medium": 0.9741119807344973,
6
+ "long": 0.986362371277484
7
+ }
8
+ }
models/saved/lstm_model/cm.png ADDED
models/saved/lstm_model/lstm_oof.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f26e0d9f8ac71e5d60591c007e07e7a164c25de363dedeb73769a521dc548b4
3
+ size 161240
models/saved/lstm_model/metrics.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "roc_auc": 0.9730530691595465,
3
+ "bucket_accuracy": {
4
+ "short": 0.6197411003236246,
5
+ "medium": 0.9704996989765202,
6
+ "long": 0.9944336209295853
7
+ }
8
+ }
models/saved/lstm_model/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5fb0cf853e49b9410620e76842792df7ac50ace32b47bb16d23a10d29165534
3
+ size 21639159
models/saved/meta_classifier/cm_meta.png ADDED
models/saved/meta_classifier/meta_classifier.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62807b550c6db9514b69c346a497d9897dfd9db8566b8e806d1cf8fe233384fa
3
+ size 276517
models/saved/meta_classifier/metrics.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "roc_auc": 0.9770497393987524,
3
+ "bucket_accuracy": {
4
+ "short": 0.6564185544768069,
5
+ "medium": 0.9867549668874173,
6
+ "long": 0.9961035346507097
7
+ }
8
+ }
models/saved/roberta_model/cm_roberta.png ADDED
models/saved/roberta_model/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_cross_attention": false,
3
+ "architectures": [
4
+ "RobertaForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "dtype": "float32",
10
+ "eos_token_id": 2,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "is_decoder": false,
17
+ "layer_norm_eps": 1e-05,
18
+ "max_position_embeddings": 514,
19
+ "model_type": "roberta",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 12,
22
+ "pad_token_id": 1,
23
+ "problem_type": "single_label_classification",
24
+ "tie_word_embeddings": true,
25
+ "transformers_version": "5.5.3",
26
+ "type_vocab_size": 1,
27
+ "use_cache": false,
28
+ "vocab_size": 50265
29
+ }
models/saved/roberta_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c3b04b2a6f1e1c303664a07df3c950f97f4571b2f5bfebfce7b5851edbac430
3
+ size 498612800
models/saved/roberta_model/roberta_oof.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c60da5d3c8e05b0562eafac205e93d2f8ec51649e01675b9ab9d9a087d6385ce
3
+ size 161240
models/saved/roberta_model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/saved/roberta_model/tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "cls_token": "<s>",
6
+ "eos_token": "</s>",
7
+ "errors": "replace",
8
+ "is_local": false,
9
+ "mask_token": "<mask>",
10
+ "model_max_length": 512,
11
+ "pad_token": "<pad>",
12
+ "sep_token": "</s>",
13
+ "tokenizer_class": "RobertaTokenizer",
14
+ "trim_offsets": true,
15
+ "unk_token": "<unk>"
16
+ }
models/saved/roberta_model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7c09d4d3610e99efbd18eb40d7b9c533e6d00d9ce19be89748b4377176c5d72
3
+ size 5329
models/saved/tokenizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f106d0d56f5cc7a71b6012570e2f9c4385b0462b3473c6191957e7963d3979d
3
+ size 4361433
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas>=2.0.0
2
+ numpy>=1.24.0
3
+ scikit-learn>=1.3.0
4
+ torch>=2.0.0
5
+ transformers>=4.30.0
6
+ datasets>=2.14.0
7
+ sentence-transformers>=2.2.0
8
+ xgboost>=2.0.0
9
+ spacy>=3.6.0
10
+ duckduckgo-search>=4.0.0
11
+ streamlit>=1.28.0
12
+ plotly>=5.18.0
13
+ newspaper3k>=0.2.8
14
+ pyyaml>=6.0
15
+ tqdm>=4.65.0
16
+ joblib>=1.3.0
17
+ lxml_html_clean
18
+ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl
run_pipeline.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import subprocess
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
8
+ logger = logging.getLogger("run_pipeline")
9
+
10
+ def execute_stage(stage_num):
11
+ logger.info(f"========== TRIGGERING STAGE {stage_num} ==========")
12
+ if stage_num == 1:
13
+ script = "src/stage1_ingestion.py"
14
+ elif stage_num == 2:
15
+ script = "src/stage2_preprocessing.py"
16
+ elif stage_num == 3:
17
+ script = "src/stage3_training.py"
18
+ else:
19
+ logger.error(f"Unknown Stage: {stage_num}")
20
+ return
21
+
22
+ if not os.path.exists(script):
23
+ logger.error(f"Cannot find script: {script}")
24
+ sys.exit(1)
25
+
26
+ res = subprocess.run([sys.executable, script])
27
+ if res.returncode != 0:
28
+ logger.error(f"Stage {stage_num} failed!")
29
+ sys.exit(1)
30
+ logger.info(f"========== STAGE {stage_num} FINISHED ==========\n")
31
+
32
+
33
+ def execute_evaluation():
34
+ logger.info("========== TRIGGERING FINAL HOLD-OUT BENCHMARK ==========")
35
+ import pandas as pd
36
+ import numpy as np
37
+ from tqdm import tqdm
38
+ from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
39
+
40
+ # Needs to be imported late so it doesn't fail if dependencies aren't setup
41
+ from src.stage4_inference import predict_article
42
+
43
+ df_path = "data/splits/df_holdout.csv"
44
+ if not os.path.exists(df_path):
45
+ logger.error(f"Holdout file missing at {df_path}. Run Stages 1-3 first.")
46
+ sys.exit(1)
47
+
48
+ df = pd.read_csv(df_path)
49
+ logger.info(f"Loaded {len(df)} Stratified Holdout records.")
50
+
51
+ y_true = df["binary_label"].values
52
+ y_pred = []
53
+
54
+ logger.info("Executing isolated pipeline inference across holdout targets (RAG safely bypassed)...")
55
+ logger.info("NOTE: Since this evaluates the entire heavy 4-model ensemble locally, it may take several minutes.")
56
+
57
+ for i, row in tqdm(df.iterrows(), total=len(df), desc="Benchmarking Evaluator"):
58
+ # We manually map the inference parameters directly into the ultimate test pipeline
59
+ res = predict_article(
60
+ title=row.get("title", ""),
61
+ text=row.get("text", ""),
62
+ source_domain=row.get("source_domain", ""),
63
+ published_date=row.get("published_date", ""),
64
+ mode="full",
65
+ trigger_rag=False
66
+ )
67
+
68
+ # New 4-tier verdict mapping:
69
+ # TRUE / UNCERTAIN → 1 (real news)
70
+ # LIKELY FALSE / FALSE → 0 (fake news)
71
+ v = res["verdict"]
72
+ pred_label = 1 if v in ("TRUE", "UNCERTAIN") else 0
73
+ y_pred.append(pred_label)
74
+
75
+ y_pred = np.array(y_pred)
76
+ acc = accuracy_score(y_true, y_pred)
77
+
78
+ logger.info(f"\n================ BENCHMARK RESULTS ================")
79
+ logger.info(f"Final Architecture Accuracy: {acc * 100:.2f}%")
80
+ logger.info("\n" + classification_report(y_true, y_pred, target_names=["Fake News (0)", "True News (1)"]))
81
+ logger.info(f"Confusion Matrix:\n{confusion_matrix(y_true, y_pred)}")
82
+ logger.info("===================================================\n")
83
+
84
+
85
+ if __name__ == "__main__":
86
+ parser = argparse.ArgumentParser(description="Fake News Detection System Pipeline")
87
+ parser.add_argument("--stage", nargs="+", type=int, choices=[1, 2, 3], help="Specify stages to run (e.g. --stage 1 2 3)")
88
+ parser.add_argument("--eval", action="store_true", help="Evaluate the architecture natively on the stratified holdout benchmark")
89
+
90
+ args = parser.parse_args()
91
+
92
+ if args.stage:
93
+ for s in args.stage:
94
+ execute_stage(s)
95
+
96
+ if args.eval:
97
+ execute_evaluation()
98
+
99
+ if not args.stage and not args.eval:
100
+ parser.print_help()
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Fake News Detection — Source Package
src/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Fake News Detection — Models Package
src/models/distilbert_model.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import logging
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ from sklearn.model_selection import train_test_split
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForSequenceClassification,
12
+ Trainer,
13
+ TrainingArguments,
14
+ DataCollatorWithPadding
15
+ )
16
+ from datasets import Dataset
17
+
18
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+ if str(_PROJECT_ROOT) not in sys.path:
20
+ sys.path.insert(0, str(_PROJECT_ROOT))
21
+
22
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
23
+ logger = logging.getLogger("distilbert_model")
24
+
25
+ def train_distilbert(cfg, splits_dir, save_dir):
26
+ os.makedirs(save_dir, exist_ok=True)
27
+
28
+ # 1. Load Data
29
+ train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
30
+ val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
31
+
32
+ train_df["clean_text"] = train_df["clean_text"].fillna("")
33
+ val_df["clean_text"] = val_df["clean_text"].fillna("")
34
+
35
+ maxlen = cfg.get("preprocessing", {}).get("bert_max_len", 512)
36
+ batch_size = cfg.get("training", {}).get("bert_batch_size", 16)
37
+ epochs = cfg.get("training", {}).get("bert_epochs", 3)
38
+ lr = float(cfg.get("training", {}).get("lr_learning_rate", 2e-5))
39
+
40
+ logger.info("Loading DistilBERT tokenizer...")
41
+ model_name = "distilbert-base-uncased"
42
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
43
+
44
+ # 2. Tokenization Helper
45
+ def tokenize_function(examples):
46
+ return tokenizer(examples["text"], padding=False, truncation=True, max_length=maxlen)
47
+
48
+ # 3. Create OOF Proxy Split (80/20) safely to accelerate pipeline training (avoid 5-fold computation cost)
49
+ idx_train, idx_meta_val = train_test_split(
50
+ range(len(train_df)), test_size=0.20,
51
+ stratify=train_df["binary_label"], random_state=42
52
+ )
53
+
54
+ subset_train_df = train_df.iloc[idx_train].copy()
55
+
56
+ # 4. Convert to HuggingFace Datasets
57
+ hf_sub_train = Dataset.from_pandas(pd.DataFrame({
58
+ "text": subset_train_df["clean_text"], "labels": subset_train_df["binary_label"]
59
+ }), preserve_index=False)
60
+
61
+ hf_full_train = Dataset.from_pandas(pd.DataFrame({
62
+ "text": train_df["clean_text"], "labels": train_df["binary_label"]
63
+ }), preserve_index=False)
64
+
65
+ hf_val = Dataset.from_pandas(pd.DataFrame({
66
+ "text": val_df["clean_text"], "labels": val_df["binary_label"]
67
+ }), preserve_index=False)
68
+
69
+ logger.info("Tokenizing datasets...")
70
+ hf_sub_train = hf_sub_train.map(tokenize_function, batched=True)
71
+ hf_full_train = hf_full_train.map(tokenize_function, batched=True)
72
+ hf_val = hf_val.map(tokenize_function, batched=True)
73
+
74
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
75
+
76
+ # 5. Initialize Model
77
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
78
+
79
+ # 6. Trainer Setup
80
+ training_args = TrainingArguments(
81
+ output_dir=os.path.join(save_dir, "checkpoints"),
82
+ eval_strategy="epoch",
83
+ save_strategy="epoch",
84
+ learning_rate=lr,
85
+ per_device_train_batch_size=batch_size,
86
+ per_device_eval_batch_size=batch_size,
87
+ gradient_accumulation_steps=2,
88
+ dataloader_num_workers=2,
89
+ num_train_epochs=epochs,
90
+ weight_decay=0.01,
91
+ load_best_model_at_end=True,
92
+ metric_for_best_model="eval_loss",
93
+ greater_is_better=False,
94
+ fp16=torch.cuda.is_available(),
95
+ disable_tqdm=False
96
+ )
97
+
98
+ trainer = Trainer(
99
+ model=model,
100
+ args=training_args,
101
+ train_dataset=hf_sub_train,
102
+ eval_dataset=hf_val,
103
+ processing_class=tokenizer,
104
+ data_collator=data_collator,
105
+ )
106
+
107
+ # 7. Train
108
+ logger.info("Starting DistilBERT internal proxy training...")
109
+ trainer.train()
110
+
111
+ # 8. Save Model
112
+ logger.info("Saving final fine-tuned model...")
113
+ trainer.save_model(save_dir)
114
+ tokenizer.save_pretrained(save_dir)
115
+
116
+ # 9. Extract OOF over the entire training set
117
+ logger.info("Generating OOF predictions on full train set proxy wrapper...")
118
+ oof_preds = trainer.predict(hf_full_train)
119
+ # probabilities for class 1 (True)
120
+ oof_probas = torch.softmax(torch.tensor(oof_preds.predictions), dim=-1)[:, 1].numpy()
121
+ np.save(os.path.join(save_dir, "distilbert_oof.npy"), oof_probas)
122
+ logger.info("Saved distilbert_oof.npy")
123
+
124
+ # Validation evaluation mapped later by main loop, or manually if desired.
125
+ val_preds_out = trainer.predict(hf_val)
126
+ val_probas = torch.softmax(torch.tensor(val_preds_out.predictions), dim=-1)[:, 1].numpy()
127
+
128
+ from src.models.logistic_model import plot_and_save_cm
129
+ plot_and_save_cm(
130
+ val_df["binary_label"],
131
+ (val_probas > 0.5).astype(int),
132
+ os.path.join(save_dir, "cm.png"),
133
+ title="DistilBERT Confusion Matrix"
134
+ )
135
+
136
+ logger.info("DistilBERT Training completed!")
137
+
138
+
139
+ # ====================================================================
140
+ # OPTIONAL: Full K-Fold OOF (GPU-intensive)
141
+ # --------------------------------------------------------------------
142
+ # The strategy above saves enormous compute by generating a single
143
+ # proxy model to predict the full training pool. A strict K-Fold
144
+ # architecture requires training DistilBERT 5 entirely separate
145
+ # instances which spans roughly 15+ epochs locally. Use below
146
+ # if massive parallel A100s are available.
147
+ #
148
+ """
149
+ from sklearn.model_selection import StratifiedKFold
150
+
151
+ def strict_kfold_distilbert(train_df, tokenize_function, data_collator, lr, batch_size, epochs, save_dir):
152
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
153
+ oof_probas = np.zeros(len(train_df), dtype=np.float32)
154
+
155
+ for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df["binary_label"])):
156
+ logger.info(f"Training Fold {fold+1}/5")
157
+ df_train = train_df.iloc[train_idx].copy()
158
+ df_val = train_df.iloc[val_idx].copy()
159
+
160
+ ds_train = Dataset.from_pandas(pd.DataFrame({"text": df_train["clean_text"], "labels": df_train["binary_label"]}), preserve_index=False).map(tokenize_function, batched=True)
161
+ ds_val = Dataset.from_pandas(pd.DataFrame({"text": df_val["clean_text"], "labels": df_val["binary_label"]}), preserve_index=False).map(tokenize_function, batched=True)
162
+
163
+ model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
164
+
165
+ training_args = TrainingArguments(
166
+ output_dir=os.path.join(save_dir, f"fold_{fold}"),
167
+ eval_strategy="epoch",
168
+ save_strategy="epoch",
169
+ learning_rate=lr,
170
+ per_device_train_batch_size=batch_size,
171
+ num_train_epochs=epochs,
172
+ fp16=torch.cuda.is_available(),
173
+ load_best_model_at_end=True,
174
+ )
175
+
176
+ trainer = Trainer(
177
+ model=model,
178
+ args=training_args,
179
+ train_dataset=ds_train,
180
+ eval_dataset=ds_val,
181
+ data_collator=data_collator,
182
+ )
183
+
184
+ trainer.train()
185
+ fold_preds = trainer.predict(ds_val)
186
+ oof_probas[val_idx] = torch.softmax(torch.tensor(fold_preds.predictions), dim=-1)[:, 1].numpy()
187
+
188
+ np.save(os.path.join(save_dir, "distilbert_oof.npy"), oof_probas)
189
+ """
190
+ # ====================================================================
191
+
192
+ if __name__ == "__main__":
193
+ import yaml
194
+ cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
195
+ with open(cfg_path, "r", encoding="utf-8") as file:
196
+ config = yaml.safe_load(file)
197
+
198
+ s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"])
199
+ m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "distilbert_model")
200
+
201
+ train_distilbert(config, s_dir, m_dir)
src/models/logistic_model.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import logging
5
+ import time
6
+ import numpy as np
7
+ import pandas as pd
8
+ import joblib
9
+ from sklearn.pipeline import Pipeline
10
+ from sklearn.compose import ColumnTransformer
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.preprocessing import OneHotEncoder
13
+ from sklearn.linear_model import LogisticRegression
14
+ from sklearn.model_selection import StratifiedKFold, cross_val_predict, GridSearchCV
15
+ from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
16
+ from matplotlib import pyplot as plt
17
+
18
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+ if str(_PROJECT_ROOT) not in sys.path:
20
+ sys.path.insert(0, str(_PROJECT_ROOT))
21
+
22
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
23
+ logger = logging.getLogger("logistic_model")
24
+
25
+
26
+ def load_data(splits_dir):
27
+ """Load train and val pandas dataframes, maintaining clean_text and text_length_bucket."""
28
+ train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
29
+ val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
30
+
31
+ # Fill NaN just in case
32
+ train_df["clean_text"] = train_df["clean_text"].fillna("")
33
+ val_df["clean_text"] = val_df["clean_text"].fillna("")
34
+
35
+ return train_df, val_df
36
+
37
+
38
+ def plot_and_save_cm(y_true, y_pred, path, title="Logistic Regression Confusion Matrix"):
39
+ """Save confusion matrix as a PNG."""
40
+ cm = confusion_matrix(y_true, y_pred)
41
+ fig, ax = plt.subplots(figsize=(5, 5))
42
+ ax.matshow(cm, cmap=plt.cm.Blues, alpha=0.3)
43
+ for i in range(cm.shape[0]):
44
+ for j in range(cm.shape[1]):
45
+ ax.text(x=j, y=i, s=cm[i, j], va='center', ha='center', size='xx-large')
46
+ plt.xlabel('Predicted Label')
47
+ plt.ylabel('True Label')
48
+ plt.title(title)
49
+ plt.tight_layout()
50
+ plt.savefig(path)
51
+ plt.close()
52
+
53
+
54
+ def train_logistic_model(cfg, splits_dir, save_dir):
55
+ logger.info("Initializing Logistic Regression Training...")
56
+ os.makedirs(save_dir, exist_ok=True)
57
+
58
+ train_df, val_df = load_data(splits_dir)
59
+ y_train = train_df["binary_label"].values
60
+ y_val = val_df["binary_label"].values
61
+
62
+ max_features = cfg.get("preprocessing", {}).get("max_tfidf_features", 50000)
63
+
64
+ # Define ColumnTransformer for generic pipeline feature stack
65
+ preprocessor = ColumnTransformer(
66
+ transformers=[
67
+ ("tfidf", TfidfVectorizer(max_features=max_features, ngram_range=(1, 2)), "clean_text"),
68
+ ("cat", OneHotEncoder(handle_unknown="ignore"), ["text_length_bucket"])
69
+ ],
70
+ remainder="drop"
71
+ )
72
+
73
+ # Define Model
74
+ log_reg = LogisticRegression(class_weight="balanced", random_state=42, max_iter=1000)
75
+
76
+ pipeline = Pipeline(steps=[
77
+ ("preprocessor", preprocessor),
78
+ ("classifier", log_reg)
79
+ ])
80
+
81
+ # K-Fold OOF Predictions
82
+ logger.info("Generating 5-Fold OOF predictions on Train set...")
83
+ cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
84
+ # Using method='predict_proba' returns a 2D array [n_samples, 2]
85
+ oof_probas = cross_val_predict(pipeline, train_df, y_train, cv=cv, method='predict_proba', n_jobs=-1)
86
+
87
+ np.save(os.path.join(save_dir, "lr_oof.npy"), oof_probas[:, 1])
88
+ logger.info("Saved OOF predictions (lr_oof.npy)")
89
+
90
+ # Hyperparameter Tuning on full Train via GridSearch
91
+ logger.info("Hyperparameter tuning C over 5-folds...")
92
+ param_grid = {'classifier__C': [0.1, 1.0, 10.0]}
93
+
94
+ grid_search = GridSearchCV(pipeline, param_grid, cv=cv, scoring='f1_macro', n_jobs=-1)
95
+ grid_search.fit(train_df, y_train)
96
+
97
+ best_pipeline = grid_search.best_estimator_
98
+ logger.info(f"Best parameter C: {grid_search.best_params_['classifier__C']}")
99
+
100
+ # Validation Evaluation
101
+ val_probas = best_pipeline.predict_proba(val_df)[:, 1]
102
+ val_preds = (val_probas >= 0.5).astype(int)
103
+
104
+ logger.info("Validation Classification Report:\n" + classification_report(y_val, val_preds))
105
+ roc_auc = roc_auc_score(y_val, val_probas)
106
+ logger.info(f"ROC-AUC: {roc_auc:.4f}")
107
+
108
+ # Generate Evaluation Artifacts
109
+ plot_and_save_cm(y_val, val_preds, os.path.join(save_dir, "cm.png"))
110
+
111
+ # Compute accuracy per text length bucket on val
112
+ bucket_acc = {}
113
+ for b in ["short", "medium", "long"]:
114
+ b_mask = (val_df["text_length_bucket"] == b)
115
+ if b_mask.sum() > 0:
116
+ acc = (val_preds[b_mask] == y_val[b_mask]).mean()
117
+ bucket_acc[b] = acc
118
+
119
+ metrics = {
120
+ "roc_auc": float(roc_auc),
121
+ "bucket_accuracy": {k: float(v) for k, v in bucket_acc.items()}
122
+ }
123
+ with open(os.path.join(save_dir, "metrics.json"), "w") as f:
124
+ json.dump(metrics, f, indent=2)
125
+
126
+ # Save Pipeline
127
+ joblib.dump(best_pipeline, os.path.join(save_dir, "logistic_model.pkl"))
128
+ logger.info("Saved Logistic Regression Pipeline to format `logistic_model.pkl`.")
129
+
130
+ if __name__ == "__main__":
131
+ import yaml
132
+ cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
133
+ with open(cfg_path, "r", encoding="utf-8") as file:
134
+ config = yaml.safe_load(file)
135
+
136
+ s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"])
137
+ m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "logistic_model")
138
+
139
+ t0 = time.time()
140
+ train_logistic_model(config, s_dir, m_dir)
141
+ print(f"Total time: {time.time() - t0:.2f}s")
src/models/lstm_model.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import logging
5
+ import time
6
+ import pickle
7
+ import copy
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import TensorDataset, DataLoader, Subset
14
+ from sklearn.model_selection import StratifiedKFold
15
+ from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
16
+ from matplotlib import pyplot as plt
17
+
18
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+ if str(_PROJECT_ROOT) not in sys.path:
20
+ sys.path.insert(0, str(_PROJECT_ROOT))
21
+
22
+ # We need the Tokenizer from stage 2 to execute texts_to_sequences natively
23
+ from src.stage2_preprocessing import KerasStyleTokenizer
24
+
25
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
26
+ logger = logging.getLogger("lstm_model")
27
+
28
+ # ── Architecture ──────────────────────────────────────
29
+ class SpatialDropout1D(nn.Module):
30
+ def __init__(self, p=0.3):
31
+ super().__init__()
32
+ self.p = p
33
+
34
+ def forward(self, x):
35
+ if not self.training or self.p == 0:
36
+ return x
37
+ # x is (batch, seq_len, embed_dim)
38
+ # convert to (batch, embed_dim, seq_len)
39
+ x = x.permute(0, 2, 1)
40
+ # 1D spatial dropout is equivalent to 2d dropout with height 1
41
+ # nn.Dropout2d drops entire channels (which are our embedding dimensions)
42
+ x = x.unsqueeze(3)
43
+ x = F.dropout2d(x, p=self.p, training=self.training)
44
+ x = x.squeeze(3)
45
+ return x.permute(0, 2, 1)
46
+
47
+ class BiLSTMClassifier(nn.Module):
48
+ def __init__(self, vocab_size, embedding_matrix=None):
49
+ super().__init__()
50
+ # Embedding(vocab_size, 100)
51
+ self.embedding = nn.Embedding(vocab_size, 100, padding_idx=0)
52
+ if embedding_matrix is not None:
53
+ self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix))
54
+ self.embedding.weight.requires_grad = False
55
+
56
+ self.spatial_drop = SpatialDropout1D(0.3)
57
+
58
+ # Bi-LSTM(100->128, bidirectional=True)
59
+ self.lstm1 = nn.LSTM(100, 128, bidirectional=True, batch_first=True)
60
+ # Bi-LSTM(256->64, bidirectional=True)
61
+ self.lstm2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True)
62
+
63
+ # Linear(128, 64) + ReLU
64
+ self.fc1 = nn.Linear(128, 64)
65
+ self.dropout = nn.Dropout(0.4)
66
+ # Linear(64, 1) + Sigmoid (handled via BCEWithLogitsLoss below conceptually, or explicitly applied)
67
+ self.fc2 = nn.Linear(64, 1)
68
+
69
+ def forward(self, x):
70
+ h = self.embedding(x)
71
+ h = self.spatial_drop(h)
72
+
73
+ h, _ = self.lstm1(h)
74
+ # Taking last states? Typically Keras `return_sequences=False` on the 2nd LSTM
75
+ # means it takes the final hidden state of the sequence
76
+ _, (h_n, _) = self.lstm2(h)
77
+
78
+ # h_n shape for Bi-LSTM: (2, batch, hidden_size)
79
+ # Concatenate forward and backward final states
80
+ h_concat = torch.cat((h_n[-2,:,:], h_n[-1,:,:]), dim=1) # shape: (batch, 128)
81
+
82
+ out = F.relu(self.fc1(h_concat))
83
+ out = self.dropout(out)
84
+ logits = self.fc2(out)
85
+
86
+ return logits.squeeze(1)
87
+
88
+
89
+ # ── Utilities ──────────────────────────────────────
90
+ def pad_sequences(sequences, maxlen=512, padding='post'):
91
+ padded = np.zeros((len(sequences), maxlen), dtype=np.int64)
92
+ for i, seq in enumerate(sequences):
93
+ seq = seq[:maxlen]
94
+ if padding == 'post':
95
+ padded[i, :len(seq)] = seq
96
+ else:
97
+ padded[i, -len(seq):] = seq
98
+ return padded
99
+
100
+ def load_glove_embeddings(glove_path, word_index, embed_dim=100):
101
+ logger.info(f"Loading GloVe embeddings from {glove_path}...")
102
+ embeddings_index = {}
103
+ with open(glove_path, "r", encoding="utf-8") as f:
104
+ for line in f:
105
+ values = line.split()
106
+ word = values[0]
107
+ coefs = np.asarray(values[1:], dtype='float32')
108
+ embeddings_index[word] = coefs
109
+
110
+ vocab_size = len(word_index) + 1 # 1 for padding
111
+ embedding_matrix = np.zeros((vocab_size, embed_dim), dtype=np.float32)
112
+ hits, misses = 0, 0
113
+ for word, i in word_index.items():
114
+ embedding_vector = embeddings_index.get(word)
115
+ if embedding_vector is not None:
116
+ embedding_matrix[i] = embedding_vector
117
+ hits += 1
118
+ else:
119
+ misses += 1
120
+ logger.info(f"GloVe mapped: {hits} hits, {misses} misses.")
121
+ return embedding_matrix, vocab_size
122
+
123
+ def plot_and_save_cm(y_true, y_pred, path):
124
+ cm = confusion_matrix(y_true, (np.array(y_pred) > 0.5).astype(int))
125
+ fig, ax = plt.subplots(figsize=(5, 5))
126
+ ax.matshow(cm, cmap=plt.cm.Blues, alpha=0.3)
127
+ for i in range(cm.shape[0]):
128
+ for j in range(cm.shape[1]):
129
+ ax.text(x=j, y=i, s=cm[i, j], va='center', ha='center', size='xx-large')
130
+ plt.xlabel('Predicted Label')
131
+ plt.ylabel('True Label')
132
+ plt.title('Bi-LSTM Confusion Matrix')
133
+ plt.tight_layout()
134
+ plt.savefig(path)
135
+ plt.close()
136
+
137
+ # ── Training Loop ──────────────────────────────────────
138
+ def train_epoch(model, loader, optimizer, criterion, device):
139
+ model.train()
140
+ total_loss = 0
141
+ for x_batch, y_batch in loader:
142
+ x_batch, y_batch = x_batch.to(device), y_batch.to(device)
143
+ optimizer.zero_grad()
144
+ logits = model(x_batch)
145
+ loss = criterion(logits, y_batch)
146
+ loss.backward()
147
+ optimizer.step()
148
+ total_loss += loss.item() * x_batch.size(0)
149
+ return total_loss / len(loader.dataset)
150
+
151
+ @torch.no_grad()
152
+ def eval_model(model, loader, criterion, device):
153
+ model.eval()
154
+ total_loss = 0
155
+ all_preds = []
156
+ for x_batch, y_batch in loader:
157
+ x_batch, y_batch = x_batch.to(device), y_batch.to(device)
158
+ logits = model(x_batch)
159
+ loss = criterion(logits, y_batch)
160
+ total_loss += loss.item() * x_batch.size(0)
161
+ probas = torch.sigmoid(logits).cpu().numpy()
162
+ all_preds.extend(probas)
163
+ return total_loss / len(loader.dataset), np.array(all_preds)
164
+
165
+ def train_lstm_logic(cfg, splits_dir, save_dir, glove_path):
166
+ os.makedirs(save_dir, exist_ok=True)
167
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
+ logger.info(f"Using device: {device}")
169
+
170
+ # Load tokenized resources
171
+ train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
172
+ val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
173
+ y_train = np.float32(train_df["binary_label"].values)
174
+ y_val = np.float32(val_df["binary_label"].values)
175
+
176
+ with open(os.path.join(_PROJECT_ROOT, cfg["paths"]["models_dir"], "tokenizer.pkl"), "rb") as f:
177
+ tokenizer = pickle.load(f)
178
+
179
+ maxlen = cfg.get("preprocessing", {}).get("lstm_max_len", 512)
180
+ batch_size = cfg.get("training", {}).get("lstm_batch_size", 64)
181
+ epochs = cfg.get("training", {}).get("lstm_epochs", 10)
182
+
183
+ logger.info("Transforming texts to padded sequences...")
184
+ X_train_seq = tokenizer.texts_to_sequences(train_df["clean_text"].fillna(""))
185
+ X_val_seq = tokenizer.texts_to_sequences(val_df["clean_text"].fillna(""))
186
+
187
+ X_train_pad = pad_sequences(X_train_seq, maxlen=maxlen, padding='post')
188
+ X_val_pad = pad_sequences(X_val_seq, maxlen=maxlen, padding='post')
189
+
190
+ # Embedding matrix
191
+ emb_matrix, vocab_size = load_glove_embeddings(glove_path, tokenizer.word_index)
192
+
193
+ # Class weights balancing formula: n_samples / (n_classes * np.bincount(y))
194
+ class_counts = np.bincount(y_train.astype(int))
195
+ pos_weight = torch.tensor([class_counts[0] / class_counts[1]], dtype=torch.float32).to(device)
196
+
197
+ # Datasets
198
+ train_tensor = TensorDataset(torch.from_numpy(X_train_pad).long(), torch.from_numpy(y_train))
199
+ val_tensor = TensorDataset(torch.from_numpy(X_val_pad).long(), torch.from_numpy(y_val))
200
+ val_loader = DataLoader(val_tensor, batch_size=batch_size, shuffle=False)
201
+
202
+ # --- 5-Fold OOF Predictions ---
203
+ logger.info("Starting 5-Fold OOF generation...")
204
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
205
+ oof_preds = np.zeros_like(y_train, dtype=np.float32)
206
+
207
+ criterion_kfold = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
208
+
209
+ for fold, (t_idx, v_idx) in enumerate(skf.split(X_train_pad, y_train)):
210
+ logger.info(f"OOF Fold {fold+1}/5")
211
+
212
+ fold_train_ds = Subset(train_tensor, t_idx)
213
+ fold_val_ds = Subset(train_tensor, v_idx)
214
+ fold_train_loader = DataLoader(fold_train_ds, batch_size=batch_size, shuffle=True)
215
+ fold_val_loader = DataLoader(fold_val_ds, batch_size=batch_size, shuffle=False)
216
+
217
+ model = BiLSTMClassifier(vocab_size, emb_matrix).to(device)
218
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
219
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5)
220
+
221
+ best_val_loss = float('inf')
222
+ patience_counter = 0
223
+ best_weights = copy.deepcopy(model.state_dict())
224
+
225
+ for ep in range(epochs): # Or hardcode early stop tightly for OOF e.g., 3-4 epochs max to save time
226
+ t_loss = train_epoch(model, fold_train_loader, optimizer, criterion_kfold, device)
227
+ v_loss, v_preds = eval_model(model, fold_val_loader, criterion_kfold, device)
228
+ scheduler.step(v_loss)
229
+
230
+ if v_loss < best_val_loss:
231
+ best_val_loss = v_loss
232
+ best_weights = copy.deepcopy(model.state_dict())
233
+ patience_counter = 0
234
+ else:
235
+ patience_counter += 1
236
+ if patience_counter >= 3:
237
+ break
238
+
239
+ # Apply the best model
240
+ model.load_state_dict(best_weights)
241
+ _, fold_best_preds = eval_model(model, fold_val_loader, criterion_kfold, device)
242
+ oof_preds[v_idx] = fold_best_preds
243
+
244
+ np.save(os.path.join(save_dir, "lstm_oof.npy"), oof_preds)
245
+ logger.info("Saved OOF predictions (lstm_oof.npy).")
246
+
247
+ # --- Final Training on ALL Data ---
248
+ logger.info("Starting final model training on full Train split...")
249
+ train_loader = DataLoader(train_tensor, batch_size=batch_size, shuffle=True)
250
+ model = BiLSTMClassifier(vocab_size, emb_matrix).to(device)
251
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
252
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5)
253
+
254
+ best_val_loss = float('inf')
255
+ best_weights = copy.deepcopy(model.state_dict())
256
+ patience_counter = 0
257
+
258
+ for ep in range(epochs):
259
+ t_loss = train_epoch(model, train_loader, optimizer, criterion_kfold, device)
260
+ v_loss, v_preds = eval_model(model, val_loader, criterion_kfold, device)
261
+ scheduler.step(v_loss)
262
+
263
+ logger.info(f" Epoch {ep+1}/{epochs} | Train Loss: {t_loss:.4f} | Val Loss: {v_loss:.4f}")
264
+ if v_loss < best_val_loss:
265
+ best_val_loss = v_loss
266
+ best_weights = copy.deepcopy(model.state_dict())
267
+ patience_counter = 0
268
+ else:
269
+ patience_counter += 1
270
+ if patience_counter >= 3:
271
+ logger.info(" EarlyStopping triggered.")
272
+ break
273
+
274
+ model.load_state_dict(best_weights)
275
+ torch.save(model.state_dict(), os.path.join(save_dir, "model.pt"))
276
+ logger.info("Saved final LSTM weights.")
277
+
278
+ # Evaluate Validation Split
279
+ _, val_preds_probas = eval_model(model, val_loader, criterion_kfold, device)
280
+ val_preds_binary = (val_preds_probas >= 0.5).astype(int)
281
+
282
+ logger.info("Validation Classification Report:\n" + classification_report(y_val, val_preds_binary))
283
+ roc_auc = roc_auc_score(y_val, val_preds_probas)
284
+ logger.info(f"ROC-AUC: {roc_auc:.4f}")
285
+
286
+ plot_and_save_cm(y_val, val_preds_probas, os.path.join(save_dir, "cm.png"))
287
+
288
+ bucket_acc = {}
289
+ for b in ["short", "medium", "long"]:
290
+ b_mask = (val_df["text_length_bucket"] == b).values
291
+ if b_mask.sum() > 0:
292
+ acc = (val_preds_binary[b_mask] == y_val[b_mask]).mean()
293
+ bucket_acc[b] = acc
294
+
295
+ metrics = {
296
+ "roc_auc": float(roc_auc),
297
+ "bucket_accuracy": {k: float(v) for k, v in bucket_acc.items()}
298
+ }
299
+ with open(os.path.join(save_dir, "metrics.json"), "w") as f:
300
+ json.dump(metrics, f, indent=2)
301
+
302
+ if __name__ == "__main__":
303
+ import yaml
304
+ cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
305
+ with open(cfg_path, "r", encoding="utf-8") as file:
306
+ config = yaml.safe_load(file)
307
+
308
+ s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"])
309
+ m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "lstm_model")
310
+ g_path = os.path.join(_PROJECT_ROOT, config["paths"]["glove_path"])
311
+
312
+ t0 = time.time()
313
+ train_lstm_logic(config, s_dir, m_dir, g_path)
314
+ print(f"Total time: {time.time() - t0:.2f}s")
src/models/meta_classifier.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import logging
5
+ import joblib
6
+ import torch
7
+ import numpy as np
8
+ import pandas as pd
9
+ from xgboost import XGBClassifier
10
+ from sklearn.calibration import CalibratedClassifierCV
11
+ from sklearn.preprocessing import OneHotEncoder
12
+ from sklearn.compose import ColumnTransformer
13
+ from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
14
+ from matplotlib import pyplot as plt
15
+ from torch.utils.data import TensorDataset, DataLoader
16
+
17
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+ if str(_PROJECT_ROOT) not in sys.path:
19
+ sys.path.insert(0, str(_PROJECT_ROOT))
20
+
21
+ from src.models.lstm_model import BiLSTMClassifier, pad_sequences
22
+ from src.stage2_preprocessing import KerasStyleTokenizer
23
+ import sys
24
+ setattr(sys.modules['__main__'], 'KerasStyleTokenizer', KerasStyleTokenizer)
25
+
26
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
27
+
28
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
29
+ logger = logging.getLogger("meta_classifier")
30
+
31
+ def build_meta_features(df, lr_proba, lstm_proba, distil_proba, roberta_proba, is_train=True, preprocessor=None):
32
+ """
33
+ Construct the meta-feature matrix.
34
+ If is_train is True, preprocessor is fit on the categorical columns.
35
+ """
36
+ df_meta = pd.DataFrame({
37
+ "lr_proba": lr_proba,
38
+ "lstm_proba": lstm_proba,
39
+ "distilbert_proba": distil_proba,
40
+ "roberta_proba": roberta_proba,
41
+ "word_count": df["word_count"],
42
+ "has_date": df["has_date"].astype(int),
43
+ "freshness_score": df["freshness_score"]
44
+ })
45
+
46
+ # Categoricals to encode
47
+ cats = df[["text_length_bucket", "source_domain"]].fillna("unknown")
48
+
49
+ if is_train:
50
+ preprocessor = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
51
+ cat_features = preprocessor.fit_transform(cats)
52
+ else:
53
+ cat_features = preprocessor.transform(cats)
54
+
55
+ X_meta = np.hstack((df_meta.values, cat_features))
56
+ return X_meta, preprocessor
57
+
58
+
59
+ def train_meta_classifier(cfg, splits_dir, models_dir):
60
+ save_dir = os.path.join(models_dir, "meta_classifier")
61
+ os.makedirs(save_dir, exist_ok=True)
62
+
63
+ logger.info("Loading dataset splits...")
64
+ train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
65
+ val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
66
+
67
+ y_train = train_df["binary_label"].values
68
+ y_val = val_df["binary_label"].values
69
+
70
+ # ── 1. Load OOF predictions for Train Set ──
71
+ logger.info("Gathering base model OOF predictions...")
72
+ try:
73
+ lr_oof = np.load(os.path.join(models_dir, "logistic_model", "lr_oof.npy"))
74
+ lstm_oof = np.load(os.path.join(models_dir, "lstm_model", "lstm_oof.npy"))
75
+ distil_oof = np.load(os.path.join(models_dir, "distilbert_model", "distilbert_oof.npy"))
76
+ roberta_oof = np.load(os.path.join(models_dir, "roberta_model", "roberta_oof.npy"))
77
+ except FileNotFoundError as e:
78
+ logger.error(f"Missing OOF file: {e}. Please ensure all base models have trained completely.")
79
+ return
80
+
81
+ roberta_oof = roberta_oof * 0.92
82
+
83
+ X_meta_train, meta_preprocessor = build_meta_features(
84
+ train_df, lr_oof, lstm_oof, distil_oof, roberta_oof, is_train=True
85
+ )
86
+
87
+ # ── 2. Dynamically Generate Val predictions ──
88
+ # Since we need a val set for early stopping, we predict them here.
89
+ logger.info("Generating base model predictions for Validation set...")
90
+
91
+ # Logistic
92
+ lr_pipeline = joblib.load(os.path.join(models_dir, "logistic_model", "logistic_model.pkl"))
93
+ lr_val = lr_pipeline.predict_proba(val_df)[:, 1]
94
+
95
+ # LSTM
96
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+ import pickle
98
+ with open(os.path.join(models_dir, "tokenizer.pkl"), "rb") as f:
99
+ tok = pickle.load(f)
100
+ glove_path = os.path.join(_PROJECT_ROOT, cfg["paths"]["glove_path"])
101
+ from src.models.lstm_model import load_glove_embeddings
102
+ emb_matrix, vocab_size = load_glove_embeddings(glove_path, tok.word_index)
103
+
104
+ maxlen = cfg.get("preprocessing", {}).get("lstm_max_len", 512)
105
+ X_val_seq = tok.texts_to_sequences(val_df["clean_text"].fillna(""))
106
+ X_val_pad = pad_sequences(X_val_seq, maxlen=maxlen, padding='post')
107
+
108
+ lstm_model = BiLSTMClassifier(vocab_size, emb_matrix).to(device)
109
+ lstm_model.load_state_dict(torch.load(os.path.join(models_dir, "lstm_model", "model.pt"), map_location=device))
110
+ lstm_model.eval()
111
+
112
+ val_loader = DataLoader(TensorDataset(torch.from_numpy(X_val_pad).long()), batch_size=64, shuffle=False)
113
+ lstm_val_preds = []
114
+ with torch.no_grad():
115
+ for x_b in val_loader:
116
+ logits = lstm_model(x_b[0].to(device))
117
+ lstm_val_preds.extend(torch.sigmoid(logits).cpu().numpy())
118
+ lstm_val = np.array(lstm_val_preds)
119
+
120
+ # DistilBERT
121
+ d_tok = AutoTokenizer.from_pretrained(os.path.join(models_dir, "distilbert_model"))
122
+ d_mod = AutoModelForSequenceClassification.from_pretrained(os.path.join(models_dir, "distilbert_model")).to(device)
123
+ d_mod.eval()
124
+
125
+ distil_val = []
126
+ with torch.no_grad():
127
+ for text in val_df["clean_text"].fillna(""):
128
+ inputs = d_tok(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
129
+ out = d_mod(**inputs)
130
+ distil_val.append(torch.softmax(out.logits, dim=-1)[0, 1].item())
131
+ distil_val = np.array(distil_val)
132
+
133
+ # RoBERTa
134
+ r_tok = AutoTokenizer.from_pretrained(os.path.join(models_dir, "roberta_model"))
135
+ r_mod = AutoModelForSequenceClassification.from_pretrained(os.path.join(models_dir, "roberta_model")).to(device)
136
+ r_mod.eval()
137
+
138
+ roberta_val = []
139
+ with torch.no_grad():
140
+ for text in val_df["clean_text"].fillna(""):
141
+ inputs = r_tok(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
142
+ out = r_mod(**inputs)
143
+ roberta_val.append(torch.softmax(out.logits, dim=-1)[0, 1].item())
144
+ roberta_val = np.array(roberta_val) * 0.92
145
+
146
+ X_meta_val, _ = build_meta_features(
147
+ val_df, lr_val, lstm_val, distil_val, roberta_val, is_train=False, preprocessor=meta_preprocessor
148
+ )
149
+
150
+ # ── 3. Train Meta-Classifier (XGBoost) ──
151
+ logger.info("Training XGBoost meta-classifier...")
152
+ xgb = XGBClassifier(
153
+ n_estimators=500,
154
+ learning_rate=0.05,
155
+ max_depth=5,
156
+ eval_metric='logloss',
157
+ early_stopping_rounds=20,
158
+ random_state=42
159
+ )
160
+
161
+ xgb.fit(
162
+ X_meta_train, y_train,
163
+ eval_set=[(X_meta_val, y_val)],
164
+ verbose=False
165
+ )
166
+ logger.info(f"XGBoost best iteration: {xgb.best_iteration}")
167
+
168
+ # ── 4. Calibrate Probabilities ──
169
+ logger.info("Calibrating final probabilities via CalibratedClassifierCV on Val set...")
170
+ # 'prefit' means it will only use X_meta_val to calibrate the output
171
+ calibrated_meta = CalibratedClassifierCV(estimator=xgb, method='sigmoid', cv='prefit')
172
+ calibrated_meta.fit(X_meta_val, y_val)
173
+
174
+ # Final Val Score Check
175
+ final_val_probas = calibrated_meta.predict_proba(X_meta_val)[:, 1]
176
+
177
+ # For short texts, dampen confidence toward 0.5 (more uncertain)
178
+ # rather than making a confident wrong prediction
179
+ for i in range(len(final_val_probas)):
180
+ if val_df["word_count"].iloc[i] < 50:
181
+ final_val_probas[i] = 0.5 + (final_val_probas[i] - 0.5) * 0.6
182
+
183
+ final_val_preds = (final_val_probas >= 0.55).astype(int)
184
+
185
+ logger.info("Final Meta-Classifier Classification Report:\n" + classification_report(y_val, final_val_preds))
186
+ roc_auc = roc_auc_score(y_val, final_val_probas)
187
+ logger.info(f"ROC-AUC: {roc_auc:.4f}")
188
+
189
+ from src.models.logistic_model import plot_and_save_cm
190
+ plot_and_save_cm(
191
+ y_val,
192
+ final_val_preds,
193
+ os.path.join(save_dir, "cm.png"),
194
+ title="XGBoost Meta-Classifier Confusion Matrix"
195
+ )
196
+
197
+ bucket_acc = {}
198
+ for b in ["short", "medium", "long"]:
199
+ b_mask = (val_df["text_length_bucket"] == b).values
200
+ if b_mask.sum() > 0:
201
+ acc = (final_val_preds[b_mask] == y_val[b_mask]).mean()
202
+ bucket_acc[b] = acc
203
+
204
+ metrics = {
205
+ "roc_auc": float(roc_auc),
206
+ "bucket_accuracy": {k: float(v) for k, v in bucket_acc.items()}
207
+ }
208
+ with open(os.path.join(save_dir, "metrics.json"), "w") as f:
209
+ json.dump(metrics, f, indent=2)
210
+
211
+ # Save Model Bundle (Pre-processor + Calibrated XGBoost)
212
+ bundle = {
213
+ "preprocessor": meta_preprocessor,
214
+ "model": calibrated_meta
215
+ }
216
+ joblib.dump(bundle, os.path.join(save_dir, "meta_classifier.pkl"))
217
+ logger.info("Saved Meta-Classifier bundle.")
218
+
219
+ if __name__ == "__main__":
220
+ import yaml
221
+ cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
222
+ with open(cfg_path, "r", encoding="utf-8") as file:
223
+ config = yaml.safe_load(file)
224
+
225
+ train_meta_classifier(
226
+ config,
227
+ os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"]),
228
+ os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"])
229
+ )
src/models/roberta_model.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import logging
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ from sklearn.model_selection import train_test_split
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForSequenceClassification,
12
+ Trainer,
13
+ TrainingArguments,
14
+ DataCollatorWithPadding
15
+ )
16
+ from datasets import Dataset
17
+
18
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+ if str(_PROJECT_ROOT) not in sys.path:
20
+ sys.path.insert(0, str(_PROJECT_ROOT))
21
+
22
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
23
+ logger = logging.getLogger("roberta_model")
24
+
25
+ def train_roberta(cfg, splits_dir, save_dir):
26
+ os.makedirs(save_dir, exist_ok=True)
27
+
28
+ # 1. Load Data
29
+ train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
30
+ val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
31
+
32
+ train_df["clean_text"] = train_df["clean_text"].fillna("")
33
+ val_df["clean_text"] = val_df["clean_text"].fillna("")
34
+
35
+ maxlen = cfg.get("preprocessing", {}).get("bert_max_len", 512)
36
+ batch_size = cfg.get("training", {}).get("bert_batch_size", 16)
37
+ epochs = cfg.get("training", {}).get("bert_epochs", 3)
38
+ lr = float(cfg.get("training", {}).get("roberta_learning_rate", 1e-5))
39
+
40
+ logger.info("Loading RoBERTa tokenizer...")
41
+ model_name = "roberta-base"
42
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
43
+
44
+ # 2. Tokenization Helper
45
+ def tokenize_function(examples):
46
+ return tokenizer(examples["text"], padding=False, truncation=True, max_length=maxlen)
47
+
48
+ # 3. Create OOF Proxy Split (80/20) safely
49
+ idx_train, idx_meta_val = train_test_split(
50
+ range(len(train_df)), test_size=0.20,
51
+ stratify=train_df["binary_label"], random_state=42
52
+ )
53
+
54
+ subset_train_df = train_df.iloc[idx_train].copy()
55
+
56
+ # 4. Convert to HuggingFace Datasets
57
+ hf_sub_train = Dataset.from_pandas(pd.DataFrame({
58
+ "text": subset_train_df["clean_text"], "labels": subset_train_df["binary_label"]
59
+ }), preserve_index=False)
60
+
61
+ hf_full_train = Dataset.from_pandas(pd.DataFrame({
62
+ "text": train_df["clean_text"], "labels": train_df["binary_label"]
63
+ }), preserve_index=False)
64
+
65
+ hf_val = Dataset.from_pandas(pd.DataFrame({
66
+ "text": val_df["clean_text"], "labels": val_df["binary_label"]
67
+ }), preserve_index=False)
68
+
69
+ logger.info("Tokenizing datasets...")
70
+ hf_sub_train = hf_sub_train.map(tokenize_function, batched=True)
71
+ hf_full_train = hf_full_train.map(tokenize_function, batched=True)
72
+ hf_val = hf_val.map(tokenize_function, batched=True)
73
+
74
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
75
+
76
+ # 5. Initialize Model
77
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
78
+
79
+ # 6. Trainer Setup
80
+ training_args = TrainingArguments(
81
+ output_dir=os.path.join(save_dir, "checkpoints"),
82
+ eval_strategy="epoch",
83
+ save_strategy="epoch",
84
+ learning_rate=lr,
85
+ per_device_train_batch_size=batch_size,
86
+ per_device_eval_batch_size=batch_size,
87
+ gradient_accumulation_steps=2,
88
+ dataloader_num_workers=2,
89
+ num_train_epochs=epochs,
90
+ weight_decay=0.01,
91
+ load_best_model_at_end=True,
92
+ metric_for_best_model="eval_loss",
93
+ greater_is_better=False,
94
+ fp16=torch.cuda.is_available(),
95
+ disable_tqdm=False
96
+ )
97
+
98
+ trainer = Trainer(
99
+ model=model,
100
+ args=training_args,
101
+ train_dataset=hf_sub_train,
102
+ eval_dataset=hf_val,
103
+ processing_class=tokenizer,
104
+ data_collator=data_collator,
105
+ )
106
+
107
+ # 7. Train
108
+ logger.info("Starting RoBERTa internal proxy training...")
109
+ trainer.train()
110
+
111
+ # 8. Save Model
112
+ logger.info("Saving final fine-tuned model...")
113
+ trainer.save_model(save_dir)
114
+ tokenizer.save_pretrained(save_dir)
115
+
116
+ # 9. Extract OOF over the entire training set
117
+ logger.info("Generating OOF predictions on full train set proxy wrapper...")
118
+ oof_preds = trainer.predict(hf_full_train)
119
+ # probabilities for class 1 (True)
120
+ oof_probas = torch.softmax(torch.tensor(oof_preds.predictions), dim=-1)[:, 1].numpy()
121
+ np.save(os.path.join(save_dir, "roberta_oof.npy"), oof_probas)
122
+ logger.info("Saved roberta_oof.npy")
123
+
124
+ # Validation evaluation
125
+ val_preds_out = trainer.predict(hf_val)
126
+ val_probas = torch.softmax(torch.tensor(val_preds_out.predictions), dim=-1)[:, 1].numpy()
127
+
128
+ from src.models.logistic_model import plot_and_save_cm
129
+ plot_and_save_cm(
130
+ val_df["binary_label"],
131
+ (val_probas > 0.5).astype(int),
132
+ os.path.join(save_dir, "cm.png"),
133
+ title="RoBERTa Confusion Matrix"
134
+ )
135
+
136
+ logger.info("RoBERTa Training completed!")
137
+
138
+
139
+ # ====================================================================
140
+ # OPTIONAL: Full K-Fold OOF (GPU-intensive)
141
+ # --------------------------------------------------------------------
142
+ # A robust 5-Fold implementation for deploying RoBERTa if unconstrained
143
+ # temporal budget scales naturally to GPU cluster arrays.
144
+ #
145
+ """
146
+ from sklearn.model_selection import StratifiedKFold
147
+
148
+ def strict_kfold_roberta(train_df, tokenize_function, data_collator, lr, batch_size, epochs, save_dir):
149
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
150
+ oof_probas = np.zeros(len(train_df), dtype=np.float32)
151
+
152
+ for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df["binary_label"])):
153
+ logger.info(f"Training Fold {fold+1}/5")
154
+ df_train = train_df.iloc[train_idx].copy()
155
+ df_val = train_df.iloc[val_idx].copy()
156
+
157
+ ds_train = Dataset.from_pandas(pd.DataFrame({"text": df_train["clean_text"], "labels": df_train["binary_label"]}), preserve_index=False).map(tokenize_function, batched=True)
158
+ ds_val = Dataset.from_pandas(pd.DataFrame({"text": df_val["clean_text"], "labels": df_val["binary_label"]}), preserve_index=False).map(tokenize_function, batched=True)
159
+
160
+ model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2)
161
+
162
+ training_args = TrainingArguments(
163
+ output_dir=os.path.join(save_dir, f"fold_{fold}"),
164
+ eval_strategy="epoch",
165
+ save_strategy="epoch",
166
+ learning_rate=lr,
167
+ per_device_train_batch_size=batch_size,
168
+ num_train_epochs=epochs,
169
+ fp16=torch.cuda.is_available(),
170
+ load_best_model_at_end=True,
171
+ )
172
+
173
+ trainer = Trainer(
174
+ model=model,
175
+ args=training_args,
176
+ train_dataset=ds_train,
177
+ eval_dataset=ds_val,
178
+ data_collator=data_collator,
179
+ )
180
+
181
+ trainer.train()
182
+ fold_preds = trainer.predict(ds_val)
183
+ oof_probas[val_idx] = torch.softmax(torch.tensor(fold_preds.predictions), dim=-1)[:, 1].numpy()
184
+
185
+ np.save(os.path.join(save_dir, "roberta_oof.npy"), oof_probas)
186
+ """
187
+ # ====================================================================
188
+
189
+ if __name__ == "__main__":
190
+ import yaml
191
+ cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
192
+ with open(cfg_path, "r", encoding="utf-8") as file:
193
+ config = yaml.safe_load(file)
194
+
195
+ s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"])
196
+ m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "roberta_model")
197
+
198
+ train_roberta(config, s_dir, m_dir)
src/stage1_ingestion.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ stage1_ingestion.py — Load, unify, deduplicate, and persist all datasets.
3
+
4
+ Reads from five dataset sources (ISOT, LIAR, Kaggle Combined / News_dataset,
5
+ Multi-Domain / overall, and supplementary training folder), maps them into a
6
+ single canonical schema, performs Sentence-BERT deduplication, and writes the
7
+ result to ``data/processed/unified.csv`` together with label-distribution
8
+ statistics in ``data/processed/stats.json``.
9
+
10
+ Usage:
11
+ python -m src.stage1_ingestion # from fake_news_detection/
12
+ python src/stage1_ingestion.py # direct execution
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import logging
19
+ import os
20
+ import sys
21
+ import time
22
+ import uuid
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple
25
+ from urllib.parse import urlparse
26
+
27
+ import pandas as pd
28
+ import yaml
29
+
30
+ # ── Ensure project root is on sys.path when running directly ──
31
+ _SCRIPT_DIR = Path(__file__).resolve().parent
32
+ _PROJECT_ROOT = _SCRIPT_DIR.parent
33
+ if str(_PROJECT_ROOT) not in sys.path:
34
+ sys.path.insert(0, str(_PROJECT_ROOT))
35
+
36
+ from src.utils.deduplication import deduplicate_dataframe # noqa: E402
37
+ from src.utils.text_utils import clean_empty_texts, build_full_text, word_count
38
+
39
+ # ═══════════════════════════════════════════════════════════
40
+ # Logger
41
+ # ═══════════════════════════════════════════════════════════
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ format="%(asctime)s │ %(levelname)-8s │ %(name)s │ %(message)s",
45
+ datefmt="%H:%M:%S",
46
+ )
47
+ logger = logging.getLogger("stage1_ingestion")
48
+
49
+
50
+ # ═══════════════════════════════════════════════════════════
51
+ # Config loader
52
+ # ═══════════════════════════════════════════════════════════
53
+ def load_config(config_path: Optional[str] = None) -> dict:
54
+ """Load the YAML configuration file.
55
+
56
+ Args:
57
+ config_path: Explicit path to ``config.yaml``. Falls back to
58
+ ``<project_root>/config/config.yaml`` if not provided.
59
+
60
+ Returns:
61
+ Parsed configuration dictionary.
62
+ """
63
+ if config_path is None:
64
+ config_path = str(_PROJECT_ROOT / "config" / "config.yaml")
65
+ with open(config_path, "r", encoding="utf-8") as fh:
66
+ cfg = yaml.safe_load(fh)
67
+ return cfg
68
+
69
+
70
+ # ═══════════════════════════════════════════════════════════
71
+ # Schema constants
72
+ # ═══════════════════════════════════════════════════════════
73
+ UNIFIED_COLUMNS = [
74
+ "article_id",
75
+ "title",
76
+ "text",
77
+ "source_domain",
78
+ "published_date",
79
+ "has_date",
80
+ "binary_label",
81
+ "dataset_origin",
82
+ ]
83
+
84
+
85
+ # ═══════════════════════════════════════════════════════════
86
+ # Helper: extract domain from URL
87
+ # ═══════════════════════════════════════════════════════════
88
+ def extract_domain(url: Optional[str]) -> str:
89
+ """Extract the domain (netloc) from a URL string.
90
+
91
+ Args:
92
+ url: Raw URL (may be ``None`` or malformed).
93
+
94
+ Returns:
95
+ Domain string such as ``"reuters.com"`` or ``"unknown"``.
96
+ """
97
+ if not url or not isinstance(url, str):
98
+ return "unknown"
99
+ url = url.strip()
100
+ if not url.startswith(("http://", "https://")):
101
+ url = "http://" + url
102
+ try:
103
+ netloc = urlparse(url).netloc
104
+ # Strip leading 'www.'
105
+ if netloc.startswith("www."):
106
+ netloc = netloc[4:]
107
+ return netloc if netloc else "unknown"
108
+ except Exception:
109
+ return "unknown"
110
+
111
+
112
+ def _try_parse_date(val) -> pd.Timestamp:
113
+ """Attempt to parse a value into a pandas Timestamp.
114
+
115
+ Args:
116
+ val: Any date-like value.
117
+
118
+ Returns:
119
+ ``pd.Timestamp`` or ``pd.NaT`` on failure.
120
+ """
121
+ if pd.isna(val):
122
+ return pd.NaT
123
+ try:
124
+ return pd.to_datetime(val)
125
+ except Exception:
126
+ return pd.NaT
127
+
128
+
129
+ # ═══════════════════════════════════════════════════════════
130
+ # Dataset-specific loaders
131
+ # ══════════════════════════════════════════��════════════════
132
+
133
+ def load_isot(dataset_root: str) -> pd.DataFrame:
134
+ """Load the ISOT Fake Real News dataset (``True.csv`` + ``Fake.csv``).
135
+
136
+ Located at ``<dataset_root>/fake_real/``.
137
+
138
+ Args:
139
+ dataset_root: Path to the top-level Dataset folder.
140
+
141
+ Returns:
142
+ DataFrame in the unified schema.
143
+ """
144
+ t0 = time.perf_counter()
145
+ logger.info("Loading ISOT dataset …")
146
+
147
+ base = os.path.join(dataset_root, "fake_real")
148
+ true_path = os.path.join(base, "True.csv")
149
+ fake_path = os.path.join(base, "Fake.csv")
150
+
151
+ df_true = pd.read_csv(true_path)
152
+ df_true["binary_label"] = 1
153
+ df_fake = pd.read_csv(fake_path)
154
+ df_fake["binary_label"] = 0
155
+
156
+ df = pd.concat([df_true, df_fake], ignore_index=True)
157
+
158
+ # Columns: title, text, subject, date
159
+ records: List[dict] = []
160
+ for _, row in df.iterrows():
161
+ pub_date = _try_parse_date(row.get("date"))
162
+ records.append({
163
+ "article_id": str(uuid.uuid4()),
164
+ "title": str(row.get("title", "") or ""),
165
+ "text": str(row.get("text", "") or ""),
166
+ "source_domain": "unknown", # ISOT has no URL column
167
+ "published_date": pub_date,
168
+ "has_date": not pd.isna(pub_date),
169
+ "binary_label": int(row["binary_label"]),
170
+ "dataset_origin": "isot",
171
+ })
172
+
173
+ result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
174
+ logger.info(
175
+ "ISOT loaded: %d rows (True=%d, Fake=%d) in %.1fs",
176
+ len(result),
177
+ (result["binary_label"] == 1).sum(),
178
+ (result["binary_label"] == 0).sum(),
179
+ time.perf_counter() - t0,
180
+ )
181
+ return result
182
+
183
+
184
+ # ─────────────────────────────────────────────────────────
185
+
186
+ # LIAR label mapping
187
+ _LIAR_LABEL_MAP = {
188
+ "true": 1,
189
+ "mostly-true": 1,
190
+ "half-true": 1,
191
+ "false": 0,
192
+ "barely-true": 0,
193
+ "pants-fire": 0,
194
+ }
195
+
196
+ _LIAR_COLNAMES = [
197
+ "id", "label", "statement", "subject", "speaker",
198
+ "job_title", "state", "party",
199
+ "barely_true_cnt", "false_cnt", "half_true_cnt",
200
+ "mostly_true_cnt", "pants_fire_cnt",
201
+ "context",
202
+ ]
203
+
204
+
205
+ def load_liar(dataset_root: str) -> pd.DataFrame:
206
+ """Load the LIAR dataset (``train.tsv``, ``valid.tsv``, ``test.tsv``).
207
+
208
+ Six-class labels are mapped to binary via ``_LIAR_LABEL_MAP``.
209
+
210
+ Args:
211
+ dataset_root: Path to the top-level Dataset folder.
212
+
213
+ Returns:
214
+ DataFrame in the unified schema.
215
+ """
216
+ t0 = time.perf_counter()
217
+ logger.info("Loading LIAR dataset …")
218
+
219
+ base = os.path.join(dataset_root, "liar")
220
+ frames: List[pd.DataFrame] = []
221
+ for fname in ("train.tsv", "valid.tsv", "test.tsv"):
222
+ fp = os.path.join(base, fname)
223
+ if os.path.exists(fp):
224
+ tmp = pd.read_csv(fp, sep="\t", header=None, names=_LIAR_COLNAMES)
225
+ frames.append(tmp)
226
+ logger.info(" %s: %d rows", fname, len(tmp))
227
+
228
+ df = pd.concat(frames, ignore_index=True)
229
+
230
+ records: List[dict] = []
231
+ for _, row in df.iterrows():
232
+ label_str = str(row.get("label", "")).strip().lower()
233
+ binary = _LIAR_LABEL_MAP.get(label_str)
234
+ if binary is None:
235
+ continue # Skip rows with unrecognised labels
236
+
237
+ records.append({
238
+ "article_id": str(uuid.uuid4()),
239
+ "title": "", # LIAR has no title
240
+ "text": str(row.get("statement", "") or ""),
241
+ "source_domain": "politifact.com", # All LIAR data from PolitiFact
242
+ "published_date": pd.NaT,
243
+ "has_date": False,
244
+ "binary_label": binary,
245
+ "dataset_origin": "liar",
246
+ })
247
+
248
+ result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
249
+ logger.info(
250
+ "LIAR loaded: %d rows (True=%d, Fake=%d) in %.1fs",
251
+ len(result),
252
+ (result["binary_label"] == 1).sum(),
253
+ (result["binary_label"] == 0).sum(),
254
+ time.perf_counter() - t0,
255
+ )
256
+ return result
257
+
258
+
259
+ # ─────────────────────────────────────────────────────────
260
+
261
+ def load_kaggle_combined(dataset_root: str) -> pd.DataFrame:
262
+ """Load the Kaggle Combined / News_dataset folder.
263
+
264
+ This folder mirrors the ISOT structure (``True.csv``, ``Fake.csv``).
265
+
266
+ Args:
267
+ dataset_root: Path to the top-level Dataset folder.
268
+
269
+ Returns:
270
+ DataFrame in the unified schema.
271
+ """
272
+ t0 = time.perf_counter()
273
+ logger.info("Loading Kaggle Combined (News_dataset) …")
274
+
275
+ # Note: The actual folder has a trailing space: "News _dataset"
276
+ base = os.path.join(dataset_root, "News _dataset")
277
+ if not os.path.isdir(base):
278
+ # Fallback without space
279
+ base = os.path.join(dataset_root, "News_dataset")
280
+
281
+ frames: List[pd.DataFrame] = []
282
+
283
+ for fname in os.listdir(base):
284
+ fpath = os.path.join(base, fname)
285
+ if not fname.lower().endswith(".csv"):
286
+ continue
287
+ try:
288
+ tmp = pd.read_csv(fpath)
289
+ except Exception as exc:
290
+ logger.warning("Could not read %s: %s", fpath, exc)
291
+ continue
292
+
293
+ # Detect label
294
+ name_lower = fname.lower()
295
+ if "true" in name_lower or "real" in name_lower:
296
+ tmp["binary_label"] = 1
297
+ elif "fake" in name_lower:
298
+ tmp["binary_label"] = 0
299
+ elif "label" in [c.lower() for c in tmp.columns]:
300
+ # Dynamic: if there's a label column, try to map
301
+ label_col = [c for c in tmp.columns if c.lower() == "label"][0]
302
+ tmp["binary_label"] = tmp[label_col].apply(
303
+ lambda x: 1 if str(x).strip().lower() in ("1", "true", "real") else 0
304
+ )
305
+ else:
306
+ logger.warning("Cannot determine label for %s — skipping.", fname)
307
+ continue
308
+
309
+ frames.append(tmp)
310
+ logger.info(" %s: %d rows", fname, len(tmp))
311
+
312
+ if not frames:
313
+ logger.warning("No CSV files found in Kaggle Combined folder.")
314
+ return pd.DataFrame(columns=UNIFIED_COLUMNS)
315
+
316
+ df = pd.concat(frames, ignore_index=True)
317
+
318
+ # Detect column names dynamically
319
+ col_map = {c.lower().strip(): c for c in df.columns}
320
+
321
+ title_col = col_map.get("title")
322
+ text_col = col_map.get("text") or col_map.get("article") or col_map.get("content")
323
+ date_col = col_map.get("date") or col_map.get("published_date")
324
+
325
+ records: List[dict] = []
326
+ for _, row in df.iterrows():
327
+ pub_date = _try_parse_date(row.get(date_col)) if date_col else pd.NaT
328
+ records.append({
329
+ "article_id": str(uuid.uuid4()),
330
+ "title": str(row.get(title_col, "") or "") if title_col else "",
331
+ "text": str(row.get(text_col, "") or "") if text_col else "",
332
+ "source_domain": "unknown",
333
+ "published_date": pub_date,
334
+ "has_date": not pd.isna(pub_date),
335
+ "binary_label": int(row["binary_label"]),
336
+ "dataset_origin": "kaggle_combined",
337
+ })
338
+
339
+ result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
340
+ logger.info(
341
+ "Kaggle Combined loaded: %d rows (True=%d, Fake=%d) in %.1fs",
342
+ len(result),
343
+ (result["binary_label"] == 1).sum(),
344
+ (result["binary_label"] == 0).sum(),
345
+ time.perf_counter() - t0,
346
+ )
347
+ return result
348
+
349
+
350
+ # ─────────────────────────────────────────────────────────
351
+
352
+ def _load_txt_folder(folder: str, label: int) -> List[dict]:
353
+ """Read all ``.txt`` files in *folder* and return a list of record dicts.
354
+
355
+ The first non-empty line is treated as the title; the remainder is the
356
+ body text.
357
+
358
+ Args:
359
+ folder: Directory containing ``.txt`` article files.
360
+ label: Binary label (0 = Fake, 1 = True) to assign.
361
+
362
+ Returns:
363
+ List of dicts suitable for DataFrame construction.
364
+ """
365
+ records: List[dict] = []
366
+ if not os.path.isdir(folder):
367
+ return records
368
+ for fname in sorted(os.listdir(folder)):
369
+ if not fname.endswith(".txt"):
370
+ continue
371
+ fpath = os.path.join(folder, fname)
372
+ try:
373
+ with open(fpath, "r", encoding="utf-8", errors="replace") as fh:
374
+ lines = fh.read().strip().splitlines()
375
+ except Exception:
376
+ continue
377
+ title = lines[0].strip() if lines else ""
378
+ body = "\n".join(lines[1:]).strip() if len(lines) > 1 else ""
379
+ records.append({
380
+ "article_id": str(uuid.uuid4()),
381
+ "title": title,
382
+ "text": body,
383
+ "source_domain": "unknown",
384
+ "published_date": pd.NaT,
385
+ "has_date": False,
386
+ "binary_label": label,
387
+ "dataset_origin": "multi_domain",
388
+ })
389
+ return records
390
+
391
+
392
+ def load_multi_domain(dataset_root: str) -> pd.DataFrame:
393
+ """Load the Multi-Domain Fake News dataset (``overall/`` folder).
394
+
395
+ Structure::
396
+
397
+ overall/overall/
398
+ fake/ → .txt files (label 0)
399
+ real/ → .txt files (label 1)
400
+ celebrityDataset/
401
+ fake/ → .txt files (label 0)
402
+ legit/ → .txt files (label 1)
403
+
404
+ Args:
405
+ dataset_root: Path to the top-level Dataset folder.
406
+
407
+ Returns:
408
+ DataFrame in the unified schema.
409
+ """
410
+ t0 = time.perf_counter()
411
+ logger.info("Loading Multi-Domain dataset …")
412
+
413
+ base = os.path.join(dataset_root, "overall", "overall")
414
+ records: List[dict] = []
415
+
416
+ # Main fake / real folders
417
+ records.extend(_load_txt_folder(os.path.join(base, "fake"), label=0))
418
+ records.extend(_load_txt_folder(os.path.join(base, "real"), label=1))
419
+
420
+ # Celebrity sub-dataset
421
+ celeb = os.path.join(base, "celebrityDataset")
422
+ records.extend(_load_txt_folder(os.path.join(celeb, "fake"), label=0))
423
+ records.extend(_load_txt_folder(os.path.join(celeb, "legit"), label=1))
424
+
425
+ result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
426
+ logger.info(
427
+ "Multi-Domain loaded: %d rows (True=%d, Fake=%d) in %.1fs",
428
+ len(result),
429
+ (result["binary_label"] == 1).sum(),
430
+ (result["binary_label"] == 0).sum(),
431
+ time.perf_counter() - t0,
432
+ )
433
+ return result
434
+
435
+
436
+ # ─────────────────────────────────────────────────────────
437
+
438
+ def load_training_folder(dataset_root: str) -> pd.DataFrame:
439
+ """Load supplementary training data from ``training/training/``.
440
+
441
+ Structure mirrors multi-domain with sub-datasets ``celebrityDataset``
442
+ and ``fakeNewsDataset``, each containing ``fake/`` and ``legit/`` folders.
443
+
444
+ Args:
445
+ dataset_root: Path to the top-level Dataset folder.
446
+
447
+ Returns:
448
+ DataFrame in the unified schema.
449
+ """
450
+ t0 = time.perf_counter()
451
+ logger.info("Loading supplementary training folder …")
452
+
453
+ base = os.path.join(dataset_root, "training", "training")
454
+ records: List[dict] = []
455
+
456
+ for subdir in ("celebrityDataset", "fakeNewsDataset"):
457
+ sub_path = os.path.join(base, subdir)
458
+ if not os.path.isdir(sub_path):
459
+ continue
460
+ fake_recs = _load_txt_folder(os.path.join(sub_path, "fake"), label=0)
461
+ legit_recs = _load_txt_folder(os.path.join(sub_path, "legit"), label=1)
462
+ for r in fake_recs + legit_recs:
463
+ r["dataset_origin"] = f"training_{subdir}"
464
+ records.extend(fake_recs + legit_recs)
465
+ logger.info(" %s: %d fake + %d legit", subdir, len(fake_recs), len(legit_recs))
466
+
467
+ result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
468
+ logger.info(
469
+ "Training folder loaded: %d rows (True=%d, Fake=%d) in %.1fs",
470
+ len(result),
471
+ (result["binary_label"] == 1).sum(),
472
+ (result["binary_label"] == 0).sum(),
473
+ time.perf_counter() - t0,
474
+ )
475
+ return result
476
+
477
+
478
+ # ─────────────────────────────────────────────────────────
479
+
480
+ def load_testing_dataset(dataset_root: str) -> pd.DataFrame:
481
+ """Load the sacred hold-out Testing_dataset (never used for training).
482
+
483
+ Structure::
484
+
485
+ Testing_dataset/testingSet/
486
+ fake/ → .txt files (label 0)
487
+ real/ → .txt files (label 1)
488
+
489
+ The catalog CSVs in this folder provide metadata; the actual article
490
+ bodies live in the ``fake/`` and ``real/`` sub-folders.
491
+
492
+ Args:
493
+ dataset_root: Path to the top-level Dataset folder.
494
+
495
+ Returns:
496
+ DataFrame in the unified schema with ``dataset_origin = "testing"``.
497
+ """
498
+ t0 = time.perf_counter()
499
+ logger.info("Loading Testing dataset (hold-out) …")
500
+
501
+ base = os.path.join(dataset_root, "Testing_dataset", "testingSet")
502
+ records: List[dict] = []
503
+
504
+ fake_recs = _load_txt_folder(os.path.join(base, "fake"), label=0)
505
+ real_recs = _load_txt_folder(os.path.join(base, "real"), label=1)
506
+ for r in fake_recs + real_recs:
507
+ r["dataset_origin"] = "testing"
508
+ records.extend(fake_recs + real_recs)
509
+
510
+ # Optionally enrich with catalog metadata
511
+ for catalog_name, label in [("Catalog - Fake Articles.csv", 0), ("Catalog - Real Articles.csv", 1)]:
512
+ cat_path = os.path.join(base, catalog_name)
513
+ if os.path.exists(cat_path):
514
+ try:
515
+ cat = pd.read_csv(cat_path)
516
+ logger.info(" Catalog %s: %d entries", catalog_name, len(cat))
517
+ except Exception as exc:
518
+ logger.warning(" Could not read catalog %s: %s", catalog_name, exc)
519
+
520
+ result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
521
+ logger.info(
522
+ "Testing dataset loaded: %d rows (True=%d, Fake=%d) in %.1fs",
523
+ len(result),
524
+ (result["binary_label"] == 1).sum(),
525
+ (result["binary_label"] == 0).sum(),
526
+ time.perf_counter() - t0,
527
+ )
528
+ return result
529
+
530
+
531
+ # ═══════════════════════════════════════════════════════════
532
+ # Main ingestion pipeline
533
+ # ═══════════════════════════════════════════════════════════
534
+
535
+ def run_ingestion(cfg: dict) -> pd.DataFrame:
536
+ """Execute the full Stage 1 ingestion pipeline.
537
+
538
+ Steps:
539
+ 1. Load all five dataset sources.
540
+ 2. Concatenate into a single DataFrame.
541
+ 3. Run Sentence-BERT deduplication.
542
+ 4. Persist ``unified.csv`` and ``stats.json``.
543
+
544
+ Args:
545
+ cfg: Parsed config dictionary (from ``config.yaml``).
546
+
547
+ Returns:
548
+ The final unified (deduplicated) DataFrame.
549
+ """
550
+ pipeline_t0 = time.perf_counter()
551
+ logger.info("═" * 60)
552
+ logger.info(" STAGE 1 — INGESTION START")
553
+ logger.info("═" * 60)
554
+
555
+ dataset_root = os.path.abspath(
556
+ os.path.join(str(_PROJECT_ROOT), cfg["paths"]["dataset_root"])
557
+ )
558
+ logger.info("Dataset root resolved to: %s", dataset_root)
559
+
560
+ # ── Step 1 : Load each dataset ───────────────────────────
561
+ t0 = time.perf_counter()
562
+ df_isot = load_isot(dataset_root)
563
+ df_liar = load_liar(dataset_root)
564
+ df_kaggle = load_kaggle_combined(dataset_root)
565
+ df_multi = load_multi_domain(dataset_root)
566
+ df_training = load_training_folder(dataset_root)
567
+ df_testing = load_testing_dataset(dataset_root)
568
+ load_time = time.perf_counter() - t0
569
+ logger.info("All datasets loaded in %.1fs", load_time)
570
+
571
+ # ── Step 2 : Concatenate ─────────────────────────────────
572
+ t0 = time.perf_counter()
573
+ all_frames = [df_isot, df_liar, df_kaggle, df_multi, df_training, df_testing]
574
+ df_unified = pd.concat(all_frames, ignore_index=True)
575
+ logger.info(
576
+ "Unified dataset: %d rows (concat took %.1fs)",
577
+ len(df_unified), time.perf_counter() - t0,
578
+ )
579
+
580
+ # Log per-origin counts
581
+ origin_counts = df_unified["dataset_origin"].value_counts()
582
+ for origin, cnt in origin_counts.items():
583
+ logger.info(" %-30s %6d rows", origin, cnt)
584
+
585
+ # ── Prep ─────────────────────────────────────────────────
586
+ # FIX 1: Exclude Sacred Hold-out from Dedup
587
+ test_mask = df_unified["dataset_origin"] == "testing"
588
+ test_df = df_unified.loc[test_mask].copy()
589
+ train_pool_df = df_unified.loc[~test_mask].copy()
590
+
591
+ # FIX 3: Remove empty/near-empty texts from training ONLY
592
+ min_word_count = cfg.get("preprocessing", {}).get("min_word_count", 3)
593
+ train_before = len(train_pool_df)
594
+ train_pool_df = clean_empty_texts(train_pool_df, min_word_count=min_word_count)
595
+ empty_dropped = train_before - len(train_pool_df)
596
+
597
+ # Flag short texts in test_df instead of dropping them
598
+ test_full = test_df.apply(lambda r: build_full_text(r.get("title", ""), r.get("text", "")), axis=1)
599
+ test_df["short_text_flag"] = test_full.apply(word_count) < min_word_count
600
+ short_test_flagged = int(test_df["short_text_flag"].sum())
601
+
602
+ logger.info("Sacred test rows preserved: %d (flagged %d short texts)", len(test_df), short_test_flagged)
603
+
604
+ # ── Step 3 : Deduplication ───────────────────────────────
605
+ dedup_cfg = cfg.get("dataset", {})
606
+ threshold = dedup_cfg.get("dedup_threshold", 0.92)
607
+ batch_size = dedup_cfg.get("dedup_batch_size", 64)
608
+
609
+ train_pool_df["_dedup_text"] = (
610
+ train_pool_df["title"].fillna("") + " " + train_pool_df["text"].fillna("")
611
+ ).str.strip()
612
+
613
+ mask_has_text = train_pool_df["_dedup_text"].str.len() > 10
614
+ df_with_text = train_pool_df.loc[mask_has_text].copy()
615
+ df_no_text = train_pool_df.loc[~mask_has_text].copy()
616
+
617
+ logger.info(
618
+ "Dedup candidates (train pool): %d rows with text, %d skipped (too short)",
619
+ len(df_with_text), len(df_no_text),
620
+ )
621
+
622
+ if len(df_with_text) > 0:
623
+ exact_counts = len(df_with_text) - len(df_with_text.drop_duplicates(subset=["_dedup_text"]))
624
+ df_deduped, dedup_stats = deduplicate_dataframe(
625
+ df_with_text,
626
+ text_column="_dedup_text",
627
+ threshold=threshold,
628
+ batch_size=batch_size,
629
+ origin_column="dataset_origin",
630
+ )
631
+ total_removed = len(df_with_text) - len(df_deduped)
632
+ semantic_counts = total_removed - exact_counts
633
+ else:
634
+ df_deduped = df_with_text
635
+ dedup_stats = {}
636
+ exact_counts = 0
637
+ semantic_counts = 0
638
+
639
+ train_pool_deduped = pd.concat([df_deduped, df_no_text], ignore_index=True)
640
+ train_pool_deduped.drop(columns=["_dedup_text"], inplace=True, errors="ignore")
641
+
642
+ # FIX 2: Stratified Holdout Carve-out
643
+ holdout_cfg = cfg.get("holdout", {})
644
+ stratified_test_size = holdout_cfg.get("stratified_test_size", 0.10)
645
+ random_state = holdout_cfg.get("random_state", 42)
646
+
647
+ from sklearn.model_selection import StratifiedShuffleSplit
648
+ sss = StratifiedShuffleSplit(n_splits=1, test_size=stratified_test_size, random_state=random_state)
649
+
650
+ train_pool_deduped = train_pool_deduped.reset_index(drop=True)
651
+ train_idx, held_idx = next(sss.split(train_pool_deduped, train_pool_deduped['binary_label']))
652
+
653
+ stratified_holdout = train_pool_deduped.iloc[held_idx].copy()
654
+ train_pool_final = train_pool_deduped.iloc[train_idx].copy()
655
+
656
+ stratified_holdout['dataset_origin'] = 'stratified_holdout'
657
+
658
+ logger.info("Train pool after carve-out: %d", len(train_pool_final))
659
+ logger.info("Stratified holdout: %d", len(stratified_holdout))
660
+ logger.info("Sacred test set: %d", len(test_df))
661
+
662
+ train_pool_final['short_text_flag'] = False
663
+ stratified_holdout['short_text_flag'] = False
664
+
665
+ df_final = pd.concat([
666
+ train_pool_final,
667
+ stratified_holdout,
668
+ test_df
669
+ ], ignore_index=True)
670
+
671
+ logger.info("Post-dedup and split total: %d rows", len(df_final))
672
+
673
+ # ── Step 4 : Ensure types ────────────────────────────────
674
+ df_final["published_date"] = pd.to_datetime(
675
+ df_final["published_date"], errors="coerce"
676
+ )
677
+ df_final["has_date"] = df_final["published_date"].notna()
678
+ df_final["binary_label"] = df_final["binary_label"].astype(int)
679
+
680
+ # ── Step 5 : Save unified CSV + stats ────────────────────
681
+ processed_dir = os.path.join(str(_PROJECT_ROOT), cfg["paths"]["processed_dir"])
682
+ os.makedirs(processed_dir, exist_ok=True)
683
+
684
+ csv_path = os.path.join(processed_dir, "unified.csv")
685
+ df_final.to_csv(csv_path, index=False)
686
+ logger.info("Saved unified CSV → %s (%d rows)", csv_path, len(df_final))
687
+
688
+ # Stats
689
+ stats = {
690
+ "total_rows": len(df_final),
691
+ "train_pool_rows": len(train_pool_final),
692
+ "stratified_holdout_rows": len(stratified_holdout),
693
+ "sacred_test_rows": len(test_df),
694
+ "fake_count": int((df_final["binary_label"] == 0).sum()),
695
+ "true_count": int((df_final["binary_label"] == 1).sum()),
696
+ "has_date_ratio": float(df_final["has_date"].mean()),
697
+ "empty_texts_dropped": empty_dropped,
698
+ "short_text_flagged_in_test": short_test_flagged,
699
+ "dedup_removed_exact": exact_counts,
700
+ "dedup_removed_semantic": semantic_counts,
701
+ "per_origin": df_final["dataset_origin"].value_counts().to_dict(),
702
+ "dedup_stats": {k: int(v) for k, v in dedup_stats.items()}
703
+ }
704
+ stats_path = os.path.join(processed_dir, "stats.json")
705
+ with open(stats_path, "w", encoding="utf-8") as fh:
706
+ json.dump(stats, fh, indent=2, default=str)
707
+ logger.info("Saved stats → %s", stats_path)
708
+
709
+ pipeline_elapsed = time.perf_counter() - pipeline_t0
710
+ logger.info("═" * 60)
711
+ logger.info(" STAGE 1 — INGESTION COMPLETE (%.1fs total)", pipeline_elapsed)
712
+ logger.info("═" * 60)
713
+
714
+ return df_final
715
+
716
+
717
+ # ═══════════════════════════════════════════════════════════
718
+ # __main__ block for standalone testing
719
+ # ═══════════════════════════════════════════════════════════
720
+ if __name__ == "__main__":
721
+ cfg = load_config()
722
+ df = run_ingestion(cfg)
723
+ print("\n=== Final Unified Dataset ===")
724
+ print(f"Shape: {df.shape}")
725
+ print(f"\nLabel distribution:\n{df['binary_label'].value_counts()}")
726
+ print(f"\nOrigin distribution:\n{df['dataset_origin'].value_counts()}")
727
+ print(f"\nhas_date ratio: {df['has_date'].mean():.2%}")
728
+ print(f"\nSample rows:\n{df.head(3).to_string()}")
src/stage2_preprocessing.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import time
5
+ import logging
6
+ import pickle
7
+ import numpy as np
8
+ import pandas as pd
9
+ import yaml
10
+ from sklearn.model_selection import StratifiedShuffleSplit
11
+
12
+ # Fix paths for imports
13
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+ if str(_PROJECT_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(_PROJECT_ROOT))
16
+
17
+ from src.utils.text_utils import clean_text, build_full_text, word_count, text_length_bucket
18
+ from src.utils.domain_weights import compute_domain_weights
19
+ from src.utils.freshness import apply_freshness_score
20
+
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
24
+ datefmt="%H:%M:%S"
25
+ )
26
+ logger = logging.getLogger("stage2_preprocessing")
27
+
28
+ class KerasStyleTokenizer:
29
+ """A lightweight, PyTorch-compatible word tokenizer mimicking Keras's Tokenizer."""
30
+ def __init__(self, num_words=None, oov_token="<OOV>"):
31
+ self.num_words = num_words
32
+ self.oov_token = oov_token
33
+ self.word_index = {self.oov_token: 1} # 0 is reserved for padding
34
+ self.word_counts = {}
35
+
36
+ def fit_on_texts(self, texts):
37
+ for text in texts:
38
+ # clean_text already removes punctuation, we just split by space
39
+ words = str(text).split()
40
+ for w in words:
41
+ self.word_counts[w] = self.word_counts.get(w, 0) + 1
42
+
43
+ # Sort by frequency
44
+ sorted_words = sorted(self.word_counts.items(), key=lambda x: x[1], reverse=True)
45
+
46
+ for idx, (w, _) in enumerate(sorted_words):
47
+ if self.num_words and idx >= self.num_words - 2:
48
+ break
49
+ self.word_index[w] = idx + 2
50
+
51
+ def texts_to_sequences(self, texts):
52
+ seqs = []
53
+ for text in texts:
54
+ words = str(text).split()
55
+ seq = [self.word_index.get(w, 1) for w in words]
56
+ seqs.append(seq)
57
+ return seqs
58
+
59
+ def truncate_str_array(df, col):
60
+ """Memory fix: force string type for arrays."""
61
+ return df[col].astype(str).values
62
+
63
+ def run_preprocessing(cfg: dict = None):
64
+ t0 = time.perf_counter()
65
+ if cfg is None:
66
+ cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
67
+ with open(cfg_path, "r", encoding="utf-8") as f:
68
+ cfg = yaml.safe_load(f)
69
+
70
+ logger.info("STAGE 2: PREPROCESSING START")
71
+ processed_dir = os.path.join(_PROJECT_ROOT, cfg["paths"]["processed_dir"])
72
+ splits_dir = os.path.join(_PROJECT_ROOT, cfg["paths"]["splits_dir"])
73
+ models_dir = os.path.join(_PROJECT_ROOT, cfg["paths"]["models_dir"])
74
+ os.makedirs(splits_dir, exist_ok=True)
75
+ os.makedirs(models_dir, exist_ok=True)
76
+
77
+ # 1. Load Data
78
+ csv_path = os.path.join(processed_dir, "unified.csv")
79
+ df = pd.read_csv(csv_path)
80
+ df["published_date"] = pd.to_datetime(df["published_date"], errors="coerce")
81
+ logger.info("Loaded unified CSV: %d rows", len(df))
82
+
83
+ # 2. Extract Text Length Features & Clean
84
+ # "Concatenate title + '. ' + text as full_text" -> use build_full_text
85
+ df["full_text"] = df.apply(lambda r: build_full_text(
86
+ str(r["title"]) if pd.notna(r["title"]) else "",
87
+ str(r["text"]) if pd.notna(r["text"]) else ""
88
+ ), axis=1)
89
+ # the prompt specifies cleaning the text by lowercasing, removing HTML/URLs/special bounds.
90
+ # clean_text handles exactly this cleanly.
91
+ logger.info("Applying text cleaning (HTML, URLs, whitespace, punctuation) ...")
92
+ df["clean_text"] = df["full_text"].apply(clean_text)
93
+
94
+ logger.info("Calculating word counts and text buckets ...")
95
+ df["word_count"] = df["clean_text"].apply(word_count)
96
+ df["text_length_bucket"] = df["word_count"].apply(text_length_bucket)
97
+
98
+ # 3. Domain Weights
99
+ ds_cfg = cfg.get("dataset", {})
100
+ min_domains = ds_cfg.get("min_domain_samples", 20)
101
+ max_multi = cfg.get("inference", {}).get("max_multiplier", 10)
102
+
103
+ logger.info("Computing domain-aware sample weights...")
104
+ df = compute_domain_weights(df, min_domain_samples=min_domains, max_multiplier=max_multi)
105
+
106
+ # 4. Freshness
107
+ logger.info("Applying temporal freshness scores...")
108
+ df = apply_freshness_score(df, is_inference=False)
109
+
110
+ # 5. Train/Val/Test Splits
111
+ # The user clarified exactly:
112
+ # stratified_holdout -> stage 3 proxy
113
+ # testing -> sacred
114
+ # train pool -> split 85/15 into train and val.
115
+
116
+ test_mask = df["dataset_origin"] == "testing"
117
+ holdout_mask = df["dataset_origin"] == "stratified_holdout"
118
+ train_pool_mask = ~(test_mask | holdout_mask)
119
+
120
+ test_df = df[test_mask].copy()
121
+ holdout_df = df[holdout_mask].copy()
122
+ train_pool_df = df[train_pool_mask].copy()
123
+
124
+ # Split train_pool into 85% train, 15% validation using StratifiedShuffleSplit
125
+ sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=42)
126
+ train_pool_df = train_pool_df.reset_index(drop=True)
127
+
128
+ train_idx, val_idx = next(sss.split(train_pool_df, train_pool_df["binary_label"]))
129
+ train_df = train_pool_df.iloc[train_idx].copy()
130
+ val_df = train_pool_df.iloc[val_idx].copy()
131
+
132
+ logger.info("SPLITS SUMMARY:")
133
+ logger.info(" Train: %d rows", len(train_df))
134
+ logger.info(" Val: %d rows", len(val_df))
135
+ logger.info(" Holdout: %d rows", len(holdout_df))
136
+ logger.info(" Sacred: %d rows", len(test_df))
137
+
138
+ # 6. Save splits metadata and arrays
139
+ # Saving raw text separately just for PyTorch dataset convenience (faster than pd.read_csv for big models)
140
+ splits_dict = {
141
+ "train": train_df,
142
+ "val": val_df,
143
+ "holdout": holdout_df,
144
+ "test": test_df
145
+ }
146
+
147
+ for split_name, split_data in splits_dict.items():
148
+ np.save(os.path.join(splits_dir, f"X_text_{split_name}.npy"), truncate_str_array(split_data, "clean_text"))
149
+ np.save(os.path.join(splits_dir, f"y_{split_name}.npy"), split_data["binary_label"].values)
150
+ np.save(os.path.join(splits_dir, f"w_{split_name}.npy"), split_data["sample_weight"].values)
151
+
152
+ meta = {
153
+ "size": len(split_data),
154
+ "fake_count": int((split_data["binary_label"] == 0).sum()),
155
+ "true_count": int((split_data["binary_label"] == 1).sum()),
156
+ "word_count_median": float(split_data["word_count"].median()),
157
+ "freshness_mean": float(split_data["freshness_score"].mean())
158
+ }
159
+ with open(os.path.join(splits_dir, f"meta_{split_name}.json"), "w") as f:
160
+ json.dump(meta, f, indent=2)
161
+
162
+ # Save train_ids.csv explicitly
163
+ train_df[["article_id"]].to_csv(os.path.join(splits_dir, "train_ids.csv"), index=False)
164
+ # Also save the full preprocessed test sets to CSV for easy loading during Stage 3 / Evaluation
165
+ train_df.to_csv(os.path.join(splits_dir, "df_train.csv"), index=False)
166
+ val_df.to_csv(os.path.join(splits_dir, "df_val.csv"), index=False)
167
+ holdout_df.to_csv(os.path.join(splits_dir, "df_holdout.csv"), index=False)
168
+ test_df.to_csv(os.path.join(splits_dir, "df_test.csv"), index=False)
169
+
170
+ # 7. Tokenization (LSTM)
171
+ logger.info("Fitting LSTM Tokenizer on Train split...")
172
+ # Max features for LSTM or generic defaults usually just load all words. We will let it cap at e.g., 50k
173
+ vocab_size = cfg.get("preprocessing", {}).get("max_tfidf_features", 50000)
174
+ tok = KerasStyleTokenizer(num_words=vocab_size)
175
+ tok.fit_on_texts(train_df["clean_text"])
176
+
177
+ tok_path = os.path.join(models_dir, "tokenizer.pkl")
178
+ with open(tok_path, "wb") as f:
179
+ pickle.dump(tok, f)
180
+ logger.info(f"Saved tokenizer to {tok_path} (vocab size: {len(tok.word_index)})")
181
+
182
+ t_end = time.perf_counter()
183
+ logger.info("STAGE 2 FINISHED in %.2f seconds", t_end - t0)
184
+
185
+ if __name__ == "__main__":
186
+ run_preprocessing()
src/stage3_training.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import time
5
+ import subprocess
6
+
7
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+ if str(_PROJECT_ROOT) not in sys.path:
9
+ sys.path.insert(0, str(_PROJECT_ROOT))
10
+
11
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
12
+ logger = logging.getLogger("stage3_training")
13
+
14
+ def run_training(cfg: dict = None):
15
+ t0 = time.perf_counter()
16
+ logger.info("STAGE 3: TRAINING START")
17
+
18
+ python_exe = sys.executable
19
+ models_dir = os.path.join(_PROJECT_ROOT, "src", "models")
20
+
21
+ scripts = [
22
+ ("Logistic Regression", "logistic_model.py"),
23
+ ("Bi-LSTM", "lstm_model.py"),
24
+ ("DistilBERT", "distilbert_model.py"),
25
+ ("RoBERTa", "roberta_model.py"),
26
+ ("Meta-Classifier", "meta_classifier.py")
27
+ ]
28
+
29
+ for name, script_name in scripts:
30
+ script_path = os.path.join(models_dir, script_name)
31
+ logger.info(f"==> Launching {name} Training ({script_name})")
32
+ val = subprocess.run([python_exe, script_path], cwd=_PROJECT_ROOT)
33
+ if val.returncode != 0:
34
+ logger.error(f"{name} aborted with exit code {val.returncode}")
35
+ sys.exit(1)
36
+
37
+ t_end = time.perf_counter()
38
+ logger.info("STAGE 3 FINISHED in %.2f seconds", t_end - t0)
39
+
40
+ if __name__ == "__main__":
41
+ run_training()
src/stage4_inference.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stage 4 — Inference Engine (5-Signal Weighted Scoring)
3
+ =====================================================
4
+ Evaluates articles across five independent signals:
5
+ 1. Source Credibility (30%)
6
+ 2. Claim Verification (30%)
7
+ 3. Linguistic Analysis (20%)
8
+ 4. Freshness (10%)
9
+ 5. Ensemble Model Vote (10%)
10
+ Then applies adversarial overrides and maps to a final verdict.
11
+ """
12
+
13
+ import os
14
+ import re
15
+ import sys
16
+ import yaml
17
+ import logging
18
+ import pickle
19
+ import pandas as pd
20
+ import numpy as np
21
+ import torch
22
+ from datetime import datetime, timezone
23
+
24
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
25
+ if str(_PROJECT_ROOT) not in sys.path:
26
+ sys.path.insert(0, str(_PROJECT_ROOT))
27
+
28
+ from src.utils.text_utils import clean_text, build_full_text, word_count as wc_func, text_length_bucket
29
+ from src.stage2_preprocessing import KerasStyleTokenizer
30
+
31
+ import sys
32
+ setattr(sys.modules['__main__'], 'KerasStyleTokenizer', KerasStyleTokenizer)
33
+
34
+ logger = logging.getLogger("stage4_inference")
35
+
36
+ # ═════════════════════════════════════════════════════════════════════════════
37
+ # CONSTANTS
38
+ # ═════════════════════════════════════════════════════════════════════════════
39
+
40
+ CREDIBLE_OUTLETS = {
41
+ "reuters.com", "apnews.com", "bbc.com", "bbc.co.uk", "nytimes.com",
42
+ "washingtonpost.com", "theguardian.com", "cnn.com", "cbsnews.com",
43
+ "nbcnews.com", "abcnews.go.com", "npr.org", "pbs.org", "bloomberg.com",
44
+ "wsj.com", "ft.com", "economist.com", "usatoday.com", "time.com",
45
+ "politico.com", "thehill.com", "axios.com", "propublica.org",
46
+ "snopes.com", "factcheck.org", "politifact.com", "fullfact.org",
47
+ "aljazeera.com", "dw.com", "france24.com", "scmp.com",
48
+ "theatlantic.com", "newyorker.com", "wired.com", "nature.com",
49
+ "sciencemag.org", "thelancet.com", "bmj.com", "who.int",
50
+ "un.org", "whitehouse.gov", "gov.uk", "europa.eu",
51
+ "hindustantimes.com", "ndtv.com", "thehindu.com", "indianexpress.com",
52
+ "timesofindia.indiatimes.com", "livemint.com",
53
+ "abc.net.au", "cbc.ca", "globalnews.ca", "stuff.co.nz",
54
+ "forbes.com", "businessinsider.com", "cnbc.com", "techcrunch.com",
55
+ "arstechnica.com", "theverge.com", "engadget.com",
56
+ "espn.com", "bbc.com/sport", "skysports.com",
57
+ }
58
+
59
+ CORROBORATION_OUTLETS_RE = re.compile(
60
+ r"(?i)\b(Reuters|Associated Press|\bAP\b|CBS|BBC|NBC|CNN|"
61
+ r"New York Times|NYT|Washington Post|The Guardian|NPR|PBS|"
62
+ r"Bloomberg|Wall Street Journal|Forbes)\b"
63
+ )
64
+
65
+ AUTHOR_PATTERNS = re.compile(
66
+ r"(?i)\b(by|written by|reporter|staff writer|correspondent|"
67
+ r"contributing writer|author|edited by|reported by)\b\s*[A-Z]"
68
+ )
69
+ BYLINE_NAME_RE = re.compile(r"^[A-Z][a-z]+ [A-Z][a-z]+", re.MULTILINE)
70
+
71
+ SUPERLATIVE_RE = re.compile(
72
+ r"(?i)\b(shocking|massive|unprecedented|bombshell|explosive|"
73
+ r"stunning|jaw-dropping|mind-blowing|unbelievable|outrageous)\b"
74
+ )
75
+ SENSATIONAL_RE = re.compile(
76
+ r"(?i)(you won't believe|what happened next|this is why|"
77
+ r"one weird trick|exposed|destroyed|slammed)"
78
+ )
79
+ NO_ATTRIB_RE = re.compile(
80
+ r"(?i)(sources say|it is believed|reportedly|some people say|"
81
+ r"many believe|rumor has it|anonymous source|unconfirmed reports)"
82
+ )
83
+ PASSIVE_VOICE_RE = re.compile(
84
+ r"(?i)(it is being said|it was reported|it has been claimed|"
85
+ r"it is alleged|it was alleged|it is rumored)"
86
+ )
87
+ QUOTE_RE = re.compile(r'"([^"]{10,})"')
88
+ QUOTE_ATTRIB_RE = re.compile(
89
+ r"(?i)(said|stated|according to|told|announced|confirmed|wrote|called|described|noted|added|explained|argued|claimed)"
90
+ )
91
+
92
+ STAT_RE = re.compile(r"\d+\s*%|\d+\s*(million|billion|trillion)", re.IGNORECASE)
93
+ CITATION_RE = re.compile(
94
+ r"(?i)(according to|source:|study by|data from|published by|research by|"
95
+ r"report by|survey by|analysis by|statistics from)"
96
+ )
97
+
98
+ INSTITUTION_RE = re.compile(
99
+ r"(?i)(university|department of|ministry|commission|institute|agency|"
100
+ r"foundation|world health|WHO|FDA|CDC|NASA|UNICEF|IMF|World Bank)"
101
+ )
102
+ TEMPORAL_RE = re.compile(
103
+ r"(?i)(this week|this month|recently|new report|just released|"
104
+ r"annual forecast|latest data|new study|breaking|today|yesterday)"
105
+ )
106
+
107
+
108
+ class ModelNotTrainedError(Exception):
109
+ def __init__(self, message="Run python run_pipeline.py --stage 3 first"):
110
+ super().__init__(message)
111
+
112
+
113
+ # ═════════════════════════════════════════════════════════════════════════════
114
+ # MODEL LOADING (unchanged from original)
115
+ # ═════════════════════════════════════════════════════════════════════════════
116
+
117
+ _MODEL_CACHE = {}
118
+
119
+ def load_config():
120
+ cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
121
+ with open(cfg_path, "r", encoding="utf-8") as f:
122
+ return yaml.safe_load(f)
123
+
124
+ def _get_model(model_name, cfg):
125
+ """Lazy load models."""
126
+ if model_name in _MODEL_CACHE:
127
+ return _MODEL_CACHE[model_name]
128
+
129
+ models_dir = os.path.join(_PROJECT_ROOT, cfg.get("paths", {}).get("models_dir", "models/saved"))
130
+
131
+ if model_name == "logistic":
132
+ import joblib
133
+ fpath = os.path.join(models_dir, "logistic_model", "logistic_model.pkl")
134
+ if not os.path.exists(fpath): raise ModelNotTrainedError()
135
+ _MODEL_CACHE[model_name] = joblib.load(fpath)
136
+
137
+ elif model_name == "lstm":
138
+ from src.models.lstm_model import BiLSTMClassifier, load_glove_embeddings, pad_sequences
139
+ tok_path = os.path.join(models_dir, "tokenizer.pkl")
140
+ if not os.path.exists(tok_path) or not os.path.exists(os.path.join(models_dir, "lstm_model", "model.pt")):
141
+ raise ModelNotTrainedError()
142
+ with open(tok_path, "rb") as f:
143
+ tok = pickle.load(f)
144
+ glove_path = os.path.join(_PROJECT_ROOT, cfg["paths"]["glove_path"])
145
+ emb_matrix, vocab_size = load_glove_embeddings(glove_path, tok.word_index)
146
+
147
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
+ model = BiLSTMClassifier(vocab_size, emb_matrix).to(device)
149
+ model.load_state_dict(torch.load(os.path.join(models_dir, "lstm_model", "model.pt"), map_location=device))
150
+ model.eval()
151
+ _MODEL_CACHE[model_name] = (model, tok, device)
152
+
153
+ elif model_name in ("distilbert", "roberta"):
154
+ try:
155
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
156
+ except ImportError:
157
+ raise ModelNotTrainedError()
158
+ d_path = os.path.join(models_dir, f"{model_name}_model")
159
+ if not os.path.exists(os.path.join(d_path, "config.json")):
160
+ raise ModelNotTrainedError()
161
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
162
+ tok = AutoTokenizer.from_pretrained(d_path)
163
+ model = AutoModelForSequenceClassification.from_pretrained(d_path).to(device)
164
+ model.eval()
165
+ _MODEL_CACHE[model_name] = (model, tok, device)
166
+
167
+ elif model_name == "meta":
168
+ import joblib
169
+ fpath = os.path.join(models_dir, "meta_classifier", "meta_classifier.pkl")
170
+ if not os.path.exists(fpath): raise ModelNotTrainedError()
171
+ _MODEL_CACHE[model_name] = joblib.load(fpath)
172
+
173
+ return _MODEL_CACHE[model_name]
174
+
175
+
176
+ # ═════════════════════════════════════════════════════════════════════════════
177
+ # FEATURE EXTRACTION
178
+ # ═════════════════════════════════════════════════════════════════════════════
179
+
180
+ def extract_features(title, text, source_domain, published_date, cfg):
181
+ """Build standardized structural mapping for raw strings."""
182
+ full = build_full_text(title, text)
183
+ clean = clean_text(full)
184
+ wc = wc_func(clean)
185
+ bucket = text_length_bucket(wc)
186
+
187
+ has_date = pd.notna(published_date) and published_date != ""
188
+ if has_date and isinstance(published_date, str):
189
+ try:
190
+ published_date = pd.to_datetime(published_date, utc=True)
191
+ except Exception:
192
+ has_date = False
193
+ published_date = None
194
+ elif has_date:
195
+ try:
196
+ published_date = pd.Timestamp(published_date, tz="UTC")
197
+ except Exception:
198
+ has_date = False
199
+ published_date = None
200
+
201
+ return {
202
+ "clean_text": clean,
203
+ "full_text": full,
204
+ "word_count": wc,
205
+ "text_length_bucket": bucket,
206
+ "has_date": has_date,
207
+ "published_date": published_date,
208
+ "source_domain": source_domain if source_domain else "unknown",
209
+ }
210
+
211
+
212
+ # ═════════════════════════════════════════════════════════════════════════════
213
+ # STEP 1 — SOURCE CREDIBILITY (weight: 30%)
214
+ # ═════════════════════════════════════════════════════════════════════════════
215
+
216
+ def _levenshtein(s1, s2):
217
+ """Minimal Levenshtein distance for typosquatting check."""
218
+ if len(s1) < len(s2):
219
+ return _levenshtein(s2, s1)
220
+ if len(s2) == 0:
221
+ return len(s1)
222
+ prev_row = range(len(s2) + 1)
223
+ for i, c1 in enumerate(s1):
224
+ curr_row = [i + 1]
225
+ for j, c2 in enumerate(s2):
226
+ curr_row.append(min(curr_row[j] + 1, prev_row[j + 1] + 1,
227
+ prev_row[j] + (c1 != c2)))
228
+ prev_row = curr_row
229
+ return prev_row[-1]
230
+
231
+
232
+ def score_source_credibility(source_domain, title, text):
233
+ """
234
+ Step 1: Evaluate source trustworthiness.
235
+ Returns: (score, author_found, typosquatting_detected)
236
+ """
237
+ # ── Early return: no source at all ──
238
+ if not source_domain or source_domain.strip() == "" or source_domain == "unknown":
239
+ # Still check for author in text body
240
+ author_found = bool(AUTHOR_PATTERNS.search(text[:500])) or bool(BYLINE_NAME_RE.search(text[:200]))
241
+ return 0.3, author_found, False
242
+
243
+ domain = source_domain.strip().lower()
244
+
245
+ # ── Typosquatting check ──
246
+ for outlet in CREDIBLE_OUTLETS:
247
+ dist = _levenshtein(domain, outlet)
248
+ if 0 < dist <= 2: # close but not exact
249
+ return 0.0, False, True
250
+
251
+ # ── Component scoring ──
252
+ score = 0.0
253
+
254
+ # Base: any valid domain
255
+ score += 0.20
256
+
257
+ # Known outlet
258
+ if domain in CREDIBLE_OUTLETS:
259
+ score += 0.40
260
+
261
+ # Author verifiability
262
+ search_area = text[:500]
263
+ author_found = bool(AUTHOR_PATTERNS.search(search_area)) or bool(BYLINE_NAME_RE.search(text[:200]))
264
+ if author_found:
265
+ score += 0.20
266
+
267
+ # Corroboration: text mentions other major outlets
268
+ if CORROBORATION_OUTLETS_RE.search(text):
269
+ score += 0.20
270
+
271
+ return min(1.0, score), author_found, False
272
+
273
+
274
+ # ═════════════════════════════════════════════════════════════════════════════
275
+ # STEP 2 — CLAIM VERIFICATION (weight: 30%)
276
+ # ═════════════════════════════════════════════════════════════════════════════
277
+
278
+ _SPACY_NLP = None
279
+
280
+ def _get_spacy():
281
+ global _SPACY_NLP
282
+ if _SPACY_NLP is None:
283
+ import spacy
284
+ try:
285
+ _SPACY_NLP = spacy.load("en_core_web_sm")
286
+ except OSError:
287
+ import subprocess
288
+ subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"], check=True)
289
+ _SPACY_NLP = spacy.load("en_core_web_sm")
290
+ return _SPACY_NLP
291
+
292
+
293
+ def score_claim_verification(meta_proba, clean_text_str, title):
294
+ """
295
+ Step 2: Entity-level claim verification.
296
+ Returns: (claim_score, entities_found, n_verifiable, quotes_attributed, quotes_total)
297
+ """
298
+ nlp = _get_spacy()
299
+ # Process a capped version to avoid memory issues on long articles
300
+ doc = nlp(clean_text_str[:5000])
301
+
302
+ # Sub-step A: Named Entity Extraction
303
+ verifiable_types = {"PERSON", "ORG", "GPE"}
304
+ numeric_types = {"MONEY", "PERCENT", "CARDINAL"}
305
+
306
+ verifiable_ents = [ent.text for ent in doc.ents if ent.label_ in verifiable_types]
307
+ numeric_ents = [ent for ent in doc.ents if ent.label_ in numeric_types]
308
+
309
+ n_verifiable = len(set(verifiable_ents))
310
+
311
+ # Count unverifiable numeric claims (no citation within ±100 chars)
312
+ n_unverifiable = 0
313
+ for ent in numeric_ents:
314
+ start = max(0, ent.start_char - 100)
315
+ end = min(len(clean_text_str), ent.end_char + 100)
316
+ context = clean_text_str[start:end]
317
+ if not CITATION_RE.search(context):
318
+ n_unverifiable += 1
319
+
320
+ # Sub-step B: Quote Attribution
321
+ quotes = QUOTE_RE.findall(clean_text_str[:5000])
322
+ quotes_total = len(quotes)
323
+ quotes_attributed = 0
324
+
325
+ for q in quotes:
326
+ q_pos = clean_text_str.find(q)
327
+ if q_pos == -1:
328
+ continue
329
+ context_start = max(0, q_pos - 50)
330
+ context_end = min(len(clean_text_str), q_pos + len(q) + 50)
331
+ context = clean_text_str[context_start:context_end]
332
+ if QUOTE_ATTRIB_RE.search(context):
333
+ quotes_attributed += 1
334
+
335
+ attributed_ratio = (quotes_attributed / quotes_total) if quotes_total > 0 else 1.0
336
+
337
+ # Sub-step C: Combine
338
+ entity_score = min(1.0, n_verifiable / 3) # 3+ verifiable entities = full marks
339
+ unverifiable_penalty = min(0.15, n_unverifiable * 0.05)
340
+
341
+ claim_score = (meta_proba * 0.60) + (entity_score * 0.25) + (attributed_ratio * 0.15)
342
+ claim_score = max(0.0, min(1.0, claim_score - unverifiable_penalty))
343
+
344
+ entities_found = list(set(verifiable_ents))[:10] # Cap for JSON output
345
+
346
+ return claim_score, entities_found, n_verifiable, quotes_attributed, quotes_total
347
+
348
+
349
+ # ═════════════════════════════════════════════════════════════════════════════
350
+ # STEP 3 — LINGUISTIC ANALYSIS (weight: 20%)
351
+ # ═════════════════════════════════════════════════════════════════════════════
352
+
353
+ def score_linguistic_quality(title, text, clean_text_str, author_found, cfg=None):
354
+ """
355
+ Step 3: Rule-based linguistic quality scoring.
356
+ Reuses DistilBERT for headline contradiction check.
357
+ Returns: (linguistic_score, deductions_applied, headline_contradicts)
358
+ """
359
+ score = 1.0
360
+ deductions = []
361
+ headline_contradicts = False
362
+ title_str = str(title) if title else ""
363
+
364
+ # ── 1. Sensationalist headline (-0.20) ──
365
+ sensational = False
366
+ if title_str:
367
+ caps_words = re.findall(r"\b[A-Z]{4,}\b", title_str)
368
+ if len(caps_words) >= 1:
369
+ sensational = True
370
+ if "!" in title_str:
371
+ sensational = True
372
+ if SENSATIONAL_RE.search(title_str):
373
+ sensational = True
374
+ if sensational:
375
+ score -= 0.20
376
+ deductions.append("Sensationalist headline detected")
377
+
378
+ # ── 2. Excessive superlatives (-0.15, needs ≥2 matches) ──
379
+ superlative_matches = SUPERLATIVE_RE.findall(clean_text_str)
380
+ if len(superlative_matches) >= 2:
381
+ score -= 0.15
382
+ deductions.append(f"Excessive superlatives ({len(superlative_matches)} found)")
383
+
384
+ # ── 3. No attribution (-0.15) ──
385
+ if NO_ATTRIB_RE.search(clean_text_str):
386
+ score -= 0.15
387
+ deductions.append("Anonymous/vague attribution patterns found")
388
+
389
+ # ── 4. Headline contradicts body (-0.10) ──
390
+ # Guard: only run if title looks like a real headline, not an auto-extracted body sentence
391
+ is_real_headline = (
392
+ title_str
393
+ and len(title_str) > 10
394
+ and len(title_str.split()) <= 15
395
+ and not title_str.lower().startswith(("it has", "it was", "it is", "there was", "there is"))
396
+ and title_str.lower() not in str(text).lower()[:100]
397
+ )
398
+ if is_real_headline:
399
+ body_only = str(text)[:512] # Raw body text, NOT clean_text_str which has title prepended
400
+ try:
401
+ if "distilbert" in _MODEL_CACHE:
402
+ model, tok, device = _MODEL_CACHE["distilbert"]
403
+ with torch.no_grad():
404
+ t_enc = tok(title_str, return_tensors="pt", truncation=True, max_length=64, padding=True).to(device)
405
+ b_enc = tok(body_only, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device)
406
+ t_hidden = model.distilbert(**t_enc).last_hidden_state[:, 0, :] # CLS token
407
+ b_hidden = model.distilbert(**b_enc).last_hidden_state[:, 0, :]
408
+ cos_sim = float(torch.nn.functional.cosine_similarity(t_hidden, b_hidden).item())
409
+ if cos_sim < 0.30:
410
+ headline_contradicts = True
411
+ score -= 0.10
412
+ deductions.append(f"Headline may contradict body (similarity={cos_sim:.2f})")
413
+ except Exception as e:
414
+ # Fallback: simple word overlap against body only
415
+ title_words = set(title_str.lower().split())
416
+ body_words = set(body_only.lower().split())
417
+ overlap = len(title_words & body_words) / max(len(title_words), 1)
418
+ if overlap < 0.15 and len(title_words) > 3:
419
+ headline_contradicts = True
420
+ score -= 0.10
421
+ deductions.append("Headline has very low word overlap with body")
422
+
423
+ # ── 5. Internal contradictions (-0.10) ──
424
+ # Heuristic: negation near repeated noun phrase
425
+ sentences = re.split(r'[.!?]+', clean_text_str[:3000])
426
+ negation_re = re.compile(r"\b(not|no|never|false|deny|denied|incorrect|wrong)\b", re.IGNORECASE)
427
+ noun_counts = {}
428
+ contradiction_found = False
429
+ for sent in sentences:
430
+ words = sent.lower().split()
431
+ # Track nouns (simple: capitalized words in original text)
432
+ for w in words:
433
+ if len(w) > 3:
434
+ noun_counts[w] = noun_counts.get(w, 0) + 1
435
+ # Check if a repeated noun appears near negation
436
+ if negation_re.search(sent):
437
+ for w in words:
438
+ if noun_counts.get(w, 0) >= 2 and len(w) > 4:
439
+ contradiction_found = True
440
+ break
441
+ if contradiction_found:
442
+ break
443
+ if contradiction_found:
444
+ score -= 0.10
445
+ deductions.append("Possible internal contradiction detected")
446
+
447
+ # ── 6. Passive voice obscuring agency (-0.10) ──
448
+ if PASSIVE_VOICE_RE.search(clean_text_str):
449
+ score -= 0.10
450
+ deductions.append("Passive voice used to obscure agency")
451
+
452
+ # ── 7. Missing byline (-0.05) ──
453
+ if not author_found:
454
+ score -= 0.05
455
+ deductions.append("No byline or author attribution found")
456
+
457
+ score = max(0.0, score)
458
+ return score, deductions, headline_contradicts
459
+
460
+
461
+ # ═════════════════════════════════════════════���═══════════════════════════════
462
+ # STEP 4 — FRESHNESS (weight: 10%)
463
+ # ═════════════════════════════════════════════════════════════════════════════
464
+
465
+ def score_freshness_v2(published_date, has_date, title, text):
466
+ """
467
+ Step 4: Temporal freshness scoring.
468
+ Case A: Date found → bracket-based scoring.
469
+ Case B: No date → contextual signal scanning.
470
+ Returns: (score, case, signals_found)
471
+ """
472
+ if has_date and published_date is not None:
473
+ # ── Case A ──
474
+ now = datetime.now(timezone.utc)
475
+ try:
476
+ if getattr(published_date, 'tzinfo', None) is None:
477
+ published_date = published_date.replace(tzinfo=timezone.utc)
478
+ days_old = (now - published_date).days
479
+ except Exception:
480
+ # Fallback to Case B if date math fails
481
+ return _freshness_case_b(title, text)
482
+
483
+ if days_old < 0:
484
+ days_old = 0
485
+
486
+ if days_old < 30:
487
+ return 1.0, "A", []
488
+ elif days_old <= 180:
489
+ return 0.75, "A", []
490
+ elif days_old <= 730: # 2 years
491
+ return 0.5, "A", []
492
+ else:
493
+ return 0.2, "A", []
494
+ else:
495
+ return _freshness_case_b(title, text)
496
+
497
+
498
+ def _freshness_case_b(title, text):
499
+ """Case B: No date found — scan for contextual freshness signals."""
500
+ combined = str(title) + " " + str(text)
501
+ signals = []
502
+ now = datetime.now()
503
+
504
+ # Signal 1: Current year mentioned (dynamic)
505
+ year_re = re.compile(r"\b(" + str(now.year) + r"|" + str(now.year - 1) + r")\b")
506
+ if year_re.search(combined):
507
+ signals.append(f"Current/recent year mentioned ({now.year} or {now.year-1})")
508
+
509
+ # Signal 2: Temporal phrases
510
+ if TEMPORAL_RE.search(combined):
511
+ signals.append("Temporal freshness phrase detected")
512
+
513
+ # Signal 3: Named institution
514
+ if INSTITUTION_RE.search(combined):
515
+ signals.append("Named institutional publisher found")
516
+
517
+ # Signal 4: Major outlet corroboration
518
+ if CORROBORATION_OUTLETS_RE.search(combined):
519
+ signals.append("Major outlet corroboration cited")
520
+
521
+ score_map = {4: 0.80, 3: 0.70, 2: 0.60, 1: 0.50, 0: 0.40}
522
+ n = min(len(signals), 4)
523
+ return score_map[n], "B", signals
524
+
525
+
526
+ # ═════════════════════════════════════════════════════════════════════════════
527
+ # STEP 5 — MODEL VOTE (weight: 10%)
528
+ # ═════════════════════════════════════════════════════════════════════════════
529
+
530
+ def score_model_vote(votes):
531
+ """Step 5: Proportion of TRUE votes from the ensemble."""
532
+ if not votes:
533
+ return 0.5
534
+ return sum(votes.values()) / len(votes)
535
+
536
+
537
+ # ═════════════════════════════════════════════════════════════════════════════
538
+ # ADVERSARIAL OVERRIDE
539
+ # ═════════════════════════════════════════════════════════════════════════════
540
+
541
+ def check_adversarial_flags(has_date, author_found, n_verifiable, headline_contradicts,
542
+ typosquatting_detected, text):
543
+ """
544
+ Post-scoring adversarial check.
545
+ Any flag → cap final_score at 0.25.
546
+ Returns: list of triggered flag names.
547
+ """
548
+ flags = []
549
+
550
+ # Flag 1: Triple anonymity
551
+ if not has_date and not author_found and n_verifiable == 0:
552
+ flags.append("Triple anonymity (no date, no author, no named sources)")
553
+
554
+ # Flag 2: Headline contradicts body
555
+ if headline_contradicts:
556
+ flags.append("Headline contradicts article body")
557
+
558
+ # Flag 3: Typosquatting
559
+ if typosquatting_detected:
560
+ flags.append("Domain mimics a known outlet (typosquatting)")
561
+
562
+ # Flag 4: Statistics without traceable source
563
+ stats_found = STAT_RE.findall(text)
564
+ if stats_found:
565
+ # Check if any citation pattern exists in the text
566
+ if not CITATION_RE.search(text):
567
+ flags.append("Statistics cited with no traceable primary source")
568
+
569
+ return flags
570
+
571
+
572
+ # ═════════════════════════════════════════════════════════════════════════════
573
+ # REASON BUILDER
574
+ # ═══════════════════════════════��═════════════════════════════════════════════
575
+
576
+ def build_reasons_and_missing(scores, n_verifiable, author_found, has_date,
577
+ deductions, adversarial_flags):
578
+ """
579
+ Programmatically generate top_reasons and missing_signals from scores.
580
+ Returns: (reasons[:3], missing_signals)
581
+ """
582
+ reasons = []
583
+ missing = []
584
+
585
+ # ── Negative signals ──
586
+ if scores["source"] < 0.4:
587
+ reasons.append("Source is unknown or not editorially accountable")
588
+ if scores["claim"] < 0.5:
589
+ reasons.append("Core claims could not be fully verified")
590
+ if scores["linguistic"] < 0.7:
591
+ reasons.append("Writing style shows signs of sensationalism or manipulation")
592
+ if scores["freshness"] < 0.5:
593
+ reasons.append("Article age or missing date reduces temporal reliability")
594
+ if scores["model_vote"] < 0.5:
595
+ reasons.append("AI models flagged patterns inconsistent with credible journalism")
596
+
597
+ # ── Positive signals ──
598
+ if scores["source"] >= 0.8:
599
+ reasons.append("Article is from a known, credible outlet")
600
+ if scores["claim"] >= 0.8:
601
+ reasons.append("Core claims are well-attributed with verifiable entities")
602
+ if scores["linguistic"] >= 0.9:
603
+ reasons.append("Writing style is neutral and well-attributed")
604
+ if scores["model_vote"] >= 0.75:
605
+ reasons.append("AI models strongly agree this content is credible")
606
+
607
+ # ── Adversarial flags ──
608
+ for flag in adversarial_flags:
609
+ reasons.append(f"Adversarial flag: {flag}")
610
+
611
+ # ── Missing signals ──
612
+ if not author_found:
613
+ missing.append("Author identity could not be verified")
614
+ if not has_date:
615
+ missing.append("Publication date not found")
616
+ if scores["source"] <= 0.3:
617
+ missing.append("Source domain not recognized")
618
+ if n_verifiable == 0:
619
+ missing.append("No verifiable named entities found in text")
620
+
621
+ return reasons[:3], missing
622
+
623
+
624
+ # ═════════════════════════════════════════════════════════════════════════════
625
+ # MAIN INFERENCE INTERFACE
626
+ # ═════════════════════════════════════════════════════════════════════════════
627
+
628
+ def predict_article(title, text, source_domain, published_date, mode="full", trigger_rag=True):
629
+ """
630
+ 5-Signal weighted scoring inference.
631
+
632
+ Execution order:
633
+ 1. extract_features()
634
+ 2. Run base models (LR/LSTM/DistilBERT/RoBERTa) → probas, votes
635
+ 3. Run meta-classifier → meta_proba
636
+ 4. Step 1: score_source_credibility()
637
+ 5. Step 2: score_claim_verification()
638
+ 6. Step 3: score_linguistic_quality() [needs author_found from Step 1]
639
+ 7. Step 4: score_freshness_v2()
640
+ 8. Step 5: score_model_vote()
641
+ 9. Weighted final score + adversarial override + verdict
642
+ """
643
+ cfg = load_config()
644
+ feat = extract_features(title, text, source_domain, published_date, cfg)
645
+
646
+ probas = {
647
+ "lr_proba": np.nan, "lstm_proba": np.nan,
648
+ "distilbert_proba": np.nan, "roberta_proba": np.nan,
649
+ }
650
+ votes = {}
651
+
652
+ # ── Base Model Inference ──────────────────────────────────────────────
653
+
654
+ # 1. Logistic Regression
655
+ if mode in ("fast", "balanced", "full"):
656
+ lr_pipe = _get_model("logistic", cfg)
657
+ df_lr = pd.DataFrame([{
658
+ "clean_text": feat["clean_text"],
659
+ "word_count": feat["word_count"],
660
+ "text_length_bucket": feat["text_length_bucket"],
661
+ "has_date": 1 if feat["has_date"] else 0,
662
+ "freshness_score": 0.5, # neutral for model input
663
+ "source_domain": feat["source_domain"],
664
+ }])
665
+ try:
666
+ p = float(lr_pipe.predict_proba(df_lr)[:, 1][0])
667
+ probas["lr_proba"] = p
668
+ votes["logistic"] = int(p >= 0.5)
669
+ except Exception as e:
670
+ logger.warning(f"LR inference failed: {e}")
671
+
672
+ # 2. Bi-LSTM
673
+ if mode in ("balanced", "full"):
674
+ lstm_model, tok, device = _get_model("lstm", cfg)
675
+ maxlen = cfg.get("preprocessing", {}).get("lstm_max_len", 512)
676
+ from src.models.lstm_model import pad_sequences
677
+
678
+ seq = tok.texts_to_sequences([feat["clean_text"]])
679
+ pad = pad_sequences(seq, maxlen=maxlen, padding='post')
680
+ t_pad = torch.from_numpy(pad).long().to(device)
681
+
682
+ with torch.no_grad():
683
+ logits = lstm_model(t_pad)
684
+ p = float(torch.sigmoid(logits).cpu().numpy()[0])
685
+ probas["lstm_proba"] = p
686
+ votes["lstm"] = int(p >= 0.5)
687
+
688
+ # 3. Transformers (DistilBERT + RoBERTa)
689
+ if mode == "full":
690
+ for t_name in ("distilbert", "roberta"):
691
+ model, tok, device = _get_model(t_name, cfg)
692
+ inputs = tok(feat["clean_text"], padding=True, truncation=True,
693
+ max_length=512, return_tensors="pt").to(device)
694
+ with torch.no_grad():
695
+ out = model(**inputs)
696
+ p = float(torch.softmax(out.logits, dim=-1)[0, 1].item())
697
+ if t_name == "roberta":
698
+ p = p * 0.92 # RoBERTa TRUE-bias dampening
699
+ probas[t_name + "_proba"] = p
700
+ votes[t_name] = int(p >= 0.5)
701
+
702
+ # 4. Meta-Classifier
703
+ meta_bundle = _get_model("meta", cfg)
704
+ meta_preprocessor = meta_bundle["preprocessor"]
705
+ meta_model = meta_bundle["model"]
706
+
707
+ df_meta = pd.DataFrame([{
708
+ "lr_proba": probas["lr_proba"],
709
+ "lstm_proba": probas["lstm_proba"],
710
+ "distilbert_proba": probas["distilbert_proba"],
711
+ "roberta_proba": probas["roberta_proba"],
712
+ "word_count": feat["word_count"],
713
+ "has_date": 1 if feat["has_date"] else 0,
714
+ "freshness_score": 0.5, # neutral — freshness is now scored separately in Step 4
715
+ }])
716
+
717
+ df_cats = pd.DataFrame([{
718
+ "text_length_bucket": feat["text_length_bucket"],
719
+ "source_domain": feat["source_domain"],
720
+ }])
721
+ cat_feats = meta_preprocessor.transform(df_cats)
722
+ X_meta = np.hstack((df_meta.values, cat_feats))
723
+
724
+ meta_proba = float(meta_model.predict_proba(X_meta)[:, 1][0])
725
+
726
+ # Short-text dampening (under 50 words)
727
+ short_text = feat["word_count"] < 50
728
+ if short_text:
729
+ meta_proba = 0.5 + (meta_proba - 0.5) * 0.6
730
+
731
+ # ── 5-Signal Scoring ─────────────────────────────────────────────────
732
+
733
+ # Step 1: Source Credibility
734
+ source_score, author_found, typosquat = score_source_credibility(
735
+ feat["source_domain"], title, text
736
+ )
737
+
738
+ # Step 2: Claim Verification
739
+ claim_score, entities_found, n_verifiable, q_attr, q_total = score_claim_verification(
740
+ meta_proba, feat["clean_text"], title
741
+ )
742
+
743
+ # Step 3: Linguistic Analysis (depends on author_found from Step 1)
744
+ ling_score, deductions, headline_contradicts = score_linguistic_quality(
745
+ title, text, feat["clean_text"], author_found, cfg
746
+ )
747
+
748
+ # Step 4: Freshness
749
+ fresh_score, fresh_case, fresh_signals = score_freshness_v2(
750
+ feat.get("published_date"), feat["has_date"], title, text
751
+ )
752
+
753
+ # Step 5: Model Vote
754
+ vote_score = score_model_vote(votes)
755
+
756
+ # ── Final Weighted Score ──────────────────────────────────────────────
757
+
758
+ scores = {
759
+ "source": round(source_score, 4),
760
+ "claim": round(claim_score, 4),
761
+ "linguistic": round(ling_score, 4),
762
+ "freshness": round(fresh_score, 4),
763
+ "model_vote": round(vote_score, 4),
764
+ }
765
+
766
+ final_score = (
767
+ source_score * 0.30 +
768
+ claim_score * 0.30 +
769
+ ling_score * 0.20 +
770
+ fresh_score * 0.10 +
771
+ vote_score * 0.10
772
+ )
773
+
774
+ # ── Adversarial Override ──────────────────────────────────────────────
775
+
776
+ adv_flags = check_adversarial_flags(
777
+ feat["has_date"], author_found, n_verifiable,
778
+ headline_contradicts, typosquat, feat["clean_text"]
779
+ )
780
+ if adv_flags:
781
+ final_score = min(final_score, 0.25)
782
+
783
+ final_score = round(final_score, 4)
784
+
785
+ # ── Verdict ───────────────────────────────────────────────────────────
786
+
787
+ if final_score >= 0.75:
788
+ verdict = "TRUE"
789
+ elif final_score >= 0.55:
790
+ verdict = "UNCERTAIN"
791
+ elif final_score >= 0.35:
792
+ verdict = "LIKELY FALSE"
793
+ else:
794
+ verdict = "FALSE"
795
+
796
+ # ── Reasons & Missing Signals ─────────────────────────────────────────
797
+
798
+ top_reasons, missing_signals = build_reasons_and_missing(
799
+ scores, n_verifiable, author_found, feat["has_date"],
800
+ deductions, adv_flags
801
+ )
802
+
803
+ # ── Confidence ────────────────────────────────────────────────────────
804
+
805
+ missing_count = len(missing_signals)
806
+ if adv_flags or missing_count >= 3:
807
+ confidence = "LOW"
808
+ elif verdict == "UNCERTAIN" or missing_count in (1, 2):
809
+ confidence = "MEDIUM"
810
+ elif final_score >= 0.75 or final_score < 0.35:
811
+ confidence = "HIGH"
812
+ else:
813
+ confidence = "MEDIUM"
814
+
815
+ # ── Recommended Action + LOW Guard ──────────────────���─────────────────
816
+
817
+ action_map = {
818
+ "TRUE": "Publish",
819
+ "UNCERTAIN": "Flag for review",
820
+ "LIKELY FALSE": "Suppress",
821
+ "FALSE": "Escalate",
822
+ }
823
+ recommended_action = action_map[verdict]
824
+
825
+ # Hard rule: LOW confidence → never "Publish"
826
+ if confidence == "LOW" and recommended_action == "Publish":
827
+ recommended_action = "Flag for review"
828
+
829
+ # ── Return Full JSON ──────────────────────────────────────────────────
830
+
831
+ return {
832
+ "verdict": verdict,
833
+ "final_score": final_score,
834
+ "scores": scores,
835
+ "freshness_case": fresh_case,
836
+ "freshness_signals_found": fresh_signals,
837
+ "adversarial_flags": adv_flags,
838
+ "top_reasons": top_reasons,
839
+ "missing_signals": missing_signals,
840
+ "confidence": confidence,
841
+ "recommended_action": recommended_action,
842
+ "base_model_votes": votes,
843
+ "base_model_probas": probas,
844
+ "word_count": feat["word_count"],
845
+ "short_text_warning": short_text,
846
+ "deductions_applied": deductions,
847
+ "entities_found": entities_found,
848
+ "quotes_attributed": q_attr,
849
+ "quotes_total": q_total,
850
+ }
851
+
852
+
853
+ if __name__ == "__main__":
854
+ import json
855
+ try:
856
+ res = predict_article(
857
+ "Breaking: AI solves P=NP",
858
+ "The algorithm has shocked absolutely everyone across the earth entirely "
859
+ "resolving everything overnight. Sources say it is unprecedented.",
860
+ "techcrunch.com",
861
+ datetime.now().isoformat(),
862
+ mode="fast"
863
+ )
864
+ print("Verdict Dict:")
865
+ print(json.dumps(res, indent=2, default=str))
866
+ except ModelNotTrainedError as e:
867
+ print("ERROR:", str(e))
src/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Fake News Detection — Utilities Package
src/utils/deduplication.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ deduplication.py — Fast near-duplicate removal for large datasets.
3
+
4
+ Uses a two-phase strategy:
5
+ 1. **Exact dedup** — hash-based O(n) removal of identical texts.
6
+ 2. **Near-dedup via Sentence-BERT** — encode texts, build a cosine
7
+ similarity index, and remove near-duplicate pairs above a
8
+ configurable threshold. Uses chunked approach with early
9
+ termination to keep runtime feasible on 100K+ rows.
10
+
11
+ The ``all-MiniLM-L6-v2`` model is used for embedding.
12
+ """
13
+
14
+ import hashlib
15
+ import logging
16
+ import time
17
+ from typing import Dict, List, Optional, Set, Tuple
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ _DEFAULT_MODEL = "all-MiniLM-L6-v2"
25
+
26
+
27
+ # ═══════════════════════════════════════════════════════════
28
+ # Phase 1: Exact dedup (hash-based, O(n))
29
+ # ═══════════════════════════════════════════════════════════
30
+
31
+ def _exact_dedup(df: pd.DataFrame, text_column: str) -> Tuple[pd.DataFrame, int]:
32
+ """Remove rows with identical text via SHA-256 hashing.
33
+
34
+ Args:
35
+ df: Input DataFrame.
36
+ text_column: Column to hash for exact comparison.
37
+
38
+ Returns:
39
+ (deduplicated DataFrame, number of rows removed).
40
+ """
41
+ before = len(df)
42
+ hashes: Dict[str, int] = {}
43
+ keep: List[bool] = []
44
+
45
+ for idx, txt in enumerate(df[text_column].fillna("").astype(str)):
46
+ h = hashlib.sha256(txt.encode("utf-8", errors="replace")).hexdigest()
47
+ if h in hashes:
48
+ keep.append(False)
49
+ else:
50
+ hashes[h] = idx
51
+ keep.append(True)
52
+
53
+ df_out = df.loc[keep].reset_index(drop=True)
54
+ removed = before - len(df_out)
55
+ logger.info("Exact dedup: removed %d / %d identical rows", removed, before)
56
+ return df_out, removed
57
+
58
+
59
+ # ═══════════════════════════════════════════════════════════
60
+ # Phase 2: Semantic near-dedup (Sentence-BERT + chunked cosine)
61
+ # ═══════════════════════════════════════════════════════════
62
+
63
+ def _semantic_dedup(
64
+ df: pd.DataFrame,
65
+ text_column: str,
66
+ threshold: float,
67
+ batch_size: int,
68
+ model_name: str,
69
+ max_rows_for_pairwise: int = 30_000,
70
+ ) -> Tuple[pd.DataFrame, int]:
71
+ """Remove near-duplicate rows using Sentence-BERT cosine similarity.
72
+
73
+ For datasets larger than *max_rows_for_pairwise*, the comparison is
74
+ done in a block-diagonal fashion (each chunk vs. itself) to keep
75
+ computation tractable. Cross-chunk duplicates are rare across
76
+ dataset origins, and exact dedup already handles identical pairs.
77
+
78
+ Args:
79
+ df: Input DataFrame (already exact-deduped).
80
+ text_column: Column to encode.
81
+ threshold: Cosine similarity cutoff.
82
+ batch_size: Encoding batch size.
83
+ model_name: Sentence-BERT model name.
84
+ max_rows_for_pairwise: Max rows for full pairwise comparison.
85
+
86
+ Returns:
87
+ (deduplicated DataFrame, number of rows removed).
88
+ """
89
+ from sentence_transformers import SentenceTransformer
90
+ from sklearn.metrics.pairwise import cosine_similarity
91
+
92
+ n = len(df)
93
+ if n < 2:
94
+ return df.copy(), 0
95
+
96
+ texts = df[text_column].fillna("").astype(str).tolist()
97
+ # Truncate long texts to first 256 chars for fast encoding
98
+ texts_trunc = [t[:256] for t in texts]
99
+
100
+ logger.info(
101
+ "Encoding %d texts with %s (batch_size=%d) …",
102
+ n, model_name, batch_size,
103
+ )
104
+ model = SentenceTransformer(model_name)
105
+ embeddings = model.encode(
106
+ texts_trunc,
107
+ batch_size=batch_size,
108
+ show_progress_bar=True,
109
+ convert_to_numpy=True,
110
+ normalize_embeddings=True,
111
+ )
112
+
113
+ duplicate_indices: Set[int] = set()
114
+
115
+ if n <= max_rows_for_pairwise:
116
+ # Full pairwise — feasible for ≤ 30K rows
117
+ logger.info("Running full pairwise cosine similarity (%d × %d) …", n, n)
118
+ chunk_size = 2000
119
+ for start in range(0, n, chunk_size):
120
+ end = min(start + chunk_size, n)
121
+ sim = cosine_similarity(embeddings[start:end], embeddings)
122
+ for li in range(sim.shape[0]):
123
+ gi = start + li
124
+ if gi in duplicate_indices:
125
+ continue
126
+ # Only compare with later-indexed rows
127
+ for j in range(gi + 1, n):
128
+ if j in duplicate_indices:
129
+ continue
130
+ if sim[li, j] >= threshold:
131
+ duplicate_indices.add(j)
132
+ else:
133
+ # For very large datasets: compare within blocks of 10K rows
134
+ logger.info(
135
+ "Dataset too large (%d) for full pairwise — using block dedup",
136
+ n,
137
+ )
138
+ block_size = 10_000
139
+ for block_start in range(0, n, block_size):
140
+ block_end = min(block_start + block_size, n)
141
+ block_emb = embeddings[block_start:block_end]
142
+ block_n = block_end - block_start
143
+ logger.info(
144
+ " Block [%d:%d] (%d rows) …",
145
+ block_start, block_end, block_n,
146
+ )
147
+ sim = cosine_similarity(block_emb, block_emb)
148
+ for li in range(block_n):
149
+ gi = block_start + li
150
+ if gi in duplicate_indices:
151
+ continue
152
+ for lj in range(li + 1, block_n):
153
+ gj = block_start + lj
154
+ if gj in duplicate_indices:
155
+ continue
156
+ if sim[li, lj] >= threshold:
157
+ duplicate_indices.add(gj)
158
+
159
+ removed = len(duplicate_indices)
160
+ if removed > 0:
161
+ keep_mask = np.ones(n, dtype=bool)
162
+ for idx in duplicate_indices:
163
+ keep_mask[idx] = False
164
+ df_out = df.loc[keep_mask].reset_index(drop=True)
165
+ else:
166
+ df_out = df.copy()
167
+
168
+ logger.info("Semantic dedup: removed %d / %d near-duplicate rows", removed, n)
169
+ return df_out, removed
170
+
171
+
172
+ # ═══════════════════════════════════════════════════════════
173
+ # Public API
174
+ # ═══════════════════════════════════════════════════════════
175
+
176
+ def deduplicate_dataframe(
177
+ df: pd.DataFrame,
178
+ text_column: str = "text",
179
+ threshold: float = 0.92,
180
+ batch_size: int = 64,
181
+ model_name: str = _DEFAULT_MODEL,
182
+ origin_column: Optional[str] = "dataset_origin",
183
+ ) -> Tuple[pd.DataFrame, Dict[str, int]]:
184
+ """Remove duplicate rows from *df* (exact + semantic).
185
+
186
+ Args:
187
+ df: Input DataFrame (must contain *text_column*).
188
+ text_column: Column to use for duplicate detection.
189
+ threshold: Cosine similarity cutoff for near-dedup.
190
+ batch_size: Encoding batch size.
191
+ model_name: Sentence-BERT model identifier.
192
+ origin_column: Optional column for per-origin stats.
193
+
194
+ Returns:
195
+ (cleaned DataFrame, stats dict with per-origin removal counts).
196
+ """
197
+ t0 = time.perf_counter()
198
+ logger.info("=" * 60)
199
+ logger.info("Starting deduplication pipeline (threshold=%.2f) …", threshold)
200
+ n_before = len(df)
201
+
202
+ # Phase 1: exact
203
+ df_exact, exact_removed = _exact_dedup(df, text_column)
204
+
205
+ # Phase 2: semantic
206
+ df_final, semantic_removed = _semantic_dedup(
207
+ df_exact,
208
+ text_column=text_column,
209
+ threshold=threshold,
210
+ batch_size=batch_size,
211
+ model_name=model_name,
212
+ )
213
+
214
+ total_removed = n_before - len(df_final)
215
+
216
+ # Build per-origin stats
217
+ stats: Dict[str, int] = {}
218
+ if origin_column and origin_column in df.columns:
219
+ before_counts = df[origin_column].value_counts().to_dict()
220
+ after_counts = df_final[origin_column].value_counts().to_dict()
221
+ for origin in before_counts:
222
+ stats[origin] = before_counts[origin] - after_counts.get(origin, 0)
223
+ else:
224
+ stats["total"] = total_removed
225
+
226
+ elapsed = time.perf_counter() - t0
227
+ logger.info(
228
+ "Dedup complete: %d → %d rows (removed %d, %.1f%%) in %.1fs",
229
+ n_before, len(df_final), total_removed,
230
+ 100 * total_removed / max(n_before, 1), elapsed,
231
+ )
232
+ for origin, cnt in stats.items():
233
+ if cnt > 0:
234
+ logger.info(" %-30s %6d removed", origin, cnt)
235
+ logger.info("=" * 60)
236
+
237
+ return df_final, stats
238
+
239
+
240
+ # ─── standalone test ────────────────────────────────────────
241
+ if __name__ == "__main__":
242
+ logging.basicConfig(level=logging.INFO)
243
+ sample = pd.DataFrame({
244
+ "text": [
245
+ "The president signed the bill into law today.",
246
+ "The president signed the bill into law today.", # exact dup
247
+ "Scientists discover a new species of frog in the Amazon.",
248
+ "A new frog species has been found in the Amazon rainforest.", # near dup
249
+ "Stock markets rallied after a strong jobs report.",
250
+ ],
251
+ "dataset_origin": ["a", "a", "b", "b", "c"],
252
+ })
253
+ clean, info = deduplicate_dataframe(sample, threshold=0.92)
254
+ print(f"\nKept {len(clean)} / {len(sample)} rows")
255
+ print("Stats:", info)
256
+ print(clean[["text", "dataset_origin"]])
src/utils/domain_weights.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ def compute_domain_weights(
8
+ df: pd.DataFrame,
9
+ min_domain_samples: int = 20,
10
+ max_multiplier: float = 10.0
11
+ ) -> pd.DataFrame:
12
+ """
13
+ Compute domain-aware sample weights to handle heavily biased domains.
14
+
15
+ Rules:
16
+ - Domains with < min_domain_samples -> merge into "other"
17
+ - Calculate global class distributions.
18
+ - Calculate per-domain class distributions.
19
+ - Compute weight = global_class_ratio / domain_class_ratio
20
+ - Clip weights at max_multiplier * median_weight
21
+
22
+ Args:
23
+ df: Input DataFrame containing 'source_domain' and 'binary_label'
24
+ min_domain_samples: Threshold below which domains are grouped to 'other'
25
+ max_multiplier: Max multiplier over the median weight to clip extreme weights
26
+
27
+ Returns:
28
+ DataFrame with an additional 'sample_weight' column.
29
+ """
30
+ df = df.copy()
31
+
32
+ # Ensure source_domain exists
33
+ if "source_domain" not in df.columns:
34
+ logger.warning("'source_domain' not found in DataFrame. Returning weights=1.0")
35
+ df["sample_weight"] = 1.0
36
+ return df
37
+
38
+ # 1. Merge small domains into "other"
39
+ domain_counts = df["source_domain"].value_counts()
40
+ small_domains = set(domain_counts[domain_counts < min_domain_samples].index)
41
+
42
+ df["_effective_domain"] = df["source_domain"].apply(
43
+ lambda x: "other" if x in small_domains or not isinstance(x, str) else x
44
+ )
45
+
46
+ # 2. Compute global class ratios
47
+ global_counts = df["binary_label"].value_counts()
48
+ global_total = len(df)
49
+ global_ratio = {
50
+ label: count / global_total
51
+ for label, count in global_counts.items()
52
+ }
53
+
54
+ # 3. Compute domain class ratios and assign weights
55
+ # We group by domain and label to get counts per domain
56
+ domain_label_counts = df.groupby(["_effective_domain", "binary_label"]).size().unstack(fill_value=0)
57
+ domain_totals = domain_label_counts.sum(axis=1)
58
+
59
+ weights_map = {}
60
+ for domain in domain_label_counts.index:
61
+ weights_map[domain] = {}
62
+ d_total = domain_totals[domain]
63
+ for label in global_ratio.keys():
64
+ if label in domain_label_counts.columns:
65
+ d_count = domain_label_counts.loc[domain, label]
66
+ if d_count == 0:
67
+ # If domain has 0 instances of this class, we won't observe it here anyway,
68
+ # but set some fallback value.
69
+ weights_map[domain][label] = 1.0
70
+ else:
71
+ d_ratio = d_count / d_total
72
+ weights_map[domain][label] = global_ratio[label] / d_ratio
73
+ else:
74
+ weights_map[domain][label] = 1.0
75
+
76
+ # 4. Map weights back to dataframe
77
+ df["sample_weight"] = df.apply(
78
+ lambda r: weights_map[r["_effective_domain"]].get(r["binary_label"], 1.0),
79
+ axis=1
80
+ )
81
+
82
+ # 5. Clip weights at max_multiplier * median_weight
83
+ median_w = df["sample_weight"].median()
84
+ max_w = max_multiplier * median_w
85
+ df["sample_weight"] = df["sample_weight"].clip(upper=max_w)
86
+
87
+ # Clean up temp col
88
+ df.drop(columns=["_effective_domain"], inplace=True)
89
+
90
+ logger.info("Computed domain weights (median: %.3f, max applied: %.3f)", median_w, df["sample_weight"].max())
91
+
92
+ return df
93
+
94
+ if __name__ == "__main__":
95
+ # Test script
96
+ data = pd.DataFrame({
97
+ "source_domain": ["nytimes.com"] * 100 + ["fakenews.biz"] * 100 + ["tinyblog.com"] * 5,
98
+ "binary_label": [1] * 90 + [0] * 10 + [0] * 95 + [1] * 5 + [0] * 5
99
+ })
100
+
101
+ out = compute_domain_weights(data, min_domain_samples=20, max_multiplier=10.0)
102
+ print(out.groupby("source_domain")["sample_weight"].mean())
src/utils/freshness.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timezone
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ def calculate_freshness(
6
+ published_date,
7
+ has_date: bool,
8
+ is_inference: bool = False,
9
+ reference_date: datetime = None
10
+ ) -> float:
11
+ """
12
+ Calculate the temporal freshness score for a single article.
13
+
14
+ Rules:
15
+ - score = 1.0 if article is < 30 days old
16
+ - score = max(0.1, 1 - (days_old / 365)) for older articles
17
+ - score = 0.5 if has_date is False (neutral for training)
18
+ - score = 0.35 if has_date is False AND called from inference
19
+
20
+ Args:
21
+ published_date: The published date of the article (datetime or NaT).
22
+ has_date: Boolean flag indicating if a valid date is present.
23
+ is_inference: Whether the scoring is happening during live inference.
24
+ reference_date: The date to compute 'days_old' against (defaults to now).
25
+
26
+ Returns:
27
+ Float score between 0.1 and 1.0.
28
+ """
29
+ if not has_date or pd.isna(published_date):
30
+ return 0.35 if is_inference else 0.50
31
+
32
+ if reference_date is None:
33
+ reference_date = datetime.now(timezone.utc)
34
+
35
+ # Ensure published_date is timezone-aware
36
+ if pd.api.types.is_scalar(published_date) and getattr(published_date, 'tzinfo', None) is None:
37
+ # Assuming UTC if naive, typical for web dates
38
+ try:
39
+ published_date = published_date.replace(tzinfo=timezone.utc)
40
+ except Exception:
41
+ pass
42
+
43
+ days_old = (reference_date - published_date).days
44
+
45
+ # Handle future dates gracefully (e.g., bad parsed data)
46
+ if days_old < 0:
47
+ days_old = 0
48
+
49
+ if days_old < 30:
50
+ return 1.0
51
+
52
+ return max(0.1, 1.0 - (days_old / 365.0))
53
+
54
+ def apply_freshness_score(df: pd.DataFrame, is_inference: bool = False) -> pd.DataFrame:
55
+ """
56
+ Apply freshness scoring to a DataFrame.
57
+ """
58
+ df = df.copy()
59
+ ref_date = datetime.now(timezone.utc)
60
+
61
+ # Vectorized execution wrapper
62
+ df["freshness_score"] = df.apply(
63
+ lambda r: calculate_freshness(
64
+ r.get("published_date"),
65
+ r.get("has_date", pd.notna(r.get("published_date"))),
66
+ is_inference,
67
+ ref_date
68
+ ),
69
+ axis=1
70
+ )
71
+ return df
src/utils/rag_retrieval.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import yaml
4
+ import logging
5
+ import spacy
6
+ import numpy as np
7
+
8
+ from duckduckgo_search import DDGS
9
+ from sentence_transformers import SentenceTransformer
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+
12
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ if str(_PROJECT_ROOT) not in sys.path:
14
+ sys.path.insert(0, str(_PROJECT_ROOT))
15
+
16
+ logger = logging.getLogger("rag_retrieval")
17
+
18
+ # Lazy-load massive models securely within the file
19
+ _NLP = None
20
+ _SIM_MODEL = None
21
+
22
+ def load_spacy():
23
+ global _NLP
24
+ if _NLP is None:
25
+ try:
26
+ _NLP = spacy.load("en_core_web_sm")
27
+ except OSError:
28
+ logger.info("Downloading spaCy en_core_web_sm model dynamically...")
29
+ import subprocess
30
+ subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"], check=True)
31
+ _NLP = spacy.load("en_core_web_sm")
32
+ return _NLP
33
+
34
+ def load_sim_model():
35
+ global _SIM_MODEL
36
+ if _SIM_MODEL is None:
37
+ _SIM_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
38
+ return _SIM_MODEL
39
+
40
+ def extract_focused_query(title: str, text: str) -> str:
41
+ """
42
+ Extracts the top 3 named entities + main noun phrases to form a focused DDG query.
43
+ """
44
+ nlp = load_spacy()
45
+ # Prioritize title if strong, else merge heavily
46
+ target = title if (isinstance(title, str) and len(title.split()) > 4) else (str(title) + " " + str(text))[:1000]
47
+
48
+ doc = nlp(target)
49
+
50
+ # 1. Grab Entities (ORG, PERSON, GPE, DATE, EVENT)
51
+ entities = [ent.text for ent in doc.ents if ent.label_ in ['ORG', 'PERSON', 'GPE', 'EVENT']]
52
+ unique_entities = list(dict.fromkeys(entities))[:3]
53
+
54
+ # 2. Grab top Noun Phrases if entities are missing
55
+ noun_phrases = [chunk.text for chunk in doc.noun_chunks]
56
+
57
+ query_parts = unique_entities.copy()
58
+ if len(query_parts) < 3:
59
+ for np_chunk in noun_phrases:
60
+ if np_chunk not in query_parts and len(np_chunk.split()) <= 3:
61
+ query_parts.append(np_chunk)
62
+ if len(query_parts) >= 3:
63
+ break
64
+
65
+ focused_query = " ".join(query_parts)
66
+ if not focused_query.strip():
67
+ # Fallback to pure headline first 5 words
68
+ focused_query = " ".join(target.split()[:5])
69
+
70
+ return focused_query
71
+
72
+ def execute_rag(title: str, text: str):
73
+ """
74
+ 1. Extracts Query.
75
+ 2. DuckDuckGo Search (top 5).
76
+ 3. Measure Similarity vs article body via all-MiniLM-L6-v2.
77
+ 4. Return strict evaluations.
78
+ """
79
+ cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
80
+ with open(cfg_path, "r", encoding="utf-8") as f:
81
+ cfg = yaml.safe_load(f)
82
+ rag_cfg = cfg.get("rag", {})
83
+
84
+ top_k = rag_cfg.get("top_k", 5)
85
+ support_thresh = rag_cfg.get("support_threshold", 0.65)
86
+ conflict_thresh = rag_cfg.get("conflict_threshold", 0.30)
87
+
88
+ query = extract_focused_query(title, text)
89
+ logger.info(f"RAG Triggered. Extracted Search Query: {query}")
90
+
91
+ search_results = []
92
+ try:
93
+ with DDGS() as ddgs:
94
+ results = list(ddgs.text(query, max_results=top_k))
95
+ search_results = [r.get("body", r.get("title", "")) for r in results if isinstance(r, dict)]
96
+ except Exception as e:
97
+ logger.error(f"DDGS failure: {e}")
98
+ return {"status": "error", "message": "Search engine failure", "data": []}, "INCONCLUSIVE"
99
+
100
+ if not search_results:
101
+ return {"status": "empty", "message": "No external context found", "data": []}, "INCONCLUSIVE"
102
+
103
+ sim_model = load_sim_model()
104
+
105
+ # Compare target text against all snippets
106
+ corpus_text = (str(title) + " " + str(text))[:2000] # Cap memory context
107
+ embed_target = sim_model.encode([corpus_text])
108
+ embed_search = sim_model.encode(search_results)
109
+
110
+ similarities = cosine_similarity(embed_target, embed_search)[0]
111
+
112
+ supports = 0
113
+ conflicts = 0
114
+ eval_payload = []
115
+
116
+ for i, sim in enumerate(similarities):
117
+ s_float = float(sim)
118
+ if s_float >= support_thresh:
119
+ supports += 1
120
+ nature = "SUPPORTS"
121
+ elif s_float < conflict_thresh:
122
+ conflicts += 1
123
+ nature = "CONFLICTS"
124
+ else:
125
+ nature = "NEUTRAL"
126
+
127
+ eval_payload.append({
128
+ "snippet": search_results[i],
129
+ "similarity": s_float,
130
+ "nature": nature
131
+ })
132
+
133
+ # Rag Verdict Check
134
+ if supports >= 2:
135
+ verdict = "CORROBORATED"
136
+ elif conflicts >= 2:
137
+ verdict = "CONTRADICTED"
138
+ else:
139
+ verdict = "INCONCLUSIVE"
140
+
141
+ final_output = {
142
+ "status": "success",
143
+ "query": query,
144
+ "supports": supports,
145
+ "conflicts": conflicts,
146
+ "data": eval_payload
147
+ }
148
+
149
+ return final_output, verdict
150
+
151
+ if __name__ == "__main__":
152
+ t = "Eiffel Tower sold for scrap metal in surprising Paris decree."
153
+ tx = "The mayor of Paris declared the tower will be dismantled."
154
+ o, v = execute_rag(t, tx)
155
+ import json
156
+ print("Verdict:", v)
157
+ print(json.dumps(o, indent=2))