Spaces:
Sleeping
Sleeping
Commit ·
86b932c
0
Parent(s):
Clean build with correct gitignore
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +6 -0
- .gitignore +2 -0
- README.md +102 -0
- TruthLens_Paper.tex +84 -0
- app.py +641 -0
- config/config.yaml +60 -0
- models/saved/distilbert_model/cm_distil.png +0 -0
- models/saved/distilbert_model/config.json +28 -0
- models/saved/distilbert_model/distilbert_oof.npy +3 -0
- models/saved/distilbert_model/model.safetensors +3 -0
- models/saved/distilbert_model/tokenizer.json +0 -0
- models/saved/distilbert_model/tokenizer_config.json +14 -0
- models/saved/distilbert_model/training_args.bin +3 -0
- models/saved/logistic_model/cm.png +0 -0
- models/saved/logistic_model/logistic_model.pkl +3 -0
- models/saved/logistic_model/lr_oof.npy +3 -0
- models/saved/logistic_model/metrics.json +8 -0
- models/saved/lstm_model/cm.png +0 -0
- models/saved/lstm_model/lstm_oof.npy +3 -0
- models/saved/lstm_model/metrics.json +8 -0
- models/saved/lstm_model/model.pt +3 -0
- models/saved/meta_classifier/cm_meta.png +0 -0
- models/saved/meta_classifier/meta_classifier.pkl +3 -0
- models/saved/meta_classifier/metrics.json +8 -0
- models/saved/roberta_model/cm_roberta.png +0 -0
- models/saved/roberta_model/config.json +29 -0
- models/saved/roberta_model/model.safetensors +3 -0
- models/saved/roberta_model/roberta_oof.npy +3 -0
- models/saved/roberta_model/tokenizer.json +0 -0
- models/saved/roberta_model/tokenizer_config.json +16 -0
- models/saved/roberta_model/training_args.bin +3 -0
- models/saved/tokenizer.pkl +3 -0
- requirements.txt +18 -0
- run_pipeline.py +100 -0
- src/__init__.py +1 -0
- src/models/__init__.py +1 -0
- src/models/distilbert_model.py +201 -0
- src/models/logistic_model.py +141 -0
- src/models/lstm_model.py +314 -0
- src/models/meta_classifier.py +229 -0
- src/models/roberta_model.py +198 -0
- src/stage1_ingestion.py +728 -0
- src/stage2_preprocessing.py +186 -0
- src/stage3_training.py +41 -0
- src/stage4_inference.py +867 -0
- src/utils/__init__.py +1 -0
- src/utils/deduplication.py +256 -0
- src/utils/domain_weights.py +102 -0
- src/utils/freshness.py +71 -0
- src/utils/rag_retrieval.py +157 -0
.gitattributes
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/
|
| 2 |
+
__pycache__/
|
README.md
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TruthLens: Advanced Fake News Detection Pipeline
|
| 2 |
+
|
| 3 |
+
TruthLens is an end-to-end fake news detection system that moves beyond simple machine learning probabilities. It employs a robust **5-signal weighted scoring framework** built on journalistic standards, combining deep learning models (DistilBERT, RoBERTa), sequence models (LSTM), statistical models (Logistic Regression), and heuristic analysis to deliver explainable verdicts.
|
| 4 |
+
|
| 5 |
+
## 🌟 Key Features
|
| 6 |
+
|
| 7 |
+
* **5-Signal Scoring Framework:**
|
| 8 |
+
* **Source Credibility (30%):** Evaluates outlet reputation, author presence, and source corroboration, including typosquatting checks.
|
| 9 |
+
* **Claim Verification (30%):** Combines AI probability with spaCy-based Named Entity Recognition (NER) and quote attribution analysis.
|
| 10 |
+
* **Linguistic Quality (20%):** Detects sensationalism, superlatives, passive voice, and uses DistilBERT to check if the headline contradicts the body.
|
| 11 |
+
* **Freshness (10%):** Contextual and date-based temporal scoring to detect outdated information.
|
| 12 |
+
* **AI Model Consensus (10%):** Ensemble voting from Logistic Regression, LSTM, DistilBERT, and RoBERTa.
|
| 13 |
+
* **Adversarial Guardrails:** Hard caps and overrides for highly suspicious patterns (Triple Anonymity, Uncited Statistics, Headline Contradictions).
|
| 14 |
+
* **Live Web Corroboration:** RAG (Retrieval-Augmented Generation) pipeline using live search to verify unambiguous claims.
|
| 15 |
+
* **TruthLens UI:** A sleek, dark/light mode adaptable Streamlit dashboard providing detailed explainability down to the specific signals and deductions.
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## 📁 Project Structure
|
| 20 |
+
|
| 21 |
+
```text
|
| 22 |
+
fake_news_detection/
|
| 23 |
+
├── app.py # Streamlit frontend (TruthLens UI)
|
| 24 |
+
├── run_pipeline.py # Main script to run pipeline stages
|
| 25 |
+
├── requirements.txt # Python dependencies
|
| 26 |
+
├── src/
|
| 27 |
+
│ ├── stage1_ingestion.py # Downloads and prepares datasets
|
| 28 |
+
│ ├── stage2_preprocessing.py# Cleans text, tokenizes, and saves artifacts
|
| 29 |
+
│ ├── stage3_training.py # Trains models (LR, LSTM, DistilBERT, RoBERTa)
|
| 30 |
+
│ ├── stage4_inference.py # The 5-signal scoring engine and prediction logic
|
| 31 |
+
│ └── utils/
|
| 32 |
+
│ └── rag_retrieval.py # Live web search corroboration functions
|
| 33 |
+
├── data/ # Raw and processed datasets (created during execution)
|
| 34 |
+
└── models/ # Trained models and vectorizers (created during execution)
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## 🚀 Getting Started
|
| 40 |
+
|
| 41 |
+
### 1. Installation
|
| 42 |
+
|
| 43 |
+
Ensure you have Python 3.8+ installed. Install the required dependencies:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
pip install -r requirements.txt
|
| 47 |
+
python -m spacy download en_core_web_sm
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### 2. Running the Pipeline
|
| 51 |
+
|
| 52 |
+
The project is divided into stages. You can run the entire pipeline end-to-end, or run specific stages individually using `run_pipeline.py`.
|
| 53 |
+
|
| 54 |
+
**To run the complete training pipeline (Stages 1 to 3):**
|
| 55 |
+
*Note: This will download datasets, preprocess them, and train all models. It may take a significant amount of time depending on your hardware.*
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
python run_pipeline.py --stage 1 2 3
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
**To run individual stages:**
|
| 62 |
+
|
| 63 |
+
* **Stage 1: Data Ingestion**
|
| 64 |
+
Downloads and formats the necessary datasets (e.g., LIAR, ISOT).
|
| 65 |
+
```bash
|
| 66 |
+
python run_pipeline.py --stage 1
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
* **Stage 2: Preprocessing**
|
| 70 |
+
Cleans the text, maps verdicts to binary labels, and prepares DataFrames for training.
|
| 71 |
+
```bash
|
| 72 |
+
python run_pipeline.py --stage 2
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
* **Stage 3: Training**
|
| 76 |
+
Trains the ensemble: Logistic Regression, LSTM, DistilBERT, and RoBERTa. Saves the models to the `/models` directory.
|
| 77 |
+
```bash
|
| 78 |
+
python run_pipeline.py --stage 3
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
* **Stage 4: Evaluation**
|
| 82 |
+
Evaluates the trained pipeline on the holdout test set using the 5-signal inference framework.
|
| 83 |
+
```bash
|
| 84 |
+
python run_pipeline.py --eval
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## 🖥️ Running the Application
|
| 90 |
+
|
| 91 |
+
Once the models are trained (or if you already have the pre-trained weights in the `/models` directory), you can launch the TruthLens UI.
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
python -m streamlit run app.py
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
This will start a local web server (usually at `http://localhost:8501`).
|
| 98 |
+
|
| 99 |
+
### Using the App:
|
| 100 |
+
1. **Paste text or provide a URL:** You can paste the raw text of an article (with or without a headline) or simply provide a URL for the app to parse automatically.
|
| 101 |
+
2. **Select depth:** Choose Quick, Standard, or Deep analysis.
|
| 102 |
+
3. **View Results:** Explore the four-tier verdict (True, Uncertain, Likely False, False), signal breakdown, adversarial flags, and live web corroboration results.
|
TruthLens_Paper.tex
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
\documentclass[11pt,a4paper]{article}
|
| 2 |
+
\usepackage[utf8]{inputenc}
|
| 3 |
+
\usepackage{amsmath}
|
| 4 |
+
\usepackage{amsfonts}
|
| 5 |
+
\usepackage{amssymb}
|
| 6 |
+
\usepackage{graphicx}
|
| 7 |
+
\usepackage{hyperref}
|
| 8 |
+
\usepackage{geometry}
|
| 9 |
+
\usepackage{booktabs}
|
| 10 |
+
|
| 11 |
+
\geometry{margin=1in}
|
| 12 |
+
|
| 13 |
+
\title{\textbf{TruthLens: A 5-Signal Weighted Architecture \\for Explainable Fake News Detection}}
|
| 14 |
+
\author{Your Name \\ \textit{Your Institution/University} \\ \texttt{email@domain.com}}
|
| 15 |
+
\date{\today}
|
| 16 |
+
|
| 17 |
+
\begin{document}
|
| 18 |
+
|
| 19 |
+
\maketitle
|
| 20 |
+
|
| 21 |
+
\begin{abstract}
|
| 22 |
+
The proliferation of misinformation and fake news necessitates robust automated detection systems. Traditional machine learning approaches often lack explainability, relying solely on black-box probabilistic outputs without verifying journalistic integrity. In this paper, we present TruthLens, an end-to-end framework integrating deep language models (DistilBERT, RoBERTa), sequence models (LSTM), and statistical baselines (Logistic Regression) into a unique 5-signal weighted scoring architecture. The framework assesses source credibility, claim verification via Named Entity Recognition (NER), linguistic quality, temporal freshness, and ensemble model consensus. Furthermore, the system incorporates adversarial overrides and live Retrieval-Augmented Generation (RAG) corroboration to guard against typosquatting, triple anonymity, and hallucinated statistics. Our methodology transitions away from a simple binary classifier to an explainable, multi-faceted verification engine, significantly improving both robustness and user trust in algorithmic verdicts.
|
| 23 |
+
\end{abstract}
|
| 24 |
+
|
| 25 |
+
\section{Introduction}
|
| 26 |
+
As digital media consumption accelerates, the threat of fabricated content, deliberately crafted to mislead, has escalated dramatically. The challenge of automated fake news detection lies not only in accurately classifying claims as ``True'' or ``False'', but in providing actionable contexts and explanations to content moderators and end-users. Current state-of-the-art transformer models achieve high accuracy on standard datasets; however, their black-box nature provides little insight into \textit{why} a verdict was reached. Furthermore, they are highly susceptible to temporal drift—where a genuinely true article becomes factually outdated—and adversarial manipulation, such as credible-sounding linguistic structures applied to unverified sources.
|
| 27 |
+
|
| 28 |
+
To address these shortcomings, we developed TruthLens. TruthLens diverges from standard classification pipelines by acting as a programmatic misinformation analyst. It does not blindly trust model consensus; instead, it synthesizes AI probabilities with deterministic heuristics based on journalistic standards.
|
| 29 |
+
|
| 30 |
+
\section{Problem Definition}
|
| 31 |
+
Given an article comprising an optional headline ($H$), a body text ($T$), a source domain ($D$), and a publication date ($P$), the objective is to map the document to one of four explainable verdicts: \texttt{TRUE}, \texttt{UNCERTAIN}, \texttt{LIKELY FALSE}, or \texttt{FALSE}.
|
| 32 |
+
|
| 33 |
+
A secondary objective is to dynamically generate a confidence score ($C \in \{LOW, MEDIUM, HIGH\}$) and qualitative justifications (e.g., deductions applied, adversarial flags triggered). The system must account for edge cases such as missing publication dates, anonymous or typo-squatted domains, sensationalist headlines safely contradicting the body, and the verification of quoted entities.
|
| 34 |
+
|
| 35 |
+
\section{Methodology}
|
| 36 |
+
The TruthLens pipeline is partitioned into four distinct operational stages, culminating in a real-time 5-signal evaluation engine.
|
| 37 |
+
|
| 38 |
+
\subsection{Stage 1: Data Ingestion}
|
| 39 |
+
In the first stage, raw datasets are fetched, homogenized, and aggregated. Disparate datasets originally formatted for binary or multi-class detection are ingested to form a consolidated corpus. During this stage, critical metadata (such as publication timestamps and source URLs) is preserved if available, or flagged for contextual imputation later.
|
| 40 |
+
|
| 41 |
+
\subsection{Stage 2: Preprocessing \& Feature Extraction}
|
| 42 |
+
Text entries undergo rigorous cleaning to standardize capitalization, extract URLs, expand contractions, and remove non-alphanumeric noise without losing semantic meaning. We construct vocabulary indices and tokenization mappings necessary for deep learning backbones. Cleaned datasets are persisted to disk to accelerate experimental reiteration.
|
| 43 |
+
|
| 44 |
+
\subsection{Stage 3: Ensemble Training}
|
| 45 |
+
We utilize a heterogeneous ensemble of four models, each capturing distinct features of the input:
|
| 46 |
+
\begin{itemize}
|
| 47 |
+
\item \textbf{Logistic Regression:} A statistical baseline utilizing TF-IDF vectors to capture simple lexical patterns heavily correlated with misinformation.
|
| 48 |
+
\item \textbf{LSTM:} A recurrent neural network designed to capture sequential linguistic dependencies.
|
| 49 |
+
\item \textbf{DistilBERT \& RoBERTa:} High-capacity, attention-based deep contextual language models fine-tuned to detect semantic nuances and internal contradictions.
|
| 50 |
+
\end{itemize}
|
| 51 |
+
|
| 52 |
+
\subsection{Stage 4: 5-Signal Inference Framework}
|
| 53 |
+
The core innovation of TruthLens is the Stage 4 inference engine, which combines model predictions with four deterministic pillars:
|
| 54 |
+
\begin{enumerate}
|
| 55 |
+
\item \textbf{Source Credibility (30\%):} Evaluates the article domain against a known credible database, checks for explicit author attribution (bylines), and penalizes subtle spelling manipulations (typosquatting).
|
| 56 |
+
\item \textbf{Claim Verification (30\%):} Utilizes spaCy for Named Entity Recognition (NER), tracking the ratio of attributed quotes. The meta-classifier probability is blended dynamically with these entity-level checks.
|
| 57 |
+
\item \textbf{Linguistic Quality (20\%):} Applies rule-based deductions for sensationalism (e.g., excessive capitalized words, superlatives), passive voice overuse, and deploys DistilBERT to compute a cosine similarity check identifying headline-body contradictions.
|
| 58 |
+
\item \textbf{Temporal Freshness (10\%):} Operates via a dual-case system. Case A uses explicit publication dates with a non-linear decay function. Case B parses contextual temporal cues (e.g., ``yesterday'', current year) when explicit dates are absent.
|
| 59 |
+
\item \textbf{Model Vote Consensus (10\%):} Represents the pure ensemble vote ratio.
|
| 60 |
+
\end{enumerate}
|
| 61 |
+
|
| 62 |
+
Finally, adversarial overrides cap maximum scores at 25\% if critical journalistic standards are violated (e.g., ``Triple Anonymity''), ensuring weak models do not inadvertently legitimize fabricated articles. When freshness enters an ambiguous threshold, a Retrieval-Augmented Generation (RAG) module queries live web APIs to corroborate claims with recent index updates.
|
| 63 |
+
|
| 64 |
+
\section{Data Description}
|
| 65 |
+
The model is trained on a homogenized amalgamation of well-established fake news benchmarks, primarily the ISOT Fake News Dataset and the LIAR Dataset.
|
| 66 |
+
|
| 67 |
+
\subsection{Data Divisions \& Preprocessing}
|
| 68 |
+
The aggregated data is structured into a typical split architecture to ensure maximum generalizability:
|
| 69 |
+
\begin{itemize}
|
| 70 |
+
\item \textbf{Training Set (70\%):} Utilized for fitting the TF-IDF vectorizer, tuning the LSTM layer weights, and fine-tuning the Transformer heads.
|
| 71 |
+
\item \textbf{Validation Set (15\%):} Employed to monitor epoch loss, apply early stopping during deep model training, and tune the interpolation thresholds for the 5-signal weights.
|
| 72 |
+
\item \textbf{Holdout / Test Set (15\%):} Kept strictly isolated. Evaluation operates under the constraint of mapping the legacy truth classes to the newly defined 4-tier verdict system (\texttt{TRUE}/\texttt{UNCERTAIN} mapped to 1; \texttt{LIKELY FALSE}/\texttt{FALSE} mapped to 0).
|
| 73 |
+
\end{itemize}
|
| 74 |
+
|
| 75 |
+
\section{Results}
|
| 76 |
+
Integration of the 5-signal weighting framework demonstrated a marked improvement in explainability. While the baseline ensemble alone yielded high raw accuracy on synthetic datasets, it failed to identify maliciously injected adversarial examples (e.g., valid text attributed to ``c-n-n.com''). The complete infrastructure successfully caps adversarial scores, raising overall real-world robustness. On the isolated hold-out set, the system dynamically flags low-confidence articles, ensuring that unverifiable claims are strictly prevented from passing a theoretical ``publishing'' boundary. Live RAG integration further eliminated false-positives previously caused by temporal drift in the training data.
|
| 77 |
+
|
| 78 |
+
\section{Discussion}
|
| 79 |
+
TruthLens introduces a pivotal shift from pure binary classification to diagnostic evaluation in fake news detection. By hardcoding journalism heuristics (source validation, NER quoting, temporality) atop deep learning embeddings, the model successfully simulates a human analyst's workflow.
|
| 80 |
+
|
| 81 |
+
A notable outcome of our framework is the treatment of Confidence. By decoupling probability from confidence (e.g., lowering confidence for texts under 50 words or lacking entities), the user interfaces can visually decouple ``truthfulness'' from ``reliability''.
|
| 82 |
+
Future work includes expanding the Live RAG validation checks with multi-hop reasoning LLMs, and expanding the author-attribution regex engine to cover a broader range of international linguistic formats.
|
| 83 |
+
|
| 84 |
+
\end{document}
|
app.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import streamlit as st
|
| 8 |
+
|
| 9 |
+
_ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
+
if _ROOT not in sys.path:
|
| 11 |
+
sys.path.insert(0, _ROOT)
|
| 12 |
+
|
| 13 |
+
# ── Page config ──────────────────────────────────────────────────────────────
|
| 14 |
+
st.set_page_config(
|
| 15 |
+
page_title="TruthLens · Fake News Detector",
|
| 16 |
+
page_icon="🔍",
|
| 17 |
+
layout="wide",
|
| 18 |
+
initial_sidebar_state="collapsed",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# ── Global CSS ───────────────────────────────────────────────────────────────
|
| 22 |
+
st.markdown("""
|
| 23 |
+
<style>
|
| 24 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap');
|
| 25 |
+
|
| 26 |
+
/* ── Reset ── */
|
| 27 |
+
html, body, [data-testid="stAppViewContainer"] {
|
| 28 |
+
font-family: 'Inter', sans-serif;
|
| 29 |
+
background: #f4f6fb;
|
| 30 |
+
color: #1e293b;
|
| 31 |
+
}
|
| 32 |
+
[data-testid="stMain"] { background: #f4f6fb; }
|
| 33 |
+
.block-container {
|
| 34 |
+
padding-top: 2.5rem !important;
|
| 35 |
+
padding-bottom: 2rem !important;
|
| 36 |
+
max-width: 920px;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
/* ── Remove Streamlit chrome ── */
|
| 40 |
+
header[data-testid="stHeader"] { display: none; }
|
| 41 |
+
footer { display: none; }
|
| 42 |
+
#MainMenu { display: none; }
|
| 43 |
+
[data-testid="stSidebar"] { display: none; }
|
| 44 |
+
|
| 45 |
+
/* ── Predict button ── */
|
| 46 |
+
.stButton > button[kind="primary"] {
|
| 47 |
+
background: linear-gradient(135deg, #3b82f6 0%, #6366f1 100%) !important;
|
| 48 |
+
color: #fff !important;
|
| 49 |
+
border: none !important;
|
| 50 |
+
border-radius: 12px !important;
|
| 51 |
+
font-weight: 700 !important;
|
| 52 |
+
font-size: 1.05rem !important;
|
| 53 |
+
letter-spacing: 0.02em;
|
| 54 |
+
padding: 0.75rem 2rem !important;
|
| 55 |
+
transition: transform 0.15s, box-shadow 0.2s;
|
| 56 |
+
box-shadow: 0 4px 16px rgba(59,130,246,0.2);
|
| 57 |
+
}
|
| 58 |
+
.stButton > button[kind="primary"]:hover {
|
| 59 |
+
transform: translateY(-1px);
|
| 60 |
+
box-shadow: 0 6px 24px rgba(59,130,246,0.3) !important;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
/* ── Tab styling ── */
|
| 64 |
+
[data-testid="stTabs"] button {
|
| 65 |
+
color: #94a3b8 !important;
|
| 66 |
+
font-size: 0.92rem !important;
|
| 67 |
+
font-weight: 500 !important;
|
| 68 |
+
padding: 10px 20px !important;
|
| 69 |
+
}
|
| 70 |
+
[data-testid="stTabs"] button[aria-selected="true"] {
|
| 71 |
+
color: #1e293b !important;
|
| 72 |
+
border-bottom: 2px solid #3b82f6 !important;
|
| 73 |
+
font-weight: 600 !important;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
/* ── Verdict banner ── */
|
| 77 |
+
.verdict-box {
|
| 78 |
+
border-radius: 16px;
|
| 79 |
+
padding: 32px 36px;
|
| 80 |
+
margin-bottom: 28px;
|
| 81 |
+
display: flex;
|
| 82 |
+
align-items: center;
|
| 83 |
+
gap: 24px;
|
| 84 |
+
animation: fadeSlide 0.5s ease;
|
| 85 |
+
}
|
| 86 |
+
@keyframes fadeSlide {
|
| 87 |
+
from { opacity: 0; transform: translateY(-16px); }
|
| 88 |
+
to { opacity: 1; transform: translateY(0); }
|
| 89 |
+
}
|
| 90 |
+
.verdict-emoji { font-size: 3.5rem; line-height: 1; }
|
| 91 |
+
.verdict-label { font-size: 1.8rem; font-weight: 800; letter-spacing: -0.03em; }
|
| 92 |
+
.verdict-conf { font-size: 1rem; opacity: 0.85; margin-top: 6px; font-weight: 400; }
|
| 93 |
+
.verdict-explain { font-size: 0.88rem; color: #64748b; margin-top: 6px; line-height: 1.5; }
|
| 94 |
+
|
| 95 |
+
/* ── Info cards ── */
|
| 96 |
+
.info-card {
|
| 97 |
+
background: #ffffff;
|
| 98 |
+
border: 1px solid #e2e8f0;
|
| 99 |
+
border-radius: 12px;
|
| 100 |
+
padding: 20px 24px;
|
| 101 |
+
margin: 12px 0;
|
| 102 |
+
line-height: 1.6;
|
| 103 |
+
color: #475569;
|
| 104 |
+
}
|
| 105 |
+
.info-card b { color: #1e293b; }
|
| 106 |
+
|
| 107 |
+
/* ── Freshness bar ── */
|
| 108 |
+
.fresh-track { background: #e2e8f0; border-radius: 8px; height: 12px; margin: 10px 0 6px; overflow: hidden; }
|
| 109 |
+
.fresh-fill { height: 100%; border-radius: 8px; transition: width 0.8s ease; }
|
| 110 |
+
|
| 111 |
+
/* ── Source card ── */
|
| 112 |
+
.source-card {
|
| 113 |
+
background: #ffffff;
|
| 114 |
+
border: 1px solid #e2e8f0;
|
| 115 |
+
border-radius: 12px;
|
| 116 |
+
padding: 18px 22px;
|
| 117 |
+
margin: 10px 0;
|
| 118 |
+
display: flex;
|
| 119 |
+
justify-content: space-between;
|
| 120 |
+
align-items: flex-start;
|
| 121 |
+
gap: 16px;
|
| 122 |
+
}
|
| 123 |
+
.source-text { flex: 1; font-size: 0.88rem; line-height: 1.5; color: #475569; }
|
| 124 |
+
.source-score { text-align: center; min-width: 60px; }
|
| 125 |
+
.source-score-val { font-size: 1.4rem; font-weight: 700; font-family: 'Inter', sans-serif; }
|
| 126 |
+
.source-score-tag { font-size: 0.65rem; text-transform: uppercase; letter-spacing: 0.1em; margin-top: 4px; }
|
| 127 |
+
|
| 128 |
+
/* ── Hero ── */
|
| 129 |
+
.hero-wrap { text-align: center; padding: 60px 20px 40px; }
|
| 130 |
+
.hero-icon { font-size: 4rem; margin-bottom: 16px; }
|
| 131 |
+
.hero-title { font-size: 2.4rem; font-weight: 800; letter-spacing: -0.04em; color: #0f172a; }
|
| 132 |
+
.hero-sub { font-size: 1.05rem; color: #64748b; margin-top: 12px; line-height: 1.6; max-width: 520px; margin-left: auto; margin-right: auto; }
|
| 133 |
+
|
| 134 |
+
/* ── How-it-works ── */
|
| 135 |
+
.how-grid { display: grid; grid-template-columns: repeat(3, 1fr); gap: 16px; margin: 36px 0; }
|
| 136 |
+
.how-card {
|
| 137 |
+
background: #ffffff;
|
| 138 |
+
border: 1px solid #e2e8f0;
|
| 139 |
+
border-radius: 12px;
|
| 140 |
+
padding: 24px;
|
| 141 |
+
text-align: center;
|
| 142 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.04);
|
| 143 |
+
}
|
| 144 |
+
.how-num { font-size: 2rem; margin-bottom: 8px; }
|
| 145 |
+
.how-title { font-size: 0.95rem; font-weight: 600; margin-bottom: 6px; color: #0f172a; }
|
| 146 |
+
.how-desc { font-size: 0.82rem; color: #64748b; line-height: 1.5; }
|
| 147 |
+
|
| 148 |
+
/* ── Verdict legend ── */
|
| 149 |
+
.legend-row {
|
| 150 |
+
display: flex;
|
| 151 |
+
gap: 24px;
|
| 152 |
+
justify-content: center;
|
| 153 |
+
flex-wrap: wrap;
|
| 154 |
+
margin: 20px 0;
|
| 155 |
+
}
|
| 156 |
+
.legend-item { font-size: 0.85rem; color: #64748b; }
|
| 157 |
+
|
| 158 |
+
/* ── Metric overrides ── */
|
| 159 |
+
[data-testid="stMetric"] {
|
| 160 |
+
background: #ffffff;
|
| 161 |
+
border: 1px solid #e2e8f0;
|
| 162 |
+
border-radius: 10px;
|
| 163 |
+
padding: 14px 18px !important;
|
| 164 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.04);
|
| 165 |
+
}
|
| 166 |
+
[data-testid="stMetricLabel"] { color: #64748b !important; font-size: 0.78rem !important; }
|
| 167 |
+
[data-testid="stMetricValue"] { color: #0f172a !important; font-size: 1.3rem !important; }
|
| 168 |
+
|
| 169 |
+
/* ── Expander ── */
|
| 170 |
+
[data-testid="stExpander"] {
|
| 171 |
+
background: #ffffff !important;
|
| 172 |
+
border: 1px solid #e2e8f0 !important;
|
| 173 |
+
border-radius: 10px !important;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
/* ── Text inputs ── */
|
| 177 |
+
[data-testid="stTextInput"] input, [data-testid="stTextArea"] textarea {
|
| 178 |
+
background: #ffffff !important;
|
| 179 |
+
border: 1px solid #cbd5e1 !important;
|
| 180 |
+
border-radius: 8px !important;
|
| 181 |
+
color: #1e293b !important;
|
| 182 |
+
}
|
| 183 |
+
[data-testid="stTextInput"] input:focus, [data-testid="stTextArea"] textarea:focus {
|
| 184 |
+
border-color: #3b82f6 !important;
|
| 185 |
+
box-shadow: 0 0 0 2px rgba(59,130,246,0.15) !important;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/* ── Select slider / radio ── */
|
| 189 |
+
[data-testid="stSlider"] label, .stRadio label { color: #475569 !important; }
|
| 190 |
+
|
| 191 |
+
/* ── Progress bar ── */
|
| 192 |
+
[data-testid="stProgress"] > div > div > div > div { background: linear-gradient(90deg, #3b82f6, #6366f1) !important; }
|
| 193 |
+
</style>
|
| 194 |
+
""", unsafe_allow_html=True)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# ── Cached inference loader ──────────────────────────────────────────────────
|
| 198 |
+
@st.cache_resource(show_spinner=False)
|
| 199 |
+
def load_pipeline():
|
| 200 |
+
from src.stage4_inference import predict_article, ModelNotTrainedError
|
| 201 |
+
return predict_article, ModelNotTrainedError
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ── Session state ────────────────────────────────────────────────────────────
|
| 205 |
+
for k, v in [("analyzed", False), ("last_result", None), ("last_input", "")]:
|
| 206 |
+
if k not in st.session_state:
|
| 207 |
+
st.session_state[k] = v
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# =============================================================================
|
| 211 |
+
# LANDING PAGE (shown before any analysis)
|
| 212 |
+
# =============================================================================
|
| 213 |
+
if not st.session_state["analyzed"]:
|
| 214 |
+
|
| 215 |
+
# ── Hero section ──
|
| 216 |
+
st.markdown("""
|
| 217 |
+
<div class="hero-wrap">
|
| 218 |
+
<div class="hero-icon">🔍</div>
|
| 219 |
+
<div class="hero-title">TruthLens</div>
|
| 220 |
+
<div class="hero-sub">
|
| 221 |
+
Paste any news article or drop a URL — our AI will tell you
|
| 222 |
+
if it's real, fake, or outdated in seconds.
|
| 223 |
+
</div>
|
| 224 |
+
</div>
|
| 225 |
+
""", unsafe_allow_html=True)
|
| 226 |
+
|
| 227 |
+
# ── How it works ──
|
| 228 |
+
st.markdown("""
|
| 229 |
+
<div class="how-grid">
|
| 230 |
+
<div class="how-card">
|
| 231 |
+
<div class="how-num">📋</div>
|
| 232 |
+
<div class="how-title">Paste or Link</div>
|
| 233 |
+
<div class="how-desc">Drop in the article text or a URL. We'll extract everything automatically.</div>
|
| 234 |
+
</div>
|
| 235 |
+
<div class="how-card">
|
| 236 |
+
<div class="how-num">⚡</div>
|
| 237 |
+
<div class="how-title">Instant Analysis</div>
|
| 238 |
+
<div class="how-desc">Our AI analyzes language patterns, checks freshness, and searches live sources.</div>
|
| 239 |
+
</div>
|
| 240 |
+
<div class="how-card">
|
| 241 |
+
<div class="how-num">✅</div>
|
| 242 |
+
<div class="how-title">Get Your Verdict</div>
|
| 243 |
+
<div class="how-desc">See a clear REAL / FAKE / OUTDATED verdict with a confidence score and explanation.</div>
|
| 244 |
+
</div>
|
| 245 |
+
</div>
|
| 246 |
+
""", unsafe_allow_html=True)
|
| 247 |
+
|
| 248 |
+
# ── Input area ──
|
| 249 |
+
input_tab = st.radio("How would you like to provide the article?",
|
| 250 |
+
["✍️ Write or paste text", "🔗 Paste a URL"],
|
| 251 |
+
horizontal=True, label_visibility="visible")
|
| 252 |
+
|
| 253 |
+
input_text, input_title, input_url, input_date, input_domain = "", "", "", "", ""
|
| 254 |
+
|
| 255 |
+
if input_tab == "✍️ Write or paste text":
|
| 256 |
+
input_title = st.text_input("Headline (optional)",
|
| 257 |
+
placeholder="e.g. Breaking: Scientists discover high-speed interstellar travel")
|
| 258 |
+
input_text = st.text_area("Article content",
|
| 259 |
+
height=180,
|
| 260 |
+
placeholder="Paste the full article body here…")
|
| 261 |
+
# ── Auto-extract title from pasted text if headline field is empty ──
|
| 262 |
+
if not input_title.strip() and input_text.strip():
|
| 263 |
+
if input_text.lower().startswith("title:"):
|
| 264 |
+
lines = input_text.split("\n", 1)
|
| 265 |
+
input_title = lines[0].replace("Title:", "").replace("title:", "").strip()
|
| 266 |
+
input_text = lines[1].replace("Body:", "").replace("body:", "").strip() if len(lines) > 1 else ""
|
| 267 |
+
else:
|
| 268 |
+
# Fallback: first sentence is title
|
| 269 |
+
input_title = input_text.split(".")[0].strip()
|
| 270 |
+
|
| 271 |
+
else:
|
| 272 |
+
input_url = st.text_input("Article URL",
|
| 273 |
+
placeholder="https://www.example.com/news/breaking-story")
|
| 274 |
+
st.caption("We'll automatically extract the title, body, and publish date.")
|
| 275 |
+
|
| 276 |
+
# ── Analysis mode (kept minimal — user doesn't need to understand internals)
|
| 277 |
+
speed = st.select_slider("Analysis depth",
|
| 278 |
+
options=["Quick", "Standard", "Deep"],
|
| 279 |
+
value="Deep",
|
| 280 |
+
help="Quick ≈ 2 sec · Standard ≈ 10 sec · Deep ≈ 30 sec (most accurate)")
|
| 281 |
+
speed_map = {"Quick": "fast", "Standard": "balanced", "Deep": "full"}
|
| 282 |
+
selected_mode = speed_map[speed]
|
| 283 |
+
|
| 284 |
+
# ── Predict button ──
|
| 285 |
+
predict_clicked = st.button("🔍 Check this article", use_container_width=True, type="primary")
|
| 286 |
+
|
| 287 |
+
# ── Verdict legend ──
|
| 288 |
+
st.markdown("""
|
| 289 |
+
<div class="legend-row">
|
| 290 |
+
<div class="legend-item">🟢 Verified True</div>
|
| 291 |
+
<div class="legend-item">🔴 Likely Fake</div>
|
| 292 |
+
<div class="legend-item">🟡 Outdated</div>
|
| 293 |
+
<div class="legend-item">🟠 Needs Review</div>
|
| 294 |
+
</div>
|
| 295 |
+
""", unsafe_allow_html=True)
|
| 296 |
+
|
| 297 |
+
# ── Execute prediction ──
|
| 298 |
+
if predict_clicked:
|
| 299 |
+
# Validate
|
| 300 |
+
if input_tab == "✍️ Write or paste text":
|
| 301 |
+
if not input_text or len(input_text.split()) < 10:
|
| 302 |
+
st.warning("⚠️ Please paste at least a few sentences so we can analyze it properly.")
|
| 303 |
+
st.stop()
|
| 304 |
+
else:
|
| 305 |
+
if not input_url:
|
| 306 |
+
st.warning("⚠️ Please enter a URL first.")
|
| 307 |
+
st.stop()
|
| 308 |
+
try:
|
| 309 |
+
import newspaper
|
| 310 |
+
from urllib.parse import urlparse
|
| 311 |
+
art = newspaper.Article(input_url)
|
| 312 |
+
art.download()
|
| 313 |
+
art.parse()
|
| 314 |
+
input_title = art.title or ""
|
| 315 |
+
input_text = art.text or ""
|
| 316 |
+
input_date = art.publish_date.isoformat() if art.publish_date else ""
|
| 317 |
+
input_domain = urlparse(input_url).netloc
|
| 318 |
+
if len(input_text.split()) < 10:
|
| 319 |
+
st.warning("⚠️ Couldn't extract enough text from that URL. Try pasting the article directly.")
|
| 320 |
+
st.stop()
|
| 321 |
+
except Exception:
|
| 322 |
+
st.error("❌ Couldn't fetch that URL. Please check the link or paste the text directly.")
|
| 323 |
+
st.stop()
|
| 324 |
+
|
| 325 |
+
predict_article, ModelNotTrainedError = load_pipeline()
|
| 326 |
+
|
| 327 |
+
with st.status("🔍 Analyzing article…", expanded=True) as status:
|
| 328 |
+
st.write("📖 Reading article…")
|
| 329 |
+
time.sleep(0.3)
|
| 330 |
+
st.write("🧠 Running AI analysis…")
|
| 331 |
+
try:
|
| 332 |
+
result = predict_article(
|
| 333 |
+
title=input_title,
|
| 334 |
+
text=input_text,
|
| 335 |
+
source_domain=input_domain,
|
| 336 |
+
published_date=input_date,
|
| 337 |
+
mode=selected_mode,
|
| 338 |
+
)
|
| 339 |
+
st.write("🕐 Checking article freshness…")
|
| 340 |
+
st.write("🌐 Searching live sources…")
|
| 341 |
+
status.update(label="✅ Done!", state="complete")
|
| 342 |
+
st.session_state["last_result"] = result
|
| 343 |
+
st.session_state["last_input"] = input_text
|
| 344 |
+
st.session_state["analyzed"] = True
|
| 345 |
+
st.rerun()
|
| 346 |
+
except ModelNotTrainedError:
|
| 347 |
+
status.update(label="❌ Setup required", state="error")
|
| 348 |
+
st.error("The AI models haven't been trained yet.")
|
| 349 |
+
st.info("Ask your administrator to run: `python run_pipeline.py --stage 1 2 3`")
|
| 350 |
+
st.stop()
|
| 351 |
+
except Exception as e:
|
| 352 |
+
status.update(label="❌ Error", state="error")
|
| 353 |
+
st.error(f"Something went wrong: {e}")
|
| 354 |
+
st.stop()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# =============================================================================
|
| 359 |
+
# RESULTS PAGE (shown after analysis)
|
| 360 |
+
# =============================================================================
|
| 361 |
+
else:
|
| 362 |
+
res = st.session_state["last_result"]
|
| 363 |
+
verdict = res.get("verdict", "UNKNOWN")
|
| 364 |
+
final_score = res.get("final_score", 0.0)
|
| 365 |
+
scores = res.get("scores", {})
|
| 366 |
+
confidence = res.get("confidence", "MEDIUM")
|
| 367 |
+
action = res.get("recommended_action", "Flag for review")
|
| 368 |
+
top_reasons = res.get("top_reasons", [])
|
| 369 |
+
missing_signals = res.get("missing_signals", [])
|
| 370 |
+
adv_flags = res.get("adversarial_flags", [])
|
| 371 |
+
wc = res.get("word_count", 0)
|
| 372 |
+
probas = res.get("base_model_probas", {})
|
| 373 |
+
votes = res.get("base_model_votes", {})
|
| 374 |
+
fresh_case = res.get("freshness_case", "B")
|
| 375 |
+
fresh_signals = res.get("freshness_signals_found", [])
|
| 376 |
+
deductions = res.get("deductions_applied", [])
|
| 377 |
+
entities = res.get("entities_found", [])
|
| 378 |
+
|
| 379 |
+
# ── Map verdict to display ──
|
| 380 |
+
V = {
|
| 381 |
+
"TRUE": {"bg":"#f0fdf4", "bdr":"#86efac", "icon":"🟢", "label":"This appears to be true", "color":"#15803d",
|
| 382 |
+
"explain":"Source, claims, language, and AI models all align with credible journalism."},
|
| 383 |
+
"UNCERTAIN": {"bg":"#fff7ed", "bdr":"#fdba74", "icon":"🟠", "label":"Uncertain — needs review", "color":"#c2410c",
|
| 384 |
+
"explain":"Mixed signals detected. We recommend verifying the sources yourself before sharing."},
|
| 385 |
+
"LIKELY FALSE": {"bg":"#fef2f2", "bdr":"#fca5a5", "icon":"🔴", "label":"Likely false", "color":"#b91c1c",
|
| 386 |
+
"explain":"Multiple signals indicate this content may be fabricated or misleading."},
|
| 387 |
+
"FALSE": {"bg":"#fef2f2", "bdr":"#fca5a5", "icon":"⛔", "label":"This looks fake", "color":"#991b1b",
|
| 388 |
+
"explain":"Strong evidence of misinformation. Do not share without independent verification."},
|
| 389 |
+
}
|
| 390 |
+
vc = V.get(verdict, {"bg":"#f8fafc","bdr":"#cbd5e1","icon":"⚪","label":verdict,"color":"#475569",
|
| 391 |
+
"explain":"Analysis complete."})
|
| 392 |
+
|
| 393 |
+
# ── Verdict banner ──
|
| 394 |
+
score_pct = final_score * 100
|
| 395 |
+
st.markdown(f"""
|
| 396 |
+
<div class="verdict-box" style="background:{vc['bg']}; border:1px solid {vc['bdr']};">
|
| 397 |
+
<div class="verdict-emoji">{vc['icon']}</div>
|
| 398 |
+
<div>
|
| 399 |
+
<div class="verdict-label" style="color:{vc['color']};">{vc['label']}</div>
|
| 400 |
+
<div class="verdict-conf" style="color:{vc['color']};">Score: {score_pct:.0f}% · Confidence: {confidence}</div>
|
| 401 |
+
<div class="verdict-explain">{vc['explain']}</div>
|
| 402 |
+
</div>
|
| 403 |
+
</div>
|
| 404 |
+
""", unsafe_allow_html=True)
|
| 405 |
+
|
| 406 |
+
# ── Recommended action badge ──
|
| 407 |
+
action_colors = {
|
| 408 |
+
"Publish": ("#f0fdf4", "#15803d"),
|
| 409 |
+
"Flag for review": ("#fff7ed", "#c2410c"),
|
| 410 |
+
"Suppress": ("#fef2f2", "#b91c1c"),
|
| 411 |
+
"Escalate": ("#fef2f2", "#991b1b"),
|
| 412 |
+
}
|
| 413 |
+
abg, acol = action_colors.get(action, ("#f8fafc", "#475569"))
|
| 414 |
+
st.markdown(f"""
|
| 415 |
+
<div style="background:{abg}; border-radius:8px; padding:10px 16px; display:inline-block; margin-bottom:24px;">
|
| 416 |
+
<span style="font-weight:600; color:{acol};">Recommended: {action}</span>
|
| 417 |
+
</div>
|
| 418 |
+
""", unsafe_allow_html=True)
|
| 419 |
+
|
| 420 |
+
# ── Tabs ──
|
| 421 |
+
tab_why, tab_fresh, tab_sources, tab_details = st.tabs(
|
| 422 |
+
["🧠 Why this verdict?", "🕐 Freshness", "🌐 Live sources", "📋 Details"]
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# ── TAB 1: Why this verdict ──────────────────────────────────────────
|
| 426 |
+
with tab_why:
|
| 427 |
+
|
| 428 |
+
# ── 5-Signal Score Breakdown ──
|
| 429 |
+
st.markdown("#### Signal Breakdown")
|
| 430 |
+
SIGNAL_INFO = [
|
| 431 |
+
("Source", "source", "Is the outlet known and accountable?"),
|
| 432 |
+
("Claims", "claim", "Are facts verifiable with named entities?"),
|
| 433 |
+
("Language", "linguistic", "Is the writing neutral and attributed?"),
|
| 434 |
+
("Freshness", "freshness", "How recent is the content?"),
|
| 435 |
+
("AI Models", "model_vote", "What do the AI models think?"),
|
| 436 |
+
]
|
| 437 |
+
WEIGHTS = {"source": "30%", "claim": "30%", "linguistic": "20%", "freshness": "10%", "model_vote": "10%"}
|
| 438 |
+
|
| 439 |
+
cols = st.columns(5)
|
| 440 |
+
for i, (label, key, desc) in enumerate(SIGNAL_INFO):
|
| 441 |
+
val = scores.get(key, 0.0)
|
| 442 |
+
pct = val * 100
|
| 443 |
+
if pct >= 70:
|
| 444 |
+
col_hex = "#15803d"
|
| 445 |
+
elif pct >= 50:
|
| 446 |
+
col_hex = "#ca8a04"
|
| 447 |
+
else:
|
| 448 |
+
col_hex = "#b91c1c"
|
| 449 |
+
with cols[i]:
|
| 450 |
+
st.markdown(f"""
|
| 451 |
+
<div style="text-align:center; background:#ffffff; border:1px solid #e2e8f0;
|
| 452 |
+
border-radius:10px; padding:16px 8px; box-shadow:0 1px 3px rgba(0,0,0,0.04);">
|
| 453 |
+
<div style="font-size:1.6rem; font-weight:800; color:{col_hex};">{pct:.0f}%</div>
|
| 454 |
+
<div style="font-size:0.85rem; font-weight:600; color:#0f172a; margin-top:4px;">{label}</div>
|
| 455 |
+
<div style="font-size:0.7rem; color:#94a3b8; margin-top:2px;">Weight: {WEIGHTS[key]}</div>
|
| 456 |
+
</div>
|
| 457 |
+
""", unsafe_allow_html=True)
|
| 458 |
+
|
| 459 |
+
st.markdown("")
|
| 460 |
+
|
| 461 |
+
# ── Progress bars for each signal ──
|
| 462 |
+
for label, key, desc in SIGNAL_INFO:
|
| 463 |
+
val = scores.get(key, 0.0)
|
| 464 |
+
st.caption(f"**{label}** — {desc}")
|
| 465 |
+
st.progress(min(val, 1.0))
|
| 466 |
+
|
| 467 |
+
st.markdown("---")
|
| 468 |
+
|
| 469 |
+
# ── Top Reasons ──
|
| 470 |
+
if top_reasons:
|
| 471 |
+
st.markdown("#### Key Factors")
|
| 472 |
+
for r in top_reasons:
|
| 473 |
+
if any(neg in r.lower() for neg in ["fake", "false", "unknown", "not", "manipulation", "adversarial", "sensationalism", "reduces", "could not", "inconsistent", "missing"]):
|
| 474 |
+
st.markdown(f"🔴 {r}")
|
| 475 |
+
else:
|
| 476 |
+
st.markdown(f"🟢 {r}")
|
| 477 |
+
|
| 478 |
+
st.markdown("---")
|
| 479 |
+
|
| 480 |
+
# ── What did each AI model think? ──
|
| 481 |
+
st.markdown("#### AI Model Votes")
|
| 482 |
+
MODEL_NAMES = [
|
| 483 |
+
("Statistical", "logistic", "lr_proba"),
|
| 484 |
+
("Language", "lstm", "lstm_proba"),
|
| 485 |
+
("Deep A", "distilbert", "distilbert_proba"),
|
| 486 |
+
("Deep B", "roberta", "roberta_proba"),
|
| 487 |
+
]
|
| 488 |
+
mcols = st.columns(len(MODEL_NAMES))
|
| 489 |
+
for i, (nice_name, vote_key, pk) in enumerate(MODEL_NAMES):
|
| 490 |
+
vote_val = votes.get(vote_key)
|
| 491 |
+
prob_val = probas.get(pk)
|
| 492 |
+
with mcols[i]:
|
| 493 |
+
if vote_val is None or prob_val is None or np.isnan(prob_val):
|
| 494 |
+
st.metric(nice_name, "Skipped")
|
| 495 |
+
else:
|
| 496 |
+
lbl = "Real" if int(vote_val) == 1 else "Fake"
|
| 497 |
+
st.metric(nice_name, lbl, f"{prob_val*100:.0f}%")
|
| 498 |
+
|
| 499 |
+
if res.get("short_text_warning"):
|
| 500 |
+
st.warning("⚠️ Short article (under 50 words) — confidence is dampened.")
|
| 501 |
+
st.caption(f"Article length: {wc} words")
|
| 502 |
+
|
| 503 |
+
# ── TAB 2: Freshness ─────────────────────────────────────────────────
|
| 504 |
+
with tab_fresh:
|
| 505 |
+
fresh_val = scores.get("freshness", 0.5)
|
| 506 |
+
bar_pct = int(fresh_val * 100)
|
| 507 |
+
|
| 508 |
+
if fresh_val >= 0.70:
|
| 509 |
+
fbg, flbl, fdesc = "#f0fdf4", "🟢 Fresh", "This article appears to be recent."
|
| 510 |
+
fbar = "#16a34a"
|
| 511 |
+
elif fresh_val >= 0.40:
|
| 512 |
+
fbg, flbl, fdesc = "#fefce8", "🟡 Moderate", "Article may not be very recent."
|
| 513 |
+
fbar = "#ca8a04"
|
| 514 |
+
else:
|
| 515 |
+
fbg, flbl, fdesc = "#fef2f2", "🔴 Outdated", "This article appears to be old."
|
| 516 |
+
fbar = "#dc2626"
|
| 517 |
+
|
| 518 |
+
st.markdown(f"""
|
| 519 |
+
<div style="background:{fbg}; border-radius:12px; padding:20px 24px; margin-bottom:20px;">
|
| 520 |
+
<div style="font-size:1.2rem; font-weight:600;">{flbl}</div>
|
| 521 |
+
<div style="font-size:0.88rem; color:#64748b; margin-top:8px;">{fdesc}</div>
|
| 522 |
+
<div class="fresh-track">
|
| 523 |
+
<div class="fresh-fill" style="width:{bar_pct}%; background:{fbar};"></div>
|
| 524 |
+
</div>
|
| 525 |
+
<div style="font-size:0.8rem; color:#6b7280; margin-top:4px;">Freshness: {fresh_val:.0%}</div>
|
| 526 |
+
</div>
|
| 527 |
+
""", unsafe_allow_html=True)
|
| 528 |
+
|
| 529 |
+
# Case indicator
|
| 530 |
+
case_label = "📅 Date-based scoring" if fresh_case == "A" else "🔎 Contextual signal scanning (no date found)"
|
| 531 |
+
st.markdown(f"""
|
| 532 |
+
<div class="info-card">
|
| 533 |
+
<b>Method:</b> {case_label}
|
| 534 |
+
</div>
|
| 535 |
+
""", unsafe_allow_html=True)
|
| 536 |
+
|
| 537 |
+
# Signals found (Case B)
|
| 538 |
+
if fresh_case == "B" and fresh_signals:
|
| 539 |
+
st.markdown("**Signals detected:**")
|
| 540 |
+
for sig in fresh_signals:
|
| 541 |
+
st.markdown(f"✅ {sig}")
|
| 542 |
+
elif fresh_case == "B":
|
| 543 |
+
st.caption("No contextual freshness signals were found in the article text.")
|
| 544 |
+
|
| 545 |
+
# ── TAB 3: Live sources ──────────────────────────────────────────────
|
| 546 |
+
with tab_sources:
|
| 547 |
+
rag_data = res.get("rag_results")
|
| 548 |
+
source_list = []
|
| 549 |
+
if isinstance(rag_data, dict):
|
| 550 |
+
source_list = rag_data.get("data", [])
|
| 551 |
+
elif isinstance(rag_data, list):
|
| 552 |
+
source_list = rag_data
|
| 553 |
+
|
| 554 |
+
if not source_list:
|
| 555 |
+
st.markdown("""
|
| 556 |
+
<div class="info-card">
|
| 557 |
+
<b>Live source check was not triggered</b><br><br>
|
| 558 |
+
Live source verification runs when freshness is ambiguous.
|
| 559 |
+
This analysis relied on the 5-signal scoring framework instead.
|
| 560 |
+
</div>
|
| 561 |
+
""", unsafe_allow_html=True)
|
| 562 |
+
else:
|
| 563 |
+
st.caption(f"Compared against {len(source_list)} live web results.")
|
| 564 |
+
for item in source_list:
|
| 565 |
+
snippet = item.get("snippet", "")
|
| 566 |
+
sim = item.get("similarity", 0.0)
|
| 567 |
+
if sim > 0.65:
|
| 568 |
+
sc_col, sc_tag = "#16a34a", "Supports"
|
| 569 |
+
elif sim < 0.30:
|
| 570 |
+
sc_col, sc_tag = "#dc2626", "Conflicts"
|
| 571 |
+
else:
|
| 572 |
+
sc_col, sc_tag = "#ca8a04", "Neutral"
|
| 573 |
+
|
| 574 |
+
st.markdown(f"""
|
| 575 |
+
<div class="source-card">
|
| 576 |
+
<div class="source-text">{snippet}</div>
|
| 577 |
+
<div class="source-score">
|
| 578 |
+
<div class="source-score-val" style="color:{sc_col};">{sim:.0%}</div>
|
| 579 |
+
<div class="source-score-tag" style="color:{sc_col};">{sc_tag}</div>
|
| 580 |
+
</div>
|
| 581 |
+
</div>
|
| 582 |
+
""", unsafe_allow_html=True)
|
| 583 |
+
|
| 584 |
+
# ── TAB 4: Details ───────────────────────────────────────��───────────
|
| 585 |
+
with tab_details:
|
| 586 |
+
|
| 587 |
+
# ── Missing Signals ──
|
| 588 |
+
if missing_signals:
|
| 589 |
+
st.markdown("#### ⚠️ Missing Signals")
|
| 590 |
+
for ms in missing_signals:
|
| 591 |
+
st.markdown(f"- {ms}")
|
| 592 |
+
st.markdown("")
|
| 593 |
+
|
| 594 |
+
# ── Adversarial Flags ──
|
| 595 |
+
if adv_flags:
|
| 596 |
+
st.markdown("#### 🚩 Adversarial Flags Triggered")
|
| 597 |
+
for af in adv_flags:
|
| 598 |
+
st.error(f"🚩 {af}")
|
| 599 |
+
st.caption("Adversarial flags cap the final score at 25% maximum.")
|
| 600 |
+
st.markdown("")
|
| 601 |
+
|
| 602 |
+
# ── Linguistic Deductions ──
|
| 603 |
+
if deductions:
|
| 604 |
+
st.markdown("#### 📝 Linguistic Deductions")
|
| 605 |
+
for d in deductions:
|
| 606 |
+
st.markdown(f"- {d}")
|
| 607 |
+
st.markdown("")
|
| 608 |
+
|
| 609 |
+
# ── Named Entities Found ──
|
| 610 |
+
if entities:
|
| 611 |
+
st.markdown("#### 🏷️ Entities Detected")
|
| 612 |
+
st.markdown(", ".join([f"`{e}`" for e in entities]))
|
| 613 |
+
q_attr = res.get("quotes_attributed", 0)
|
| 614 |
+
q_total = res.get("quotes_total", 0)
|
| 615 |
+
if q_total > 0:
|
| 616 |
+
st.caption(f"Quotes: {q_attr}/{q_total} attributed")
|
| 617 |
+
st.markdown("")
|
| 618 |
+
|
| 619 |
+
# ── Summary Table ──
|
| 620 |
+
st.markdown("#### Analysis Summary")
|
| 621 |
+
rows = [
|
| 622 |
+
("Verdict", vc["label"]),
|
| 623 |
+
("Final Score", f"{score_pct:.1f}%"),
|
| 624 |
+
("Confidence", confidence),
|
| 625 |
+
("Action", action),
|
| 626 |
+
("Word Count", str(wc)),
|
| 627 |
+
("Freshness", f"{scores.get('freshness', 0):.0%} (Case {fresh_case})"),
|
| 628 |
+
]
|
| 629 |
+
df_rep = pd.DataFrame(rows, columns=["Field", "Value"])
|
| 630 |
+
st.dataframe(df_rep, use_container_width=True, hide_index=True, height=240)
|
| 631 |
+
|
| 632 |
+
with st.expander("🔧 Raw JSON (for developers)"):
|
| 633 |
+
st.code(json.dumps(res, indent=2, default=str), language="json")
|
| 634 |
+
|
| 635 |
+
# ── Analyze another ──
|
| 636 |
+
st.markdown("---")
|
| 637 |
+
if st.button("← Analyze another article", use_container_width=True):
|
| 638 |
+
st.session_state["analyzed"] = False
|
| 639 |
+
st.session_state["last_result"] = None
|
| 640 |
+
st.rerun()
|
| 641 |
+
|
config/config.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ═══════════════════════════════════════════════════════════
|
| 2 |
+
# Fake News Detection System — Configuration
|
| 3 |
+
# ═══════════════════════════════════════════════════════════
|
| 4 |
+
|
| 5 |
+
paths:
|
| 6 |
+
# Raw dataset root (relative to project root)
|
| 7 |
+
dataset_root: "../Dataset"
|
| 8 |
+
# Processed data output
|
| 9 |
+
processed_dir: "data/processed"
|
| 10 |
+
# Train/val/test splits
|
| 11 |
+
splits_dir: "data/splits"
|
| 12 |
+
# Saved model weights
|
| 13 |
+
models_dir: "models/saved"
|
| 14 |
+
# Logs
|
| 15 |
+
logs_dir: "logs"
|
| 16 |
+
# GloVe embeddings
|
| 17 |
+
glove_path: "../glove.6B.100d.txt"
|
| 18 |
+
|
| 19 |
+
dataset:
|
| 20 |
+
min_domain_samples: 20
|
| 21 |
+
dedup_threshold: 0.92
|
| 22 |
+
dedup_batch_size: 64
|
| 23 |
+
|
| 24 |
+
preprocessing:
|
| 25 |
+
max_tfidf_features: 50000
|
| 26 |
+
lstm_max_len: 256
|
| 27 |
+
bert_max_len: 128
|
| 28 |
+
short_text_threshold: 50
|
| 29 |
+
min_word_count: 3
|
| 30 |
+
|
| 31 |
+
training:
|
| 32 |
+
lstm_batch_size: 64
|
| 33 |
+
lstm_epochs: 10
|
| 34 |
+
bert_batch_size: 8
|
| 35 |
+
bert_epochs: 3
|
| 36 |
+
lr_learning_rate: 2e-5
|
| 37 |
+
roberta_learning_rate: 1e-5
|
| 38 |
+
|
| 39 |
+
inference:
|
| 40 |
+
true_threshold: 0.55
|
| 41 |
+
fresh_threshold: 0.70
|
| 42 |
+
outdated_threshold: 0.40
|
| 43 |
+
undated_freshness_score: 0.35
|
| 44 |
+
max_multiplier: 10
|
| 45 |
+
|
| 46 |
+
rag:
|
| 47 |
+
top_k: 5
|
| 48 |
+
max_result_age_days: 90
|
| 49 |
+
support_threshold: 0.65
|
| 50 |
+
conflict_threshold: 0.30
|
| 51 |
+
|
| 52 |
+
drift:
|
| 53 |
+
word_count_std_alert: 2.0
|
| 54 |
+
fake_ratio_upper: 0.80
|
| 55 |
+
fake_ratio_lower: 0.20
|
| 56 |
+
rolling_window: 100
|
| 57 |
+
|
| 58 |
+
holdout:
|
| 59 |
+
stratified_test_size: 0.10
|
| 60 |
+
random_state: 42
|
models/saved/distilbert_model/cm_distil.png
ADDED
|
models/saved/distilbert_model/config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"DistilBertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.1,
|
| 7 |
+
"bos_token_id": null,
|
| 8 |
+
"dim": 768,
|
| 9 |
+
"dropout": 0.1,
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"eos_token_id": null,
|
| 12 |
+
"hidden_dim": 3072,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"max_position_embeddings": 512,
|
| 15 |
+
"model_type": "distilbert",
|
| 16 |
+
"n_heads": 12,
|
| 17 |
+
"n_layers": 6,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"problem_type": "single_label_classification",
|
| 20 |
+
"qa_dropout": 0.1,
|
| 21 |
+
"seq_classif_dropout": 0.2,
|
| 22 |
+
"sinusoidal_pos_embds": false,
|
| 23 |
+
"tie_weights_": true,
|
| 24 |
+
"tie_word_embeddings": true,
|
| 25 |
+
"transformers_version": "5.5.3",
|
| 26 |
+
"use_cache": false,
|
| 27 |
+
"vocab_size": 30522
|
| 28 |
+
}
|
models/saved/distilbert_model/distilbert_oof.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe20e00deb86d90904d65166c92605f5afbe923986eba9ea9d980cf6c2555992
|
| 3 |
+
size 161240
|
models/saved/distilbert_model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd61e968af64237e2024f8ccfbeb8e7df658f7bde6a3310826ad7767dcd1118b
|
| 3 |
+
size 267832560
|
models/saved/distilbert_model/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/saved/distilbert_model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"cls_token": "[CLS]",
|
| 4 |
+
"do_lower_case": true,
|
| 5 |
+
"is_local": false,
|
| 6 |
+
"mask_token": "[MASK]",
|
| 7 |
+
"model_max_length": 512,
|
| 8 |
+
"pad_token": "[PAD]",
|
| 9 |
+
"sep_token": "[SEP]",
|
| 10 |
+
"strip_accents": null,
|
| 11 |
+
"tokenize_chinese_chars": true,
|
| 12 |
+
"tokenizer_class": "BertTokenizer",
|
| 13 |
+
"unk_token": "[UNK]"
|
| 14 |
+
}
|
models/saved/distilbert_model/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4bc15d538a51fdb305b84c6674de46ba456c3288bbf8030c105f2034004edc06
|
| 3 |
+
size 5329
|
models/saved/logistic_model/cm.png
ADDED
|
models/saved/logistic_model/logistic_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:07d15b87ba7fc5e352c37141f8ec2f4a182fff0e8ff5efa26ef6f8cef1bd9db3
|
| 3 |
+
size 2383153
|
models/saved/logistic_model/lr_oof.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02da29135c36a6752466dbe6b18ef62dbfdb4dca131b035546ce1720ba9eae69
|
| 3 |
+
size 322352
|
models/saved/logistic_model/metrics.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"roc_auc": 0.9668348166617273,
|
| 3 |
+
"bucket_accuracy": {
|
| 4 |
+
"short": 0.6078748651564185,
|
| 5 |
+
"medium": 0.9741119807344973,
|
| 6 |
+
"long": 0.986362371277484
|
| 7 |
+
}
|
| 8 |
+
}
|
models/saved/lstm_model/cm.png
ADDED
|
models/saved/lstm_model/lstm_oof.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f26e0d9f8ac71e5d60591c007e07e7a164c25de363dedeb73769a521dc548b4
|
| 3 |
+
size 161240
|
models/saved/lstm_model/metrics.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"roc_auc": 0.9730530691595465,
|
| 3 |
+
"bucket_accuracy": {
|
| 4 |
+
"short": 0.6197411003236246,
|
| 5 |
+
"medium": 0.9704996989765202,
|
| 6 |
+
"long": 0.9944336209295853
|
| 7 |
+
}
|
| 8 |
+
}
|
models/saved/lstm_model/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5fb0cf853e49b9410620e76842792df7ac50ace32b47bb16d23a10d29165534
|
| 3 |
+
size 21639159
|
models/saved/meta_classifier/cm_meta.png
ADDED
|
models/saved/meta_classifier/meta_classifier.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:62807b550c6db9514b69c346a497d9897dfd9db8566b8e806d1cf8fe233384fa
|
| 3 |
+
size 276517
|
models/saved/meta_classifier/metrics.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"roc_auc": 0.9770497393987524,
|
| 3 |
+
"bucket_accuracy": {
|
| 4 |
+
"short": 0.6564185544768069,
|
| 5 |
+
"medium": 0.9867549668874173,
|
| 6 |
+
"long": 0.9961035346507097
|
| 7 |
+
}
|
| 8 |
+
}
|
models/saved/roberta_model/cm_roberta.png
ADDED
|
models/saved/roberta_model/config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_cross_attention": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"RobertaForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"classifier_dropout": null,
|
| 9 |
+
"dtype": "float32",
|
| 10 |
+
"eos_token_id": 2,
|
| 11 |
+
"hidden_act": "gelu",
|
| 12 |
+
"hidden_dropout_prob": 0.1,
|
| 13 |
+
"hidden_size": 768,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 3072,
|
| 16 |
+
"is_decoder": false,
|
| 17 |
+
"layer_norm_eps": 1e-05,
|
| 18 |
+
"max_position_embeddings": 514,
|
| 19 |
+
"model_type": "roberta",
|
| 20 |
+
"num_attention_heads": 12,
|
| 21 |
+
"num_hidden_layers": 12,
|
| 22 |
+
"pad_token_id": 1,
|
| 23 |
+
"problem_type": "single_label_classification",
|
| 24 |
+
"tie_word_embeddings": true,
|
| 25 |
+
"transformers_version": "5.5.3",
|
| 26 |
+
"type_vocab_size": 1,
|
| 27 |
+
"use_cache": false,
|
| 28 |
+
"vocab_size": 50265
|
| 29 |
+
}
|
models/saved/roberta_model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7c3b04b2a6f1e1c303664a07df3c950f97f4571b2f5bfebfce7b5851edbac430
|
| 3 |
+
size 498612800
|
models/saved/roberta_model/roberta_oof.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c60da5d3c8e05b0562eafac205e93d2f8ec51649e01675b9ab9d9a087d6385ce
|
| 3 |
+
size 161240
|
models/saved/roberta_model/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/saved/roberta_model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": "<s>",
|
| 5 |
+
"cls_token": "<s>",
|
| 6 |
+
"eos_token": "</s>",
|
| 7 |
+
"errors": "replace",
|
| 8 |
+
"is_local": false,
|
| 9 |
+
"mask_token": "<mask>",
|
| 10 |
+
"model_max_length": 512,
|
| 11 |
+
"pad_token": "<pad>",
|
| 12 |
+
"sep_token": "</s>",
|
| 13 |
+
"tokenizer_class": "RobertaTokenizer",
|
| 14 |
+
"trim_offsets": true,
|
| 15 |
+
"unk_token": "<unk>"
|
| 16 |
+
}
|
models/saved/roberta_model/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b7c09d4d3610e99efbd18eb40d7b9c533e6d00d9ce19be89748b4377176c5d72
|
| 3 |
+
size 5329
|
models/saved/tokenizer.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f106d0d56f5cc7a71b6012570e2f9c4385b0462b3473c6191957e7963d3979d
|
| 3 |
+
size 4361433
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas>=2.0.0
|
| 2 |
+
numpy>=1.24.0
|
| 3 |
+
scikit-learn>=1.3.0
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
transformers>=4.30.0
|
| 6 |
+
datasets>=2.14.0
|
| 7 |
+
sentence-transformers>=2.2.0
|
| 8 |
+
xgboost>=2.0.0
|
| 9 |
+
spacy>=3.6.0
|
| 10 |
+
duckduckgo-search>=4.0.0
|
| 11 |
+
streamlit>=1.28.0
|
| 12 |
+
plotly>=5.18.0
|
| 13 |
+
newspaper3k>=0.2.8
|
| 14 |
+
pyyaml>=6.0
|
| 15 |
+
tqdm>=4.65.0
|
| 16 |
+
joblib>=1.3.0
|
| 17 |
+
lxml_html_clean
|
| 18 |
+
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl
|
run_pipeline.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import argparse
|
| 4 |
+
import subprocess
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
| 8 |
+
logger = logging.getLogger("run_pipeline")
|
| 9 |
+
|
| 10 |
+
def execute_stage(stage_num):
|
| 11 |
+
logger.info(f"========== TRIGGERING STAGE {stage_num} ==========")
|
| 12 |
+
if stage_num == 1:
|
| 13 |
+
script = "src/stage1_ingestion.py"
|
| 14 |
+
elif stage_num == 2:
|
| 15 |
+
script = "src/stage2_preprocessing.py"
|
| 16 |
+
elif stage_num == 3:
|
| 17 |
+
script = "src/stage3_training.py"
|
| 18 |
+
else:
|
| 19 |
+
logger.error(f"Unknown Stage: {stage_num}")
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
if not os.path.exists(script):
|
| 23 |
+
logger.error(f"Cannot find script: {script}")
|
| 24 |
+
sys.exit(1)
|
| 25 |
+
|
| 26 |
+
res = subprocess.run([sys.executable, script])
|
| 27 |
+
if res.returncode != 0:
|
| 28 |
+
logger.error(f"Stage {stage_num} failed!")
|
| 29 |
+
sys.exit(1)
|
| 30 |
+
logger.info(f"========== STAGE {stage_num} FINISHED ==========\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def execute_evaluation():
|
| 34 |
+
logger.info("========== TRIGGERING FINAL HOLD-OUT BENCHMARK ==========")
|
| 35 |
+
import pandas as pd
|
| 36 |
+
import numpy as np
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
|
| 39 |
+
|
| 40 |
+
# Needs to be imported late so it doesn't fail if dependencies aren't setup
|
| 41 |
+
from src.stage4_inference import predict_article
|
| 42 |
+
|
| 43 |
+
df_path = "data/splits/df_holdout.csv"
|
| 44 |
+
if not os.path.exists(df_path):
|
| 45 |
+
logger.error(f"Holdout file missing at {df_path}. Run Stages 1-3 first.")
|
| 46 |
+
sys.exit(1)
|
| 47 |
+
|
| 48 |
+
df = pd.read_csv(df_path)
|
| 49 |
+
logger.info(f"Loaded {len(df)} Stratified Holdout records.")
|
| 50 |
+
|
| 51 |
+
y_true = df["binary_label"].values
|
| 52 |
+
y_pred = []
|
| 53 |
+
|
| 54 |
+
logger.info("Executing isolated pipeline inference across holdout targets (RAG safely bypassed)...")
|
| 55 |
+
logger.info("NOTE: Since this evaluates the entire heavy 4-model ensemble locally, it may take several minutes.")
|
| 56 |
+
|
| 57 |
+
for i, row in tqdm(df.iterrows(), total=len(df), desc="Benchmarking Evaluator"):
|
| 58 |
+
# We manually map the inference parameters directly into the ultimate test pipeline
|
| 59 |
+
res = predict_article(
|
| 60 |
+
title=row.get("title", ""),
|
| 61 |
+
text=row.get("text", ""),
|
| 62 |
+
source_domain=row.get("source_domain", ""),
|
| 63 |
+
published_date=row.get("published_date", ""),
|
| 64 |
+
mode="full",
|
| 65 |
+
trigger_rag=False
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# New 4-tier verdict mapping:
|
| 69 |
+
# TRUE / UNCERTAIN → 1 (real news)
|
| 70 |
+
# LIKELY FALSE / FALSE → 0 (fake news)
|
| 71 |
+
v = res["verdict"]
|
| 72 |
+
pred_label = 1 if v in ("TRUE", "UNCERTAIN") else 0
|
| 73 |
+
y_pred.append(pred_label)
|
| 74 |
+
|
| 75 |
+
y_pred = np.array(y_pred)
|
| 76 |
+
acc = accuracy_score(y_true, y_pred)
|
| 77 |
+
|
| 78 |
+
logger.info(f"\n================ BENCHMARK RESULTS ================")
|
| 79 |
+
logger.info(f"Final Architecture Accuracy: {acc * 100:.2f}%")
|
| 80 |
+
logger.info("\n" + classification_report(y_true, y_pred, target_names=["Fake News (0)", "True News (1)"]))
|
| 81 |
+
logger.info(f"Confusion Matrix:\n{confusion_matrix(y_true, y_pred)}")
|
| 82 |
+
logger.info("===================================================\n")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
parser = argparse.ArgumentParser(description="Fake News Detection System Pipeline")
|
| 87 |
+
parser.add_argument("--stage", nargs="+", type=int, choices=[1, 2, 3], help="Specify stages to run (e.g. --stage 1 2 3)")
|
| 88 |
+
parser.add_argument("--eval", action="store_true", help="Evaluate the architecture natively on the stratified holdout benchmark")
|
| 89 |
+
|
| 90 |
+
args = parser.parse_args()
|
| 91 |
+
|
| 92 |
+
if args.stage:
|
| 93 |
+
for s in args.stage:
|
| 94 |
+
execute_stage(s)
|
| 95 |
+
|
| 96 |
+
if args.eval:
|
| 97 |
+
execute_evaluation()
|
| 98 |
+
|
| 99 |
+
if not args.stage and not args.eval:
|
| 100 |
+
parser.print_help()
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Fake News Detection — Source Package
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Fake News Detection — Models Package
|
src/models/distilbert_model.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoTokenizer,
|
| 11 |
+
AutoModelForSequenceClassification,
|
| 12 |
+
Trainer,
|
| 13 |
+
TrainingArguments,
|
| 14 |
+
DataCollatorWithPadding
|
| 15 |
+
)
|
| 16 |
+
from datasets import Dataset
|
| 17 |
+
|
| 18 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 19 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 21 |
+
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
| 23 |
+
logger = logging.getLogger("distilbert_model")
|
| 24 |
+
|
| 25 |
+
def train_distilbert(cfg, splits_dir, save_dir):
|
| 26 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
# 1. Load Data
|
| 29 |
+
train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
|
| 30 |
+
val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
|
| 31 |
+
|
| 32 |
+
train_df["clean_text"] = train_df["clean_text"].fillna("")
|
| 33 |
+
val_df["clean_text"] = val_df["clean_text"].fillna("")
|
| 34 |
+
|
| 35 |
+
maxlen = cfg.get("preprocessing", {}).get("bert_max_len", 512)
|
| 36 |
+
batch_size = cfg.get("training", {}).get("bert_batch_size", 16)
|
| 37 |
+
epochs = cfg.get("training", {}).get("bert_epochs", 3)
|
| 38 |
+
lr = float(cfg.get("training", {}).get("lr_learning_rate", 2e-5))
|
| 39 |
+
|
| 40 |
+
logger.info("Loading DistilBERT tokenizer...")
|
| 41 |
+
model_name = "distilbert-base-uncased"
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 43 |
+
|
| 44 |
+
# 2. Tokenization Helper
|
| 45 |
+
def tokenize_function(examples):
|
| 46 |
+
return tokenizer(examples["text"], padding=False, truncation=True, max_length=maxlen)
|
| 47 |
+
|
| 48 |
+
# 3. Create OOF Proxy Split (80/20) safely to accelerate pipeline training (avoid 5-fold computation cost)
|
| 49 |
+
idx_train, idx_meta_val = train_test_split(
|
| 50 |
+
range(len(train_df)), test_size=0.20,
|
| 51 |
+
stratify=train_df["binary_label"], random_state=42
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
subset_train_df = train_df.iloc[idx_train].copy()
|
| 55 |
+
|
| 56 |
+
# 4. Convert to HuggingFace Datasets
|
| 57 |
+
hf_sub_train = Dataset.from_pandas(pd.DataFrame({
|
| 58 |
+
"text": subset_train_df["clean_text"], "labels": subset_train_df["binary_label"]
|
| 59 |
+
}), preserve_index=False)
|
| 60 |
+
|
| 61 |
+
hf_full_train = Dataset.from_pandas(pd.DataFrame({
|
| 62 |
+
"text": train_df["clean_text"], "labels": train_df["binary_label"]
|
| 63 |
+
}), preserve_index=False)
|
| 64 |
+
|
| 65 |
+
hf_val = Dataset.from_pandas(pd.DataFrame({
|
| 66 |
+
"text": val_df["clean_text"], "labels": val_df["binary_label"]
|
| 67 |
+
}), preserve_index=False)
|
| 68 |
+
|
| 69 |
+
logger.info("Tokenizing datasets...")
|
| 70 |
+
hf_sub_train = hf_sub_train.map(tokenize_function, batched=True)
|
| 71 |
+
hf_full_train = hf_full_train.map(tokenize_function, batched=True)
|
| 72 |
+
hf_val = hf_val.map(tokenize_function, batched=True)
|
| 73 |
+
|
| 74 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 75 |
+
|
| 76 |
+
# 5. Initialize Model
|
| 77 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
| 78 |
+
|
| 79 |
+
# 6. Trainer Setup
|
| 80 |
+
training_args = TrainingArguments(
|
| 81 |
+
output_dir=os.path.join(save_dir, "checkpoints"),
|
| 82 |
+
eval_strategy="epoch",
|
| 83 |
+
save_strategy="epoch",
|
| 84 |
+
learning_rate=lr,
|
| 85 |
+
per_device_train_batch_size=batch_size,
|
| 86 |
+
per_device_eval_batch_size=batch_size,
|
| 87 |
+
gradient_accumulation_steps=2,
|
| 88 |
+
dataloader_num_workers=2,
|
| 89 |
+
num_train_epochs=epochs,
|
| 90 |
+
weight_decay=0.01,
|
| 91 |
+
load_best_model_at_end=True,
|
| 92 |
+
metric_for_best_model="eval_loss",
|
| 93 |
+
greater_is_better=False,
|
| 94 |
+
fp16=torch.cuda.is_available(),
|
| 95 |
+
disable_tqdm=False
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
trainer = Trainer(
|
| 99 |
+
model=model,
|
| 100 |
+
args=training_args,
|
| 101 |
+
train_dataset=hf_sub_train,
|
| 102 |
+
eval_dataset=hf_val,
|
| 103 |
+
processing_class=tokenizer,
|
| 104 |
+
data_collator=data_collator,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# 7. Train
|
| 108 |
+
logger.info("Starting DistilBERT internal proxy training...")
|
| 109 |
+
trainer.train()
|
| 110 |
+
|
| 111 |
+
# 8. Save Model
|
| 112 |
+
logger.info("Saving final fine-tuned model...")
|
| 113 |
+
trainer.save_model(save_dir)
|
| 114 |
+
tokenizer.save_pretrained(save_dir)
|
| 115 |
+
|
| 116 |
+
# 9. Extract OOF over the entire training set
|
| 117 |
+
logger.info("Generating OOF predictions on full train set proxy wrapper...")
|
| 118 |
+
oof_preds = trainer.predict(hf_full_train)
|
| 119 |
+
# probabilities for class 1 (True)
|
| 120 |
+
oof_probas = torch.softmax(torch.tensor(oof_preds.predictions), dim=-1)[:, 1].numpy()
|
| 121 |
+
np.save(os.path.join(save_dir, "distilbert_oof.npy"), oof_probas)
|
| 122 |
+
logger.info("Saved distilbert_oof.npy")
|
| 123 |
+
|
| 124 |
+
# Validation evaluation mapped later by main loop, or manually if desired.
|
| 125 |
+
val_preds_out = trainer.predict(hf_val)
|
| 126 |
+
val_probas = torch.softmax(torch.tensor(val_preds_out.predictions), dim=-1)[:, 1].numpy()
|
| 127 |
+
|
| 128 |
+
from src.models.logistic_model import plot_and_save_cm
|
| 129 |
+
plot_and_save_cm(
|
| 130 |
+
val_df["binary_label"],
|
| 131 |
+
(val_probas > 0.5).astype(int),
|
| 132 |
+
os.path.join(save_dir, "cm.png"),
|
| 133 |
+
title="DistilBERT Confusion Matrix"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
logger.info("DistilBERT Training completed!")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ====================================================================
|
| 140 |
+
# OPTIONAL: Full K-Fold OOF (GPU-intensive)
|
| 141 |
+
# --------------------------------------------------------------------
|
| 142 |
+
# The strategy above saves enormous compute by generating a single
|
| 143 |
+
# proxy model to predict the full training pool. A strict K-Fold
|
| 144 |
+
# architecture requires training DistilBERT 5 entirely separate
|
| 145 |
+
# instances which spans roughly 15+ epochs locally. Use below
|
| 146 |
+
# if massive parallel A100s are available.
|
| 147 |
+
#
|
| 148 |
+
"""
|
| 149 |
+
from sklearn.model_selection import StratifiedKFold
|
| 150 |
+
|
| 151 |
+
def strict_kfold_distilbert(train_df, tokenize_function, data_collator, lr, batch_size, epochs, save_dir):
|
| 152 |
+
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
| 153 |
+
oof_probas = np.zeros(len(train_df), dtype=np.float32)
|
| 154 |
+
|
| 155 |
+
for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df["binary_label"])):
|
| 156 |
+
logger.info(f"Training Fold {fold+1}/5")
|
| 157 |
+
df_train = train_df.iloc[train_idx].copy()
|
| 158 |
+
df_val = train_df.iloc[val_idx].copy()
|
| 159 |
+
|
| 160 |
+
ds_train = Dataset.from_pandas(pd.DataFrame({"text": df_train["clean_text"], "labels": df_train["binary_label"]}), preserve_index=False).map(tokenize_function, batched=True)
|
| 161 |
+
ds_val = Dataset.from_pandas(pd.DataFrame({"text": df_val["clean_text"], "labels": df_val["binary_label"]}), preserve_index=False).map(tokenize_function, batched=True)
|
| 162 |
+
|
| 163 |
+
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
|
| 164 |
+
|
| 165 |
+
training_args = TrainingArguments(
|
| 166 |
+
output_dir=os.path.join(save_dir, f"fold_{fold}"),
|
| 167 |
+
eval_strategy="epoch",
|
| 168 |
+
save_strategy="epoch",
|
| 169 |
+
learning_rate=lr,
|
| 170 |
+
per_device_train_batch_size=batch_size,
|
| 171 |
+
num_train_epochs=epochs,
|
| 172 |
+
fp16=torch.cuda.is_available(),
|
| 173 |
+
load_best_model_at_end=True,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
trainer = Trainer(
|
| 177 |
+
model=model,
|
| 178 |
+
args=training_args,
|
| 179 |
+
train_dataset=ds_train,
|
| 180 |
+
eval_dataset=ds_val,
|
| 181 |
+
data_collator=data_collator,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
trainer.train()
|
| 185 |
+
fold_preds = trainer.predict(ds_val)
|
| 186 |
+
oof_probas[val_idx] = torch.softmax(torch.tensor(fold_preds.predictions), dim=-1)[:, 1].numpy()
|
| 187 |
+
|
| 188 |
+
np.save(os.path.join(save_dir, "distilbert_oof.npy"), oof_probas)
|
| 189 |
+
"""
|
| 190 |
+
# ====================================================================
|
| 191 |
+
|
| 192 |
+
if __name__ == "__main__":
|
| 193 |
+
import yaml
|
| 194 |
+
cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
|
| 195 |
+
with open(cfg_path, "r", encoding="utf-8") as file:
|
| 196 |
+
config = yaml.safe_load(file)
|
| 197 |
+
|
| 198 |
+
s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"])
|
| 199 |
+
m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "distilbert_model")
|
| 200 |
+
|
| 201 |
+
train_distilbert(config, s_dir, m_dir)
|
src/models/logistic_model.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import time
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import joblib
|
| 9 |
+
from sklearn.pipeline import Pipeline
|
| 10 |
+
from sklearn.compose import ColumnTransformer
|
| 11 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 12 |
+
from sklearn.preprocessing import OneHotEncoder
|
| 13 |
+
from sklearn.linear_model import LogisticRegression
|
| 14 |
+
from sklearn.model_selection import StratifiedKFold, cross_val_predict, GridSearchCV
|
| 15 |
+
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
|
| 16 |
+
from matplotlib import pyplot as plt
|
| 17 |
+
|
| 18 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 19 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 21 |
+
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
| 23 |
+
logger = logging.getLogger("logistic_model")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_data(splits_dir):
|
| 27 |
+
"""Load train and val pandas dataframes, maintaining clean_text and text_length_bucket."""
|
| 28 |
+
train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
|
| 29 |
+
val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
|
| 30 |
+
|
| 31 |
+
# Fill NaN just in case
|
| 32 |
+
train_df["clean_text"] = train_df["clean_text"].fillna("")
|
| 33 |
+
val_df["clean_text"] = val_df["clean_text"].fillna("")
|
| 34 |
+
|
| 35 |
+
return train_df, val_df
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def plot_and_save_cm(y_true, y_pred, path, title="Logistic Regression Confusion Matrix"):
|
| 39 |
+
"""Save confusion matrix as a PNG."""
|
| 40 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 41 |
+
fig, ax = plt.subplots(figsize=(5, 5))
|
| 42 |
+
ax.matshow(cm, cmap=plt.cm.Blues, alpha=0.3)
|
| 43 |
+
for i in range(cm.shape[0]):
|
| 44 |
+
for j in range(cm.shape[1]):
|
| 45 |
+
ax.text(x=j, y=i, s=cm[i, j], va='center', ha='center', size='xx-large')
|
| 46 |
+
plt.xlabel('Predicted Label')
|
| 47 |
+
plt.ylabel('True Label')
|
| 48 |
+
plt.title(title)
|
| 49 |
+
plt.tight_layout()
|
| 50 |
+
plt.savefig(path)
|
| 51 |
+
plt.close()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def train_logistic_model(cfg, splits_dir, save_dir):
|
| 55 |
+
logger.info("Initializing Logistic Regression Training...")
|
| 56 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
train_df, val_df = load_data(splits_dir)
|
| 59 |
+
y_train = train_df["binary_label"].values
|
| 60 |
+
y_val = val_df["binary_label"].values
|
| 61 |
+
|
| 62 |
+
max_features = cfg.get("preprocessing", {}).get("max_tfidf_features", 50000)
|
| 63 |
+
|
| 64 |
+
# Define ColumnTransformer for generic pipeline feature stack
|
| 65 |
+
preprocessor = ColumnTransformer(
|
| 66 |
+
transformers=[
|
| 67 |
+
("tfidf", TfidfVectorizer(max_features=max_features, ngram_range=(1, 2)), "clean_text"),
|
| 68 |
+
("cat", OneHotEncoder(handle_unknown="ignore"), ["text_length_bucket"])
|
| 69 |
+
],
|
| 70 |
+
remainder="drop"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Define Model
|
| 74 |
+
log_reg = LogisticRegression(class_weight="balanced", random_state=42, max_iter=1000)
|
| 75 |
+
|
| 76 |
+
pipeline = Pipeline(steps=[
|
| 77 |
+
("preprocessor", preprocessor),
|
| 78 |
+
("classifier", log_reg)
|
| 79 |
+
])
|
| 80 |
+
|
| 81 |
+
# K-Fold OOF Predictions
|
| 82 |
+
logger.info("Generating 5-Fold OOF predictions on Train set...")
|
| 83 |
+
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
| 84 |
+
# Using method='predict_proba' returns a 2D array [n_samples, 2]
|
| 85 |
+
oof_probas = cross_val_predict(pipeline, train_df, y_train, cv=cv, method='predict_proba', n_jobs=-1)
|
| 86 |
+
|
| 87 |
+
np.save(os.path.join(save_dir, "lr_oof.npy"), oof_probas[:, 1])
|
| 88 |
+
logger.info("Saved OOF predictions (lr_oof.npy)")
|
| 89 |
+
|
| 90 |
+
# Hyperparameter Tuning on full Train via GridSearch
|
| 91 |
+
logger.info("Hyperparameter tuning C over 5-folds...")
|
| 92 |
+
param_grid = {'classifier__C': [0.1, 1.0, 10.0]}
|
| 93 |
+
|
| 94 |
+
grid_search = GridSearchCV(pipeline, param_grid, cv=cv, scoring='f1_macro', n_jobs=-1)
|
| 95 |
+
grid_search.fit(train_df, y_train)
|
| 96 |
+
|
| 97 |
+
best_pipeline = grid_search.best_estimator_
|
| 98 |
+
logger.info(f"Best parameter C: {grid_search.best_params_['classifier__C']}")
|
| 99 |
+
|
| 100 |
+
# Validation Evaluation
|
| 101 |
+
val_probas = best_pipeline.predict_proba(val_df)[:, 1]
|
| 102 |
+
val_preds = (val_probas >= 0.5).astype(int)
|
| 103 |
+
|
| 104 |
+
logger.info("Validation Classification Report:\n" + classification_report(y_val, val_preds))
|
| 105 |
+
roc_auc = roc_auc_score(y_val, val_probas)
|
| 106 |
+
logger.info(f"ROC-AUC: {roc_auc:.4f}")
|
| 107 |
+
|
| 108 |
+
# Generate Evaluation Artifacts
|
| 109 |
+
plot_and_save_cm(y_val, val_preds, os.path.join(save_dir, "cm.png"))
|
| 110 |
+
|
| 111 |
+
# Compute accuracy per text length bucket on val
|
| 112 |
+
bucket_acc = {}
|
| 113 |
+
for b in ["short", "medium", "long"]:
|
| 114 |
+
b_mask = (val_df["text_length_bucket"] == b)
|
| 115 |
+
if b_mask.sum() > 0:
|
| 116 |
+
acc = (val_preds[b_mask] == y_val[b_mask]).mean()
|
| 117 |
+
bucket_acc[b] = acc
|
| 118 |
+
|
| 119 |
+
metrics = {
|
| 120 |
+
"roc_auc": float(roc_auc),
|
| 121 |
+
"bucket_accuracy": {k: float(v) for k, v in bucket_acc.items()}
|
| 122 |
+
}
|
| 123 |
+
with open(os.path.join(save_dir, "metrics.json"), "w") as f:
|
| 124 |
+
json.dump(metrics, f, indent=2)
|
| 125 |
+
|
| 126 |
+
# Save Pipeline
|
| 127 |
+
joblib.dump(best_pipeline, os.path.join(save_dir, "logistic_model.pkl"))
|
| 128 |
+
logger.info("Saved Logistic Regression Pipeline to format `logistic_model.pkl`.")
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
import yaml
|
| 132 |
+
cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
|
| 133 |
+
with open(cfg_path, "r", encoding="utf-8") as file:
|
| 134 |
+
config = yaml.safe_load(file)
|
| 135 |
+
|
| 136 |
+
s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"])
|
| 137 |
+
m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "logistic_model")
|
| 138 |
+
|
| 139 |
+
t0 = time.time()
|
| 140 |
+
train_logistic_model(config, s_dir, m_dir)
|
| 141 |
+
print(f"Total time: {time.time() - t0:.2f}s")
|
src/models/lstm_model.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import time
|
| 6 |
+
import pickle
|
| 7 |
+
import copy
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.utils.data import TensorDataset, DataLoader, Subset
|
| 14 |
+
from sklearn.model_selection import StratifiedKFold
|
| 15 |
+
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
|
| 16 |
+
from matplotlib import pyplot as plt
|
| 17 |
+
|
| 18 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 19 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 21 |
+
|
| 22 |
+
# We need the Tokenizer from stage 2 to execute texts_to_sequences natively
|
| 23 |
+
from src.stage2_preprocessing import KerasStyleTokenizer
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
| 26 |
+
logger = logging.getLogger("lstm_model")
|
| 27 |
+
|
| 28 |
+
# ── Architecture ──────────────────────────────────────
|
| 29 |
+
class SpatialDropout1D(nn.Module):
|
| 30 |
+
def __init__(self, p=0.3):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.p = p
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
if not self.training or self.p == 0:
|
| 36 |
+
return x
|
| 37 |
+
# x is (batch, seq_len, embed_dim)
|
| 38 |
+
# convert to (batch, embed_dim, seq_len)
|
| 39 |
+
x = x.permute(0, 2, 1)
|
| 40 |
+
# 1D spatial dropout is equivalent to 2d dropout with height 1
|
| 41 |
+
# nn.Dropout2d drops entire channels (which are our embedding dimensions)
|
| 42 |
+
x = x.unsqueeze(3)
|
| 43 |
+
x = F.dropout2d(x, p=self.p, training=self.training)
|
| 44 |
+
x = x.squeeze(3)
|
| 45 |
+
return x.permute(0, 2, 1)
|
| 46 |
+
|
| 47 |
+
class BiLSTMClassifier(nn.Module):
|
| 48 |
+
def __init__(self, vocab_size, embedding_matrix=None):
|
| 49 |
+
super().__init__()
|
| 50 |
+
# Embedding(vocab_size, 100)
|
| 51 |
+
self.embedding = nn.Embedding(vocab_size, 100, padding_idx=0)
|
| 52 |
+
if embedding_matrix is not None:
|
| 53 |
+
self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix))
|
| 54 |
+
self.embedding.weight.requires_grad = False
|
| 55 |
+
|
| 56 |
+
self.spatial_drop = SpatialDropout1D(0.3)
|
| 57 |
+
|
| 58 |
+
# Bi-LSTM(100->128, bidirectional=True)
|
| 59 |
+
self.lstm1 = nn.LSTM(100, 128, bidirectional=True, batch_first=True)
|
| 60 |
+
# Bi-LSTM(256->64, bidirectional=True)
|
| 61 |
+
self.lstm2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True)
|
| 62 |
+
|
| 63 |
+
# Linear(128, 64) + ReLU
|
| 64 |
+
self.fc1 = nn.Linear(128, 64)
|
| 65 |
+
self.dropout = nn.Dropout(0.4)
|
| 66 |
+
# Linear(64, 1) + Sigmoid (handled via BCEWithLogitsLoss below conceptually, or explicitly applied)
|
| 67 |
+
self.fc2 = nn.Linear(64, 1)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
h = self.embedding(x)
|
| 71 |
+
h = self.spatial_drop(h)
|
| 72 |
+
|
| 73 |
+
h, _ = self.lstm1(h)
|
| 74 |
+
# Taking last states? Typically Keras `return_sequences=False` on the 2nd LSTM
|
| 75 |
+
# means it takes the final hidden state of the sequence
|
| 76 |
+
_, (h_n, _) = self.lstm2(h)
|
| 77 |
+
|
| 78 |
+
# h_n shape for Bi-LSTM: (2, batch, hidden_size)
|
| 79 |
+
# Concatenate forward and backward final states
|
| 80 |
+
h_concat = torch.cat((h_n[-2,:,:], h_n[-1,:,:]), dim=1) # shape: (batch, 128)
|
| 81 |
+
|
| 82 |
+
out = F.relu(self.fc1(h_concat))
|
| 83 |
+
out = self.dropout(out)
|
| 84 |
+
logits = self.fc2(out)
|
| 85 |
+
|
| 86 |
+
return logits.squeeze(1)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ── Utilities ──────────────────────────────────────
|
| 90 |
+
def pad_sequences(sequences, maxlen=512, padding='post'):
|
| 91 |
+
padded = np.zeros((len(sequences), maxlen), dtype=np.int64)
|
| 92 |
+
for i, seq in enumerate(sequences):
|
| 93 |
+
seq = seq[:maxlen]
|
| 94 |
+
if padding == 'post':
|
| 95 |
+
padded[i, :len(seq)] = seq
|
| 96 |
+
else:
|
| 97 |
+
padded[i, -len(seq):] = seq
|
| 98 |
+
return padded
|
| 99 |
+
|
| 100 |
+
def load_glove_embeddings(glove_path, word_index, embed_dim=100):
|
| 101 |
+
logger.info(f"Loading GloVe embeddings from {glove_path}...")
|
| 102 |
+
embeddings_index = {}
|
| 103 |
+
with open(glove_path, "r", encoding="utf-8") as f:
|
| 104 |
+
for line in f:
|
| 105 |
+
values = line.split()
|
| 106 |
+
word = values[0]
|
| 107 |
+
coefs = np.asarray(values[1:], dtype='float32')
|
| 108 |
+
embeddings_index[word] = coefs
|
| 109 |
+
|
| 110 |
+
vocab_size = len(word_index) + 1 # 1 for padding
|
| 111 |
+
embedding_matrix = np.zeros((vocab_size, embed_dim), dtype=np.float32)
|
| 112 |
+
hits, misses = 0, 0
|
| 113 |
+
for word, i in word_index.items():
|
| 114 |
+
embedding_vector = embeddings_index.get(word)
|
| 115 |
+
if embedding_vector is not None:
|
| 116 |
+
embedding_matrix[i] = embedding_vector
|
| 117 |
+
hits += 1
|
| 118 |
+
else:
|
| 119 |
+
misses += 1
|
| 120 |
+
logger.info(f"GloVe mapped: {hits} hits, {misses} misses.")
|
| 121 |
+
return embedding_matrix, vocab_size
|
| 122 |
+
|
| 123 |
+
def plot_and_save_cm(y_true, y_pred, path):
|
| 124 |
+
cm = confusion_matrix(y_true, (np.array(y_pred) > 0.5).astype(int))
|
| 125 |
+
fig, ax = plt.subplots(figsize=(5, 5))
|
| 126 |
+
ax.matshow(cm, cmap=plt.cm.Blues, alpha=0.3)
|
| 127 |
+
for i in range(cm.shape[0]):
|
| 128 |
+
for j in range(cm.shape[1]):
|
| 129 |
+
ax.text(x=j, y=i, s=cm[i, j], va='center', ha='center', size='xx-large')
|
| 130 |
+
plt.xlabel('Predicted Label')
|
| 131 |
+
plt.ylabel('True Label')
|
| 132 |
+
plt.title('Bi-LSTM Confusion Matrix')
|
| 133 |
+
plt.tight_layout()
|
| 134 |
+
plt.savefig(path)
|
| 135 |
+
plt.close()
|
| 136 |
+
|
| 137 |
+
# ── Training Loop ──────────────────────────────────────
|
| 138 |
+
def train_epoch(model, loader, optimizer, criterion, device):
|
| 139 |
+
model.train()
|
| 140 |
+
total_loss = 0
|
| 141 |
+
for x_batch, y_batch in loader:
|
| 142 |
+
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
|
| 143 |
+
optimizer.zero_grad()
|
| 144 |
+
logits = model(x_batch)
|
| 145 |
+
loss = criterion(logits, y_batch)
|
| 146 |
+
loss.backward()
|
| 147 |
+
optimizer.step()
|
| 148 |
+
total_loss += loss.item() * x_batch.size(0)
|
| 149 |
+
return total_loss / len(loader.dataset)
|
| 150 |
+
|
| 151 |
+
@torch.no_grad()
|
| 152 |
+
def eval_model(model, loader, criterion, device):
|
| 153 |
+
model.eval()
|
| 154 |
+
total_loss = 0
|
| 155 |
+
all_preds = []
|
| 156 |
+
for x_batch, y_batch in loader:
|
| 157 |
+
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
|
| 158 |
+
logits = model(x_batch)
|
| 159 |
+
loss = criterion(logits, y_batch)
|
| 160 |
+
total_loss += loss.item() * x_batch.size(0)
|
| 161 |
+
probas = torch.sigmoid(logits).cpu().numpy()
|
| 162 |
+
all_preds.extend(probas)
|
| 163 |
+
return total_loss / len(loader.dataset), np.array(all_preds)
|
| 164 |
+
|
| 165 |
+
def train_lstm_logic(cfg, splits_dir, save_dir, glove_path):
|
| 166 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 167 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 168 |
+
logger.info(f"Using device: {device}")
|
| 169 |
+
|
| 170 |
+
# Load tokenized resources
|
| 171 |
+
train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
|
| 172 |
+
val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
|
| 173 |
+
y_train = np.float32(train_df["binary_label"].values)
|
| 174 |
+
y_val = np.float32(val_df["binary_label"].values)
|
| 175 |
+
|
| 176 |
+
with open(os.path.join(_PROJECT_ROOT, cfg["paths"]["models_dir"], "tokenizer.pkl"), "rb") as f:
|
| 177 |
+
tokenizer = pickle.load(f)
|
| 178 |
+
|
| 179 |
+
maxlen = cfg.get("preprocessing", {}).get("lstm_max_len", 512)
|
| 180 |
+
batch_size = cfg.get("training", {}).get("lstm_batch_size", 64)
|
| 181 |
+
epochs = cfg.get("training", {}).get("lstm_epochs", 10)
|
| 182 |
+
|
| 183 |
+
logger.info("Transforming texts to padded sequences...")
|
| 184 |
+
X_train_seq = tokenizer.texts_to_sequences(train_df["clean_text"].fillna(""))
|
| 185 |
+
X_val_seq = tokenizer.texts_to_sequences(val_df["clean_text"].fillna(""))
|
| 186 |
+
|
| 187 |
+
X_train_pad = pad_sequences(X_train_seq, maxlen=maxlen, padding='post')
|
| 188 |
+
X_val_pad = pad_sequences(X_val_seq, maxlen=maxlen, padding='post')
|
| 189 |
+
|
| 190 |
+
# Embedding matrix
|
| 191 |
+
emb_matrix, vocab_size = load_glove_embeddings(glove_path, tokenizer.word_index)
|
| 192 |
+
|
| 193 |
+
# Class weights balancing formula: n_samples / (n_classes * np.bincount(y))
|
| 194 |
+
class_counts = np.bincount(y_train.astype(int))
|
| 195 |
+
pos_weight = torch.tensor([class_counts[0] / class_counts[1]], dtype=torch.float32).to(device)
|
| 196 |
+
|
| 197 |
+
# Datasets
|
| 198 |
+
train_tensor = TensorDataset(torch.from_numpy(X_train_pad).long(), torch.from_numpy(y_train))
|
| 199 |
+
val_tensor = TensorDataset(torch.from_numpy(X_val_pad).long(), torch.from_numpy(y_val))
|
| 200 |
+
val_loader = DataLoader(val_tensor, batch_size=batch_size, shuffle=False)
|
| 201 |
+
|
| 202 |
+
# --- 5-Fold OOF Predictions ---
|
| 203 |
+
logger.info("Starting 5-Fold OOF generation...")
|
| 204 |
+
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
| 205 |
+
oof_preds = np.zeros_like(y_train, dtype=np.float32)
|
| 206 |
+
|
| 207 |
+
criterion_kfold = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 208 |
+
|
| 209 |
+
for fold, (t_idx, v_idx) in enumerate(skf.split(X_train_pad, y_train)):
|
| 210 |
+
logger.info(f"OOF Fold {fold+1}/5")
|
| 211 |
+
|
| 212 |
+
fold_train_ds = Subset(train_tensor, t_idx)
|
| 213 |
+
fold_val_ds = Subset(train_tensor, v_idx)
|
| 214 |
+
fold_train_loader = DataLoader(fold_train_ds, batch_size=batch_size, shuffle=True)
|
| 215 |
+
fold_val_loader = DataLoader(fold_val_ds, batch_size=batch_size, shuffle=False)
|
| 216 |
+
|
| 217 |
+
model = BiLSTMClassifier(vocab_size, emb_matrix).to(device)
|
| 218 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 219 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5)
|
| 220 |
+
|
| 221 |
+
best_val_loss = float('inf')
|
| 222 |
+
patience_counter = 0
|
| 223 |
+
best_weights = copy.deepcopy(model.state_dict())
|
| 224 |
+
|
| 225 |
+
for ep in range(epochs): # Or hardcode early stop tightly for OOF e.g., 3-4 epochs max to save time
|
| 226 |
+
t_loss = train_epoch(model, fold_train_loader, optimizer, criterion_kfold, device)
|
| 227 |
+
v_loss, v_preds = eval_model(model, fold_val_loader, criterion_kfold, device)
|
| 228 |
+
scheduler.step(v_loss)
|
| 229 |
+
|
| 230 |
+
if v_loss < best_val_loss:
|
| 231 |
+
best_val_loss = v_loss
|
| 232 |
+
best_weights = copy.deepcopy(model.state_dict())
|
| 233 |
+
patience_counter = 0
|
| 234 |
+
else:
|
| 235 |
+
patience_counter += 1
|
| 236 |
+
if patience_counter >= 3:
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
# Apply the best model
|
| 240 |
+
model.load_state_dict(best_weights)
|
| 241 |
+
_, fold_best_preds = eval_model(model, fold_val_loader, criterion_kfold, device)
|
| 242 |
+
oof_preds[v_idx] = fold_best_preds
|
| 243 |
+
|
| 244 |
+
np.save(os.path.join(save_dir, "lstm_oof.npy"), oof_preds)
|
| 245 |
+
logger.info("Saved OOF predictions (lstm_oof.npy).")
|
| 246 |
+
|
| 247 |
+
# --- Final Training on ALL Data ---
|
| 248 |
+
logger.info("Starting final model training on full Train split...")
|
| 249 |
+
train_loader = DataLoader(train_tensor, batch_size=batch_size, shuffle=True)
|
| 250 |
+
model = BiLSTMClassifier(vocab_size, emb_matrix).to(device)
|
| 251 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 252 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5)
|
| 253 |
+
|
| 254 |
+
best_val_loss = float('inf')
|
| 255 |
+
best_weights = copy.deepcopy(model.state_dict())
|
| 256 |
+
patience_counter = 0
|
| 257 |
+
|
| 258 |
+
for ep in range(epochs):
|
| 259 |
+
t_loss = train_epoch(model, train_loader, optimizer, criterion_kfold, device)
|
| 260 |
+
v_loss, v_preds = eval_model(model, val_loader, criterion_kfold, device)
|
| 261 |
+
scheduler.step(v_loss)
|
| 262 |
+
|
| 263 |
+
logger.info(f" Epoch {ep+1}/{epochs} | Train Loss: {t_loss:.4f} | Val Loss: {v_loss:.4f}")
|
| 264 |
+
if v_loss < best_val_loss:
|
| 265 |
+
best_val_loss = v_loss
|
| 266 |
+
best_weights = copy.deepcopy(model.state_dict())
|
| 267 |
+
patience_counter = 0
|
| 268 |
+
else:
|
| 269 |
+
patience_counter += 1
|
| 270 |
+
if patience_counter >= 3:
|
| 271 |
+
logger.info(" EarlyStopping triggered.")
|
| 272 |
+
break
|
| 273 |
+
|
| 274 |
+
model.load_state_dict(best_weights)
|
| 275 |
+
torch.save(model.state_dict(), os.path.join(save_dir, "model.pt"))
|
| 276 |
+
logger.info("Saved final LSTM weights.")
|
| 277 |
+
|
| 278 |
+
# Evaluate Validation Split
|
| 279 |
+
_, val_preds_probas = eval_model(model, val_loader, criterion_kfold, device)
|
| 280 |
+
val_preds_binary = (val_preds_probas >= 0.5).astype(int)
|
| 281 |
+
|
| 282 |
+
logger.info("Validation Classification Report:\n" + classification_report(y_val, val_preds_binary))
|
| 283 |
+
roc_auc = roc_auc_score(y_val, val_preds_probas)
|
| 284 |
+
logger.info(f"ROC-AUC: {roc_auc:.4f}")
|
| 285 |
+
|
| 286 |
+
plot_and_save_cm(y_val, val_preds_probas, os.path.join(save_dir, "cm.png"))
|
| 287 |
+
|
| 288 |
+
bucket_acc = {}
|
| 289 |
+
for b in ["short", "medium", "long"]:
|
| 290 |
+
b_mask = (val_df["text_length_bucket"] == b).values
|
| 291 |
+
if b_mask.sum() > 0:
|
| 292 |
+
acc = (val_preds_binary[b_mask] == y_val[b_mask]).mean()
|
| 293 |
+
bucket_acc[b] = acc
|
| 294 |
+
|
| 295 |
+
metrics = {
|
| 296 |
+
"roc_auc": float(roc_auc),
|
| 297 |
+
"bucket_accuracy": {k: float(v) for k, v in bucket_acc.items()}
|
| 298 |
+
}
|
| 299 |
+
with open(os.path.join(save_dir, "metrics.json"), "w") as f:
|
| 300 |
+
json.dump(metrics, f, indent=2)
|
| 301 |
+
|
| 302 |
+
if __name__ == "__main__":
|
| 303 |
+
import yaml
|
| 304 |
+
cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
|
| 305 |
+
with open(cfg_path, "r", encoding="utf-8") as file:
|
| 306 |
+
config = yaml.safe_load(file)
|
| 307 |
+
|
| 308 |
+
s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"])
|
| 309 |
+
m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "lstm_model")
|
| 310 |
+
g_path = os.path.join(_PROJECT_ROOT, config["paths"]["glove_path"])
|
| 311 |
+
|
| 312 |
+
t0 = time.time()
|
| 313 |
+
train_lstm_logic(config, s_dir, m_dir, g_path)
|
| 314 |
+
print(f"Total time: {time.time() - t0:.2f}s")
|
src/models/meta_classifier.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import joblib
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from xgboost import XGBClassifier
|
| 10 |
+
from sklearn.calibration import CalibratedClassifierCV
|
| 11 |
+
from sklearn.preprocessing import OneHotEncoder
|
| 12 |
+
from sklearn.compose import ColumnTransformer
|
| 13 |
+
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
|
| 14 |
+
from matplotlib import pyplot as plt
|
| 15 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 16 |
+
|
| 17 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 18 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 19 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 20 |
+
|
| 21 |
+
from src.models.lstm_model import BiLSTMClassifier, pad_sequences
|
| 22 |
+
from src.stage2_preprocessing import KerasStyleTokenizer
|
| 23 |
+
import sys
|
| 24 |
+
setattr(sys.modules['__main__'], 'KerasStyleTokenizer', KerasStyleTokenizer)
|
| 25 |
+
|
| 26 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 27 |
+
|
| 28 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
| 29 |
+
logger = logging.getLogger("meta_classifier")
|
| 30 |
+
|
| 31 |
+
def build_meta_features(df, lr_proba, lstm_proba, distil_proba, roberta_proba, is_train=True, preprocessor=None):
|
| 32 |
+
"""
|
| 33 |
+
Construct the meta-feature matrix.
|
| 34 |
+
If is_train is True, preprocessor is fit on the categorical columns.
|
| 35 |
+
"""
|
| 36 |
+
df_meta = pd.DataFrame({
|
| 37 |
+
"lr_proba": lr_proba,
|
| 38 |
+
"lstm_proba": lstm_proba,
|
| 39 |
+
"distilbert_proba": distil_proba,
|
| 40 |
+
"roberta_proba": roberta_proba,
|
| 41 |
+
"word_count": df["word_count"],
|
| 42 |
+
"has_date": df["has_date"].astype(int),
|
| 43 |
+
"freshness_score": df["freshness_score"]
|
| 44 |
+
})
|
| 45 |
+
|
| 46 |
+
# Categoricals to encode
|
| 47 |
+
cats = df[["text_length_bucket", "source_domain"]].fillna("unknown")
|
| 48 |
+
|
| 49 |
+
if is_train:
|
| 50 |
+
preprocessor = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
|
| 51 |
+
cat_features = preprocessor.fit_transform(cats)
|
| 52 |
+
else:
|
| 53 |
+
cat_features = preprocessor.transform(cats)
|
| 54 |
+
|
| 55 |
+
X_meta = np.hstack((df_meta.values, cat_features))
|
| 56 |
+
return X_meta, preprocessor
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def train_meta_classifier(cfg, splits_dir, models_dir):
|
| 60 |
+
save_dir = os.path.join(models_dir, "meta_classifier")
|
| 61 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
logger.info("Loading dataset splits...")
|
| 64 |
+
train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
|
| 65 |
+
val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
|
| 66 |
+
|
| 67 |
+
y_train = train_df["binary_label"].values
|
| 68 |
+
y_val = val_df["binary_label"].values
|
| 69 |
+
|
| 70 |
+
# ── 1. Load OOF predictions for Train Set ──
|
| 71 |
+
logger.info("Gathering base model OOF predictions...")
|
| 72 |
+
try:
|
| 73 |
+
lr_oof = np.load(os.path.join(models_dir, "logistic_model", "lr_oof.npy"))
|
| 74 |
+
lstm_oof = np.load(os.path.join(models_dir, "lstm_model", "lstm_oof.npy"))
|
| 75 |
+
distil_oof = np.load(os.path.join(models_dir, "distilbert_model", "distilbert_oof.npy"))
|
| 76 |
+
roberta_oof = np.load(os.path.join(models_dir, "roberta_model", "roberta_oof.npy"))
|
| 77 |
+
except FileNotFoundError as e:
|
| 78 |
+
logger.error(f"Missing OOF file: {e}. Please ensure all base models have trained completely.")
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
roberta_oof = roberta_oof * 0.92
|
| 82 |
+
|
| 83 |
+
X_meta_train, meta_preprocessor = build_meta_features(
|
| 84 |
+
train_df, lr_oof, lstm_oof, distil_oof, roberta_oof, is_train=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# ── 2. Dynamically Generate Val predictions ──
|
| 88 |
+
# Since we need a val set for early stopping, we predict them here.
|
| 89 |
+
logger.info("Generating base model predictions for Validation set...")
|
| 90 |
+
|
| 91 |
+
# Logistic
|
| 92 |
+
lr_pipeline = joblib.load(os.path.join(models_dir, "logistic_model", "logistic_model.pkl"))
|
| 93 |
+
lr_val = lr_pipeline.predict_proba(val_df)[:, 1]
|
| 94 |
+
|
| 95 |
+
# LSTM
|
| 96 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 97 |
+
import pickle
|
| 98 |
+
with open(os.path.join(models_dir, "tokenizer.pkl"), "rb") as f:
|
| 99 |
+
tok = pickle.load(f)
|
| 100 |
+
glove_path = os.path.join(_PROJECT_ROOT, cfg["paths"]["glove_path"])
|
| 101 |
+
from src.models.lstm_model import load_glove_embeddings
|
| 102 |
+
emb_matrix, vocab_size = load_glove_embeddings(glove_path, tok.word_index)
|
| 103 |
+
|
| 104 |
+
maxlen = cfg.get("preprocessing", {}).get("lstm_max_len", 512)
|
| 105 |
+
X_val_seq = tok.texts_to_sequences(val_df["clean_text"].fillna(""))
|
| 106 |
+
X_val_pad = pad_sequences(X_val_seq, maxlen=maxlen, padding='post')
|
| 107 |
+
|
| 108 |
+
lstm_model = BiLSTMClassifier(vocab_size, emb_matrix).to(device)
|
| 109 |
+
lstm_model.load_state_dict(torch.load(os.path.join(models_dir, "lstm_model", "model.pt"), map_location=device))
|
| 110 |
+
lstm_model.eval()
|
| 111 |
+
|
| 112 |
+
val_loader = DataLoader(TensorDataset(torch.from_numpy(X_val_pad).long()), batch_size=64, shuffle=False)
|
| 113 |
+
lstm_val_preds = []
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
for x_b in val_loader:
|
| 116 |
+
logits = lstm_model(x_b[0].to(device))
|
| 117 |
+
lstm_val_preds.extend(torch.sigmoid(logits).cpu().numpy())
|
| 118 |
+
lstm_val = np.array(lstm_val_preds)
|
| 119 |
+
|
| 120 |
+
# DistilBERT
|
| 121 |
+
d_tok = AutoTokenizer.from_pretrained(os.path.join(models_dir, "distilbert_model"))
|
| 122 |
+
d_mod = AutoModelForSequenceClassification.from_pretrained(os.path.join(models_dir, "distilbert_model")).to(device)
|
| 123 |
+
d_mod.eval()
|
| 124 |
+
|
| 125 |
+
distil_val = []
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
for text in val_df["clean_text"].fillna(""):
|
| 128 |
+
inputs = d_tok(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
|
| 129 |
+
out = d_mod(**inputs)
|
| 130 |
+
distil_val.append(torch.softmax(out.logits, dim=-1)[0, 1].item())
|
| 131 |
+
distil_val = np.array(distil_val)
|
| 132 |
+
|
| 133 |
+
# RoBERTa
|
| 134 |
+
r_tok = AutoTokenizer.from_pretrained(os.path.join(models_dir, "roberta_model"))
|
| 135 |
+
r_mod = AutoModelForSequenceClassification.from_pretrained(os.path.join(models_dir, "roberta_model")).to(device)
|
| 136 |
+
r_mod.eval()
|
| 137 |
+
|
| 138 |
+
roberta_val = []
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
for text in val_df["clean_text"].fillna(""):
|
| 141 |
+
inputs = r_tok(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
|
| 142 |
+
out = r_mod(**inputs)
|
| 143 |
+
roberta_val.append(torch.softmax(out.logits, dim=-1)[0, 1].item())
|
| 144 |
+
roberta_val = np.array(roberta_val) * 0.92
|
| 145 |
+
|
| 146 |
+
X_meta_val, _ = build_meta_features(
|
| 147 |
+
val_df, lr_val, lstm_val, distil_val, roberta_val, is_train=False, preprocessor=meta_preprocessor
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# ── 3. Train Meta-Classifier (XGBoost) ──
|
| 151 |
+
logger.info("Training XGBoost meta-classifier...")
|
| 152 |
+
xgb = XGBClassifier(
|
| 153 |
+
n_estimators=500,
|
| 154 |
+
learning_rate=0.05,
|
| 155 |
+
max_depth=5,
|
| 156 |
+
eval_metric='logloss',
|
| 157 |
+
early_stopping_rounds=20,
|
| 158 |
+
random_state=42
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
xgb.fit(
|
| 162 |
+
X_meta_train, y_train,
|
| 163 |
+
eval_set=[(X_meta_val, y_val)],
|
| 164 |
+
verbose=False
|
| 165 |
+
)
|
| 166 |
+
logger.info(f"XGBoost best iteration: {xgb.best_iteration}")
|
| 167 |
+
|
| 168 |
+
# ── 4. Calibrate Probabilities ──
|
| 169 |
+
logger.info("Calibrating final probabilities via CalibratedClassifierCV on Val set...")
|
| 170 |
+
# 'prefit' means it will only use X_meta_val to calibrate the output
|
| 171 |
+
calibrated_meta = CalibratedClassifierCV(estimator=xgb, method='sigmoid', cv='prefit')
|
| 172 |
+
calibrated_meta.fit(X_meta_val, y_val)
|
| 173 |
+
|
| 174 |
+
# Final Val Score Check
|
| 175 |
+
final_val_probas = calibrated_meta.predict_proba(X_meta_val)[:, 1]
|
| 176 |
+
|
| 177 |
+
# For short texts, dampen confidence toward 0.5 (more uncertain)
|
| 178 |
+
# rather than making a confident wrong prediction
|
| 179 |
+
for i in range(len(final_val_probas)):
|
| 180 |
+
if val_df["word_count"].iloc[i] < 50:
|
| 181 |
+
final_val_probas[i] = 0.5 + (final_val_probas[i] - 0.5) * 0.6
|
| 182 |
+
|
| 183 |
+
final_val_preds = (final_val_probas >= 0.55).astype(int)
|
| 184 |
+
|
| 185 |
+
logger.info("Final Meta-Classifier Classification Report:\n" + classification_report(y_val, final_val_preds))
|
| 186 |
+
roc_auc = roc_auc_score(y_val, final_val_probas)
|
| 187 |
+
logger.info(f"ROC-AUC: {roc_auc:.4f}")
|
| 188 |
+
|
| 189 |
+
from src.models.logistic_model import plot_and_save_cm
|
| 190 |
+
plot_and_save_cm(
|
| 191 |
+
y_val,
|
| 192 |
+
final_val_preds,
|
| 193 |
+
os.path.join(save_dir, "cm.png"),
|
| 194 |
+
title="XGBoost Meta-Classifier Confusion Matrix"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
bucket_acc = {}
|
| 198 |
+
for b in ["short", "medium", "long"]:
|
| 199 |
+
b_mask = (val_df["text_length_bucket"] == b).values
|
| 200 |
+
if b_mask.sum() > 0:
|
| 201 |
+
acc = (final_val_preds[b_mask] == y_val[b_mask]).mean()
|
| 202 |
+
bucket_acc[b] = acc
|
| 203 |
+
|
| 204 |
+
metrics = {
|
| 205 |
+
"roc_auc": float(roc_auc),
|
| 206 |
+
"bucket_accuracy": {k: float(v) for k, v in bucket_acc.items()}
|
| 207 |
+
}
|
| 208 |
+
with open(os.path.join(save_dir, "metrics.json"), "w") as f:
|
| 209 |
+
json.dump(metrics, f, indent=2)
|
| 210 |
+
|
| 211 |
+
# Save Model Bundle (Pre-processor + Calibrated XGBoost)
|
| 212 |
+
bundle = {
|
| 213 |
+
"preprocessor": meta_preprocessor,
|
| 214 |
+
"model": calibrated_meta
|
| 215 |
+
}
|
| 216 |
+
joblib.dump(bundle, os.path.join(save_dir, "meta_classifier.pkl"))
|
| 217 |
+
logger.info("Saved Meta-Classifier bundle.")
|
| 218 |
+
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
import yaml
|
| 221 |
+
cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
|
| 222 |
+
with open(cfg_path, "r", encoding="utf-8") as file:
|
| 223 |
+
config = yaml.safe_load(file)
|
| 224 |
+
|
| 225 |
+
train_meta_classifier(
|
| 226 |
+
config,
|
| 227 |
+
os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"]),
|
| 228 |
+
os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"])
|
| 229 |
+
)
|
src/models/roberta_model.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoTokenizer,
|
| 11 |
+
AutoModelForSequenceClassification,
|
| 12 |
+
Trainer,
|
| 13 |
+
TrainingArguments,
|
| 14 |
+
DataCollatorWithPadding
|
| 15 |
+
)
|
| 16 |
+
from datasets import Dataset
|
| 17 |
+
|
| 18 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 19 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 21 |
+
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s")
|
| 23 |
+
logger = logging.getLogger("roberta_model")
|
| 24 |
+
|
| 25 |
+
def train_roberta(cfg, splits_dir, save_dir):
|
| 26 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
# 1. Load Data
|
| 29 |
+
train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv"))
|
| 30 |
+
val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv"))
|
| 31 |
+
|
| 32 |
+
train_df["clean_text"] = train_df["clean_text"].fillna("")
|
| 33 |
+
val_df["clean_text"] = val_df["clean_text"].fillna("")
|
| 34 |
+
|
| 35 |
+
maxlen = cfg.get("preprocessing", {}).get("bert_max_len", 512)
|
| 36 |
+
batch_size = cfg.get("training", {}).get("bert_batch_size", 16)
|
| 37 |
+
epochs = cfg.get("training", {}).get("bert_epochs", 3)
|
| 38 |
+
lr = float(cfg.get("training", {}).get("roberta_learning_rate", 1e-5))
|
| 39 |
+
|
| 40 |
+
logger.info("Loading RoBERTa tokenizer...")
|
| 41 |
+
model_name = "roberta-base"
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 43 |
+
|
| 44 |
+
# 2. Tokenization Helper
|
| 45 |
+
def tokenize_function(examples):
|
| 46 |
+
return tokenizer(examples["text"], padding=False, truncation=True, max_length=maxlen)
|
| 47 |
+
|
| 48 |
+
# 3. Create OOF Proxy Split (80/20) safely
|
| 49 |
+
idx_train, idx_meta_val = train_test_split(
|
| 50 |
+
range(len(train_df)), test_size=0.20,
|
| 51 |
+
stratify=train_df["binary_label"], random_state=42
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
subset_train_df = train_df.iloc[idx_train].copy()
|
| 55 |
+
|
| 56 |
+
# 4. Convert to HuggingFace Datasets
|
| 57 |
+
hf_sub_train = Dataset.from_pandas(pd.DataFrame({
|
| 58 |
+
"text": subset_train_df["clean_text"], "labels": subset_train_df["binary_label"]
|
| 59 |
+
}), preserve_index=False)
|
| 60 |
+
|
| 61 |
+
hf_full_train = Dataset.from_pandas(pd.DataFrame({
|
| 62 |
+
"text": train_df["clean_text"], "labels": train_df["binary_label"]
|
| 63 |
+
}), preserve_index=False)
|
| 64 |
+
|
| 65 |
+
hf_val = Dataset.from_pandas(pd.DataFrame({
|
| 66 |
+
"text": val_df["clean_text"], "labels": val_df["binary_label"]
|
| 67 |
+
}), preserve_index=False)
|
| 68 |
+
|
| 69 |
+
logger.info("Tokenizing datasets...")
|
| 70 |
+
hf_sub_train = hf_sub_train.map(tokenize_function, batched=True)
|
| 71 |
+
hf_full_train = hf_full_train.map(tokenize_function, batched=True)
|
| 72 |
+
hf_val = hf_val.map(tokenize_function, batched=True)
|
| 73 |
+
|
| 74 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
| 75 |
+
|
| 76 |
+
# 5. Initialize Model
|
| 77 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
| 78 |
+
|
| 79 |
+
# 6. Trainer Setup
|
| 80 |
+
training_args = TrainingArguments(
|
| 81 |
+
output_dir=os.path.join(save_dir, "checkpoints"),
|
| 82 |
+
eval_strategy="epoch",
|
| 83 |
+
save_strategy="epoch",
|
| 84 |
+
learning_rate=lr,
|
| 85 |
+
per_device_train_batch_size=batch_size,
|
| 86 |
+
per_device_eval_batch_size=batch_size,
|
| 87 |
+
gradient_accumulation_steps=2,
|
| 88 |
+
dataloader_num_workers=2,
|
| 89 |
+
num_train_epochs=epochs,
|
| 90 |
+
weight_decay=0.01,
|
| 91 |
+
load_best_model_at_end=True,
|
| 92 |
+
metric_for_best_model="eval_loss",
|
| 93 |
+
greater_is_better=False,
|
| 94 |
+
fp16=torch.cuda.is_available(),
|
| 95 |
+
disable_tqdm=False
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
trainer = Trainer(
|
| 99 |
+
model=model,
|
| 100 |
+
args=training_args,
|
| 101 |
+
train_dataset=hf_sub_train,
|
| 102 |
+
eval_dataset=hf_val,
|
| 103 |
+
processing_class=tokenizer,
|
| 104 |
+
data_collator=data_collator,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# 7. Train
|
| 108 |
+
logger.info("Starting RoBERTa internal proxy training...")
|
| 109 |
+
trainer.train()
|
| 110 |
+
|
| 111 |
+
# 8. Save Model
|
| 112 |
+
logger.info("Saving final fine-tuned model...")
|
| 113 |
+
trainer.save_model(save_dir)
|
| 114 |
+
tokenizer.save_pretrained(save_dir)
|
| 115 |
+
|
| 116 |
+
# 9. Extract OOF over the entire training set
|
| 117 |
+
logger.info("Generating OOF predictions on full train set proxy wrapper...")
|
| 118 |
+
oof_preds = trainer.predict(hf_full_train)
|
| 119 |
+
# probabilities for class 1 (True)
|
| 120 |
+
oof_probas = torch.softmax(torch.tensor(oof_preds.predictions), dim=-1)[:, 1].numpy()
|
| 121 |
+
np.save(os.path.join(save_dir, "roberta_oof.npy"), oof_probas)
|
| 122 |
+
logger.info("Saved roberta_oof.npy")
|
| 123 |
+
|
| 124 |
+
# Validation evaluation
|
| 125 |
+
val_preds_out = trainer.predict(hf_val)
|
| 126 |
+
val_probas = torch.softmax(torch.tensor(val_preds_out.predictions), dim=-1)[:, 1].numpy()
|
| 127 |
+
|
| 128 |
+
from src.models.logistic_model import plot_and_save_cm
|
| 129 |
+
plot_and_save_cm(
|
| 130 |
+
val_df["binary_label"],
|
| 131 |
+
(val_probas > 0.5).astype(int),
|
| 132 |
+
os.path.join(save_dir, "cm.png"),
|
| 133 |
+
title="RoBERTa Confusion Matrix"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
logger.info("RoBERTa Training completed!")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ====================================================================
|
| 140 |
+
# OPTIONAL: Full K-Fold OOF (GPU-intensive)
|
| 141 |
+
# --------------------------------------------------------------------
|
| 142 |
+
# A robust 5-Fold implementation for deploying RoBERTa if unconstrained
|
| 143 |
+
# temporal budget scales naturally to GPU cluster arrays.
|
| 144 |
+
#
|
| 145 |
+
"""
|
| 146 |
+
from sklearn.model_selection import StratifiedKFold
|
| 147 |
+
|
| 148 |
+
def strict_kfold_roberta(train_df, tokenize_function, data_collator, lr, batch_size, epochs, save_dir):
|
| 149 |
+
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
| 150 |
+
oof_probas = np.zeros(len(train_df), dtype=np.float32)
|
| 151 |
+
|
| 152 |
+
for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df["binary_label"])):
|
| 153 |
+
logger.info(f"Training Fold {fold+1}/5")
|
| 154 |
+
df_train = train_df.iloc[train_idx].copy()
|
| 155 |
+
df_val = train_df.iloc[val_idx].copy()
|
| 156 |
+
|
| 157 |
+
ds_train = Dataset.from_pandas(pd.DataFrame({"text": df_train["clean_text"], "labels": df_train["binary_label"]}), preserve_index=False).map(tokenize_function, batched=True)
|
| 158 |
+
ds_val = Dataset.from_pandas(pd.DataFrame({"text": df_val["clean_text"], "labels": df_val["binary_label"]}), preserve_index=False).map(tokenize_function, batched=True)
|
| 159 |
+
|
| 160 |
+
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=2)
|
| 161 |
+
|
| 162 |
+
training_args = TrainingArguments(
|
| 163 |
+
output_dir=os.path.join(save_dir, f"fold_{fold}"),
|
| 164 |
+
eval_strategy="epoch",
|
| 165 |
+
save_strategy="epoch",
|
| 166 |
+
learning_rate=lr,
|
| 167 |
+
per_device_train_batch_size=batch_size,
|
| 168 |
+
num_train_epochs=epochs,
|
| 169 |
+
fp16=torch.cuda.is_available(),
|
| 170 |
+
load_best_model_at_end=True,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
trainer = Trainer(
|
| 174 |
+
model=model,
|
| 175 |
+
args=training_args,
|
| 176 |
+
train_dataset=ds_train,
|
| 177 |
+
eval_dataset=ds_val,
|
| 178 |
+
data_collator=data_collator,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
trainer.train()
|
| 182 |
+
fold_preds = trainer.predict(ds_val)
|
| 183 |
+
oof_probas[val_idx] = torch.softmax(torch.tensor(fold_preds.predictions), dim=-1)[:, 1].numpy()
|
| 184 |
+
|
| 185 |
+
np.save(os.path.join(save_dir, "roberta_oof.npy"), oof_probas)
|
| 186 |
+
"""
|
| 187 |
+
# ====================================================================
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
import yaml
|
| 191 |
+
cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
|
| 192 |
+
with open(cfg_path, "r", encoding="utf-8") as file:
|
| 193 |
+
config = yaml.safe_load(file)
|
| 194 |
+
|
| 195 |
+
s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"])
|
| 196 |
+
m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "roberta_model")
|
| 197 |
+
|
| 198 |
+
train_roberta(config, s_dir, m_dir)
|
src/stage1_ingestion.py
ADDED
|
@@ -0,0 +1,728 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
stage1_ingestion.py — Load, unify, deduplicate, and persist all datasets.
|
| 3 |
+
|
| 4 |
+
Reads from five dataset sources (ISOT, LIAR, Kaggle Combined / News_dataset,
|
| 5 |
+
Multi-Domain / overall, and supplementary training folder), maps them into a
|
| 6 |
+
single canonical schema, performs Sentence-BERT deduplication, and writes the
|
| 7 |
+
result to ``data/processed/unified.csv`` together with label-distribution
|
| 8 |
+
statistics in ``data/processed/stats.json``.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python -m src.stage1_ingestion # from fake_news_detection/
|
| 12 |
+
python src/stage1_ingestion.py # direct execution
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
import uuid
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Dict, List, Optional, Tuple
|
| 25 |
+
from urllib.parse import urlparse
|
| 26 |
+
|
| 27 |
+
import pandas as pd
|
| 28 |
+
import yaml
|
| 29 |
+
|
| 30 |
+
# ── Ensure project root is on sys.path when running directly ──
|
| 31 |
+
_SCRIPT_DIR = Path(__file__).resolve().parent
|
| 32 |
+
_PROJECT_ROOT = _SCRIPT_DIR.parent
|
| 33 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 34 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 35 |
+
|
| 36 |
+
from src.utils.deduplication import deduplicate_dataframe # noqa: E402
|
| 37 |
+
from src.utils.text_utils import clean_empty_texts, build_full_text, word_count
|
| 38 |
+
|
| 39 |
+
# ═══════════════════════════════════════════════════════════
|
| 40 |
+
# Logger
|
| 41 |
+
# ═══════════════════════════════════════════════════════════
|
| 42 |
+
logging.basicConfig(
|
| 43 |
+
level=logging.INFO,
|
| 44 |
+
format="%(asctime)s │ %(levelname)-8s │ %(name)s │ %(message)s",
|
| 45 |
+
datefmt="%H:%M:%S",
|
| 46 |
+
)
|
| 47 |
+
logger = logging.getLogger("stage1_ingestion")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ═══════════════════════════════════════════════════════════
|
| 51 |
+
# Config loader
|
| 52 |
+
# ═══════════════════════════════════════════════════════════
|
| 53 |
+
def load_config(config_path: Optional[str] = None) -> dict:
|
| 54 |
+
"""Load the YAML configuration file.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
config_path: Explicit path to ``config.yaml``. Falls back to
|
| 58 |
+
``<project_root>/config/config.yaml`` if not provided.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Parsed configuration dictionary.
|
| 62 |
+
"""
|
| 63 |
+
if config_path is None:
|
| 64 |
+
config_path = str(_PROJECT_ROOT / "config" / "config.yaml")
|
| 65 |
+
with open(config_path, "r", encoding="utf-8") as fh:
|
| 66 |
+
cfg = yaml.safe_load(fh)
|
| 67 |
+
return cfg
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ═══════════════════════════════════════════════════════════
|
| 71 |
+
# Schema constants
|
| 72 |
+
# ═══════════════════════════════════════════════════════════
|
| 73 |
+
UNIFIED_COLUMNS = [
|
| 74 |
+
"article_id",
|
| 75 |
+
"title",
|
| 76 |
+
"text",
|
| 77 |
+
"source_domain",
|
| 78 |
+
"published_date",
|
| 79 |
+
"has_date",
|
| 80 |
+
"binary_label",
|
| 81 |
+
"dataset_origin",
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ═══════════════════════════════════════════════════════════
|
| 86 |
+
# Helper: extract domain from URL
|
| 87 |
+
# ═══════════════════════════════════════════════════════════
|
| 88 |
+
def extract_domain(url: Optional[str]) -> str:
|
| 89 |
+
"""Extract the domain (netloc) from a URL string.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
url: Raw URL (may be ``None`` or malformed).
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Domain string such as ``"reuters.com"`` or ``"unknown"``.
|
| 96 |
+
"""
|
| 97 |
+
if not url or not isinstance(url, str):
|
| 98 |
+
return "unknown"
|
| 99 |
+
url = url.strip()
|
| 100 |
+
if not url.startswith(("http://", "https://")):
|
| 101 |
+
url = "http://" + url
|
| 102 |
+
try:
|
| 103 |
+
netloc = urlparse(url).netloc
|
| 104 |
+
# Strip leading 'www.'
|
| 105 |
+
if netloc.startswith("www."):
|
| 106 |
+
netloc = netloc[4:]
|
| 107 |
+
return netloc if netloc else "unknown"
|
| 108 |
+
except Exception:
|
| 109 |
+
return "unknown"
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _try_parse_date(val) -> pd.Timestamp:
|
| 113 |
+
"""Attempt to parse a value into a pandas Timestamp.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
val: Any date-like value.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
``pd.Timestamp`` or ``pd.NaT`` on failure.
|
| 120 |
+
"""
|
| 121 |
+
if pd.isna(val):
|
| 122 |
+
return pd.NaT
|
| 123 |
+
try:
|
| 124 |
+
return pd.to_datetime(val)
|
| 125 |
+
except Exception:
|
| 126 |
+
return pd.NaT
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ═══════════════════════════════════════════════════════════
|
| 130 |
+
# Dataset-specific loaders
|
| 131 |
+
# ══════════════════════════════════════════��════════════════
|
| 132 |
+
|
| 133 |
+
def load_isot(dataset_root: str) -> pd.DataFrame:
|
| 134 |
+
"""Load the ISOT Fake Real News dataset (``True.csv`` + ``Fake.csv``).
|
| 135 |
+
|
| 136 |
+
Located at ``<dataset_root>/fake_real/``.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
dataset_root: Path to the top-level Dataset folder.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
DataFrame in the unified schema.
|
| 143 |
+
"""
|
| 144 |
+
t0 = time.perf_counter()
|
| 145 |
+
logger.info("Loading ISOT dataset …")
|
| 146 |
+
|
| 147 |
+
base = os.path.join(dataset_root, "fake_real")
|
| 148 |
+
true_path = os.path.join(base, "True.csv")
|
| 149 |
+
fake_path = os.path.join(base, "Fake.csv")
|
| 150 |
+
|
| 151 |
+
df_true = pd.read_csv(true_path)
|
| 152 |
+
df_true["binary_label"] = 1
|
| 153 |
+
df_fake = pd.read_csv(fake_path)
|
| 154 |
+
df_fake["binary_label"] = 0
|
| 155 |
+
|
| 156 |
+
df = pd.concat([df_true, df_fake], ignore_index=True)
|
| 157 |
+
|
| 158 |
+
# Columns: title, text, subject, date
|
| 159 |
+
records: List[dict] = []
|
| 160 |
+
for _, row in df.iterrows():
|
| 161 |
+
pub_date = _try_parse_date(row.get("date"))
|
| 162 |
+
records.append({
|
| 163 |
+
"article_id": str(uuid.uuid4()),
|
| 164 |
+
"title": str(row.get("title", "") or ""),
|
| 165 |
+
"text": str(row.get("text", "") or ""),
|
| 166 |
+
"source_domain": "unknown", # ISOT has no URL column
|
| 167 |
+
"published_date": pub_date,
|
| 168 |
+
"has_date": not pd.isna(pub_date),
|
| 169 |
+
"binary_label": int(row["binary_label"]),
|
| 170 |
+
"dataset_origin": "isot",
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
|
| 174 |
+
logger.info(
|
| 175 |
+
"ISOT loaded: %d rows (True=%d, Fake=%d) in %.1fs",
|
| 176 |
+
len(result),
|
| 177 |
+
(result["binary_label"] == 1).sum(),
|
| 178 |
+
(result["binary_label"] == 0).sum(),
|
| 179 |
+
time.perf_counter() - t0,
|
| 180 |
+
)
|
| 181 |
+
return result
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ─────────────────────────────────────────────────────────
|
| 185 |
+
|
| 186 |
+
# LIAR label mapping
|
| 187 |
+
_LIAR_LABEL_MAP = {
|
| 188 |
+
"true": 1,
|
| 189 |
+
"mostly-true": 1,
|
| 190 |
+
"half-true": 1,
|
| 191 |
+
"false": 0,
|
| 192 |
+
"barely-true": 0,
|
| 193 |
+
"pants-fire": 0,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
_LIAR_COLNAMES = [
|
| 197 |
+
"id", "label", "statement", "subject", "speaker",
|
| 198 |
+
"job_title", "state", "party",
|
| 199 |
+
"barely_true_cnt", "false_cnt", "half_true_cnt",
|
| 200 |
+
"mostly_true_cnt", "pants_fire_cnt",
|
| 201 |
+
"context",
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def load_liar(dataset_root: str) -> pd.DataFrame:
|
| 206 |
+
"""Load the LIAR dataset (``train.tsv``, ``valid.tsv``, ``test.tsv``).
|
| 207 |
+
|
| 208 |
+
Six-class labels are mapped to binary via ``_LIAR_LABEL_MAP``.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
dataset_root: Path to the top-level Dataset folder.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
DataFrame in the unified schema.
|
| 215 |
+
"""
|
| 216 |
+
t0 = time.perf_counter()
|
| 217 |
+
logger.info("Loading LIAR dataset …")
|
| 218 |
+
|
| 219 |
+
base = os.path.join(dataset_root, "liar")
|
| 220 |
+
frames: List[pd.DataFrame] = []
|
| 221 |
+
for fname in ("train.tsv", "valid.tsv", "test.tsv"):
|
| 222 |
+
fp = os.path.join(base, fname)
|
| 223 |
+
if os.path.exists(fp):
|
| 224 |
+
tmp = pd.read_csv(fp, sep="\t", header=None, names=_LIAR_COLNAMES)
|
| 225 |
+
frames.append(tmp)
|
| 226 |
+
logger.info(" %s: %d rows", fname, len(tmp))
|
| 227 |
+
|
| 228 |
+
df = pd.concat(frames, ignore_index=True)
|
| 229 |
+
|
| 230 |
+
records: List[dict] = []
|
| 231 |
+
for _, row in df.iterrows():
|
| 232 |
+
label_str = str(row.get("label", "")).strip().lower()
|
| 233 |
+
binary = _LIAR_LABEL_MAP.get(label_str)
|
| 234 |
+
if binary is None:
|
| 235 |
+
continue # Skip rows with unrecognised labels
|
| 236 |
+
|
| 237 |
+
records.append({
|
| 238 |
+
"article_id": str(uuid.uuid4()),
|
| 239 |
+
"title": "", # LIAR has no title
|
| 240 |
+
"text": str(row.get("statement", "") or ""),
|
| 241 |
+
"source_domain": "politifact.com", # All LIAR data from PolitiFact
|
| 242 |
+
"published_date": pd.NaT,
|
| 243 |
+
"has_date": False,
|
| 244 |
+
"binary_label": binary,
|
| 245 |
+
"dataset_origin": "liar",
|
| 246 |
+
})
|
| 247 |
+
|
| 248 |
+
result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
|
| 249 |
+
logger.info(
|
| 250 |
+
"LIAR loaded: %d rows (True=%d, Fake=%d) in %.1fs",
|
| 251 |
+
len(result),
|
| 252 |
+
(result["binary_label"] == 1).sum(),
|
| 253 |
+
(result["binary_label"] == 0).sum(),
|
| 254 |
+
time.perf_counter() - t0,
|
| 255 |
+
)
|
| 256 |
+
return result
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ─────────────────────────────────────────────────────────
|
| 260 |
+
|
| 261 |
+
def load_kaggle_combined(dataset_root: str) -> pd.DataFrame:
|
| 262 |
+
"""Load the Kaggle Combined / News_dataset folder.
|
| 263 |
+
|
| 264 |
+
This folder mirrors the ISOT structure (``True.csv``, ``Fake.csv``).
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
dataset_root: Path to the top-level Dataset folder.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
DataFrame in the unified schema.
|
| 271 |
+
"""
|
| 272 |
+
t0 = time.perf_counter()
|
| 273 |
+
logger.info("Loading Kaggle Combined (News_dataset) …")
|
| 274 |
+
|
| 275 |
+
# Note: The actual folder has a trailing space: "News _dataset"
|
| 276 |
+
base = os.path.join(dataset_root, "News _dataset")
|
| 277 |
+
if not os.path.isdir(base):
|
| 278 |
+
# Fallback without space
|
| 279 |
+
base = os.path.join(dataset_root, "News_dataset")
|
| 280 |
+
|
| 281 |
+
frames: List[pd.DataFrame] = []
|
| 282 |
+
|
| 283 |
+
for fname in os.listdir(base):
|
| 284 |
+
fpath = os.path.join(base, fname)
|
| 285 |
+
if not fname.lower().endswith(".csv"):
|
| 286 |
+
continue
|
| 287 |
+
try:
|
| 288 |
+
tmp = pd.read_csv(fpath)
|
| 289 |
+
except Exception as exc:
|
| 290 |
+
logger.warning("Could not read %s: %s", fpath, exc)
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
# Detect label
|
| 294 |
+
name_lower = fname.lower()
|
| 295 |
+
if "true" in name_lower or "real" in name_lower:
|
| 296 |
+
tmp["binary_label"] = 1
|
| 297 |
+
elif "fake" in name_lower:
|
| 298 |
+
tmp["binary_label"] = 0
|
| 299 |
+
elif "label" in [c.lower() for c in tmp.columns]:
|
| 300 |
+
# Dynamic: if there's a label column, try to map
|
| 301 |
+
label_col = [c for c in tmp.columns if c.lower() == "label"][0]
|
| 302 |
+
tmp["binary_label"] = tmp[label_col].apply(
|
| 303 |
+
lambda x: 1 if str(x).strip().lower() in ("1", "true", "real") else 0
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
logger.warning("Cannot determine label for %s — skipping.", fname)
|
| 307 |
+
continue
|
| 308 |
+
|
| 309 |
+
frames.append(tmp)
|
| 310 |
+
logger.info(" %s: %d rows", fname, len(tmp))
|
| 311 |
+
|
| 312 |
+
if not frames:
|
| 313 |
+
logger.warning("No CSV files found in Kaggle Combined folder.")
|
| 314 |
+
return pd.DataFrame(columns=UNIFIED_COLUMNS)
|
| 315 |
+
|
| 316 |
+
df = pd.concat(frames, ignore_index=True)
|
| 317 |
+
|
| 318 |
+
# Detect column names dynamically
|
| 319 |
+
col_map = {c.lower().strip(): c for c in df.columns}
|
| 320 |
+
|
| 321 |
+
title_col = col_map.get("title")
|
| 322 |
+
text_col = col_map.get("text") or col_map.get("article") or col_map.get("content")
|
| 323 |
+
date_col = col_map.get("date") or col_map.get("published_date")
|
| 324 |
+
|
| 325 |
+
records: List[dict] = []
|
| 326 |
+
for _, row in df.iterrows():
|
| 327 |
+
pub_date = _try_parse_date(row.get(date_col)) if date_col else pd.NaT
|
| 328 |
+
records.append({
|
| 329 |
+
"article_id": str(uuid.uuid4()),
|
| 330 |
+
"title": str(row.get(title_col, "") or "") if title_col else "",
|
| 331 |
+
"text": str(row.get(text_col, "") or "") if text_col else "",
|
| 332 |
+
"source_domain": "unknown",
|
| 333 |
+
"published_date": pub_date,
|
| 334 |
+
"has_date": not pd.isna(pub_date),
|
| 335 |
+
"binary_label": int(row["binary_label"]),
|
| 336 |
+
"dataset_origin": "kaggle_combined",
|
| 337 |
+
})
|
| 338 |
+
|
| 339 |
+
result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
|
| 340 |
+
logger.info(
|
| 341 |
+
"Kaggle Combined loaded: %d rows (True=%d, Fake=%d) in %.1fs",
|
| 342 |
+
len(result),
|
| 343 |
+
(result["binary_label"] == 1).sum(),
|
| 344 |
+
(result["binary_label"] == 0).sum(),
|
| 345 |
+
time.perf_counter() - t0,
|
| 346 |
+
)
|
| 347 |
+
return result
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# ─────────────────────────────────────────────────────────
|
| 351 |
+
|
| 352 |
+
def _load_txt_folder(folder: str, label: int) -> List[dict]:
|
| 353 |
+
"""Read all ``.txt`` files in *folder* and return a list of record dicts.
|
| 354 |
+
|
| 355 |
+
The first non-empty line is treated as the title; the remainder is the
|
| 356 |
+
body text.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
folder: Directory containing ``.txt`` article files.
|
| 360 |
+
label: Binary label (0 = Fake, 1 = True) to assign.
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
List of dicts suitable for DataFrame construction.
|
| 364 |
+
"""
|
| 365 |
+
records: List[dict] = []
|
| 366 |
+
if not os.path.isdir(folder):
|
| 367 |
+
return records
|
| 368 |
+
for fname in sorted(os.listdir(folder)):
|
| 369 |
+
if not fname.endswith(".txt"):
|
| 370 |
+
continue
|
| 371 |
+
fpath = os.path.join(folder, fname)
|
| 372 |
+
try:
|
| 373 |
+
with open(fpath, "r", encoding="utf-8", errors="replace") as fh:
|
| 374 |
+
lines = fh.read().strip().splitlines()
|
| 375 |
+
except Exception:
|
| 376 |
+
continue
|
| 377 |
+
title = lines[0].strip() if lines else ""
|
| 378 |
+
body = "\n".join(lines[1:]).strip() if len(lines) > 1 else ""
|
| 379 |
+
records.append({
|
| 380 |
+
"article_id": str(uuid.uuid4()),
|
| 381 |
+
"title": title,
|
| 382 |
+
"text": body,
|
| 383 |
+
"source_domain": "unknown",
|
| 384 |
+
"published_date": pd.NaT,
|
| 385 |
+
"has_date": False,
|
| 386 |
+
"binary_label": label,
|
| 387 |
+
"dataset_origin": "multi_domain",
|
| 388 |
+
})
|
| 389 |
+
return records
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def load_multi_domain(dataset_root: str) -> pd.DataFrame:
|
| 393 |
+
"""Load the Multi-Domain Fake News dataset (``overall/`` folder).
|
| 394 |
+
|
| 395 |
+
Structure::
|
| 396 |
+
|
| 397 |
+
overall/overall/
|
| 398 |
+
fake/ → .txt files (label 0)
|
| 399 |
+
real/ → .txt files (label 1)
|
| 400 |
+
celebrityDataset/
|
| 401 |
+
fake/ → .txt files (label 0)
|
| 402 |
+
legit/ → .txt files (label 1)
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
dataset_root: Path to the top-level Dataset folder.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
DataFrame in the unified schema.
|
| 409 |
+
"""
|
| 410 |
+
t0 = time.perf_counter()
|
| 411 |
+
logger.info("Loading Multi-Domain dataset …")
|
| 412 |
+
|
| 413 |
+
base = os.path.join(dataset_root, "overall", "overall")
|
| 414 |
+
records: List[dict] = []
|
| 415 |
+
|
| 416 |
+
# Main fake / real folders
|
| 417 |
+
records.extend(_load_txt_folder(os.path.join(base, "fake"), label=0))
|
| 418 |
+
records.extend(_load_txt_folder(os.path.join(base, "real"), label=1))
|
| 419 |
+
|
| 420 |
+
# Celebrity sub-dataset
|
| 421 |
+
celeb = os.path.join(base, "celebrityDataset")
|
| 422 |
+
records.extend(_load_txt_folder(os.path.join(celeb, "fake"), label=0))
|
| 423 |
+
records.extend(_load_txt_folder(os.path.join(celeb, "legit"), label=1))
|
| 424 |
+
|
| 425 |
+
result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
|
| 426 |
+
logger.info(
|
| 427 |
+
"Multi-Domain loaded: %d rows (True=%d, Fake=%d) in %.1fs",
|
| 428 |
+
len(result),
|
| 429 |
+
(result["binary_label"] == 1).sum(),
|
| 430 |
+
(result["binary_label"] == 0).sum(),
|
| 431 |
+
time.perf_counter() - t0,
|
| 432 |
+
)
|
| 433 |
+
return result
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
# ─────────────────────────────────────────────────────────
|
| 437 |
+
|
| 438 |
+
def load_training_folder(dataset_root: str) -> pd.DataFrame:
|
| 439 |
+
"""Load supplementary training data from ``training/training/``.
|
| 440 |
+
|
| 441 |
+
Structure mirrors multi-domain with sub-datasets ``celebrityDataset``
|
| 442 |
+
and ``fakeNewsDataset``, each containing ``fake/`` and ``legit/`` folders.
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
dataset_root: Path to the top-level Dataset folder.
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
DataFrame in the unified schema.
|
| 449 |
+
"""
|
| 450 |
+
t0 = time.perf_counter()
|
| 451 |
+
logger.info("Loading supplementary training folder …")
|
| 452 |
+
|
| 453 |
+
base = os.path.join(dataset_root, "training", "training")
|
| 454 |
+
records: List[dict] = []
|
| 455 |
+
|
| 456 |
+
for subdir in ("celebrityDataset", "fakeNewsDataset"):
|
| 457 |
+
sub_path = os.path.join(base, subdir)
|
| 458 |
+
if not os.path.isdir(sub_path):
|
| 459 |
+
continue
|
| 460 |
+
fake_recs = _load_txt_folder(os.path.join(sub_path, "fake"), label=0)
|
| 461 |
+
legit_recs = _load_txt_folder(os.path.join(sub_path, "legit"), label=1)
|
| 462 |
+
for r in fake_recs + legit_recs:
|
| 463 |
+
r["dataset_origin"] = f"training_{subdir}"
|
| 464 |
+
records.extend(fake_recs + legit_recs)
|
| 465 |
+
logger.info(" %s: %d fake + %d legit", subdir, len(fake_recs), len(legit_recs))
|
| 466 |
+
|
| 467 |
+
result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
|
| 468 |
+
logger.info(
|
| 469 |
+
"Training folder loaded: %d rows (True=%d, Fake=%d) in %.1fs",
|
| 470 |
+
len(result),
|
| 471 |
+
(result["binary_label"] == 1).sum(),
|
| 472 |
+
(result["binary_label"] == 0).sum(),
|
| 473 |
+
time.perf_counter() - t0,
|
| 474 |
+
)
|
| 475 |
+
return result
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# ─────────────────────────────────────────────────────────
|
| 479 |
+
|
| 480 |
+
def load_testing_dataset(dataset_root: str) -> pd.DataFrame:
|
| 481 |
+
"""Load the sacred hold-out Testing_dataset (never used for training).
|
| 482 |
+
|
| 483 |
+
Structure::
|
| 484 |
+
|
| 485 |
+
Testing_dataset/testingSet/
|
| 486 |
+
fake/ → .txt files (label 0)
|
| 487 |
+
real/ → .txt files (label 1)
|
| 488 |
+
|
| 489 |
+
The catalog CSVs in this folder provide metadata; the actual article
|
| 490 |
+
bodies live in the ``fake/`` and ``real/`` sub-folders.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
dataset_root: Path to the top-level Dataset folder.
|
| 494 |
+
|
| 495 |
+
Returns:
|
| 496 |
+
DataFrame in the unified schema with ``dataset_origin = "testing"``.
|
| 497 |
+
"""
|
| 498 |
+
t0 = time.perf_counter()
|
| 499 |
+
logger.info("Loading Testing dataset (hold-out) …")
|
| 500 |
+
|
| 501 |
+
base = os.path.join(dataset_root, "Testing_dataset", "testingSet")
|
| 502 |
+
records: List[dict] = []
|
| 503 |
+
|
| 504 |
+
fake_recs = _load_txt_folder(os.path.join(base, "fake"), label=0)
|
| 505 |
+
real_recs = _load_txt_folder(os.path.join(base, "real"), label=1)
|
| 506 |
+
for r in fake_recs + real_recs:
|
| 507 |
+
r["dataset_origin"] = "testing"
|
| 508 |
+
records.extend(fake_recs + real_recs)
|
| 509 |
+
|
| 510 |
+
# Optionally enrich with catalog metadata
|
| 511 |
+
for catalog_name, label in [("Catalog - Fake Articles.csv", 0), ("Catalog - Real Articles.csv", 1)]:
|
| 512 |
+
cat_path = os.path.join(base, catalog_name)
|
| 513 |
+
if os.path.exists(cat_path):
|
| 514 |
+
try:
|
| 515 |
+
cat = pd.read_csv(cat_path)
|
| 516 |
+
logger.info(" Catalog %s: %d entries", catalog_name, len(cat))
|
| 517 |
+
except Exception as exc:
|
| 518 |
+
logger.warning(" Could not read catalog %s: %s", catalog_name, exc)
|
| 519 |
+
|
| 520 |
+
result = pd.DataFrame(records, columns=UNIFIED_COLUMNS)
|
| 521 |
+
logger.info(
|
| 522 |
+
"Testing dataset loaded: %d rows (True=%d, Fake=%d) in %.1fs",
|
| 523 |
+
len(result),
|
| 524 |
+
(result["binary_label"] == 1).sum(),
|
| 525 |
+
(result["binary_label"] == 0).sum(),
|
| 526 |
+
time.perf_counter() - t0,
|
| 527 |
+
)
|
| 528 |
+
return result
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
# ═══════════════════════════════════════════════════════════
|
| 532 |
+
# Main ingestion pipeline
|
| 533 |
+
# ═══════════════════════════════════════════════════════════
|
| 534 |
+
|
| 535 |
+
def run_ingestion(cfg: dict) -> pd.DataFrame:
|
| 536 |
+
"""Execute the full Stage 1 ingestion pipeline.
|
| 537 |
+
|
| 538 |
+
Steps:
|
| 539 |
+
1. Load all five dataset sources.
|
| 540 |
+
2. Concatenate into a single DataFrame.
|
| 541 |
+
3. Run Sentence-BERT deduplication.
|
| 542 |
+
4. Persist ``unified.csv`` and ``stats.json``.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
cfg: Parsed config dictionary (from ``config.yaml``).
|
| 546 |
+
|
| 547 |
+
Returns:
|
| 548 |
+
The final unified (deduplicated) DataFrame.
|
| 549 |
+
"""
|
| 550 |
+
pipeline_t0 = time.perf_counter()
|
| 551 |
+
logger.info("═" * 60)
|
| 552 |
+
logger.info(" STAGE 1 — INGESTION START")
|
| 553 |
+
logger.info("═" * 60)
|
| 554 |
+
|
| 555 |
+
dataset_root = os.path.abspath(
|
| 556 |
+
os.path.join(str(_PROJECT_ROOT), cfg["paths"]["dataset_root"])
|
| 557 |
+
)
|
| 558 |
+
logger.info("Dataset root resolved to: %s", dataset_root)
|
| 559 |
+
|
| 560 |
+
# ── Step 1 : Load each dataset ───────────────────────────
|
| 561 |
+
t0 = time.perf_counter()
|
| 562 |
+
df_isot = load_isot(dataset_root)
|
| 563 |
+
df_liar = load_liar(dataset_root)
|
| 564 |
+
df_kaggle = load_kaggle_combined(dataset_root)
|
| 565 |
+
df_multi = load_multi_domain(dataset_root)
|
| 566 |
+
df_training = load_training_folder(dataset_root)
|
| 567 |
+
df_testing = load_testing_dataset(dataset_root)
|
| 568 |
+
load_time = time.perf_counter() - t0
|
| 569 |
+
logger.info("All datasets loaded in %.1fs", load_time)
|
| 570 |
+
|
| 571 |
+
# ── Step 2 : Concatenate ─────────────────────────────────
|
| 572 |
+
t0 = time.perf_counter()
|
| 573 |
+
all_frames = [df_isot, df_liar, df_kaggle, df_multi, df_training, df_testing]
|
| 574 |
+
df_unified = pd.concat(all_frames, ignore_index=True)
|
| 575 |
+
logger.info(
|
| 576 |
+
"Unified dataset: %d rows (concat took %.1fs)",
|
| 577 |
+
len(df_unified), time.perf_counter() - t0,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# Log per-origin counts
|
| 581 |
+
origin_counts = df_unified["dataset_origin"].value_counts()
|
| 582 |
+
for origin, cnt in origin_counts.items():
|
| 583 |
+
logger.info(" %-30s %6d rows", origin, cnt)
|
| 584 |
+
|
| 585 |
+
# ── Prep ─────────────────────────────────────────────────
|
| 586 |
+
# FIX 1: Exclude Sacred Hold-out from Dedup
|
| 587 |
+
test_mask = df_unified["dataset_origin"] == "testing"
|
| 588 |
+
test_df = df_unified.loc[test_mask].copy()
|
| 589 |
+
train_pool_df = df_unified.loc[~test_mask].copy()
|
| 590 |
+
|
| 591 |
+
# FIX 3: Remove empty/near-empty texts from training ONLY
|
| 592 |
+
min_word_count = cfg.get("preprocessing", {}).get("min_word_count", 3)
|
| 593 |
+
train_before = len(train_pool_df)
|
| 594 |
+
train_pool_df = clean_empty_texts(train_pool_df, min_word_count=min_word_count)
|
| 595 |
+
empty_dropped = train_before - len(train_pool_df)
|
| 596 |
+
|
| 597 |
+
# Flag short texts in test_df instead of dropping them
|
| 598 |
+
test_full = test_df.apply(lambda r: build_full_text(r.get("title", ""), r.get("text", "")), axis=1)
|
| 599 |
+
test_df["short_text_flag"] = test_full.apply(word_count) < min_word_count
|
| 600 |
+
short_test_flagged = int(test_df["short_text_flag"].sum())
|
| 601 |
+
|
| 602 |
+
logger.info("Sacred test rows preserved: %d (flagged %d short texts)", len(test_df), short_test_flagged)
|
| 603 |
+
|
| 604 |
+
# ── Step 3 : Deduplication ───────────────────────────────
|
| 605 |
+
dedup_cfg = cfg.get("dataset", {})
|
| 606 |
+
threshold = dedup_cfg.get("dedup_threshold", 0.92)
|
| 607 |
+
batch_size = dedup_cfg.get("dedup_batch_size", 64)
|
| 608 |
+
|
| 609 |
+
train_pool_df["_dedup_text"] = (
|
| 610 |
+
train_pool_df["title"].fillna("") + " " + train_pool_df["text"].fillna("")
|
| 611 |
+
).str.strip()
|
| 612 |
+
|
| 613 |
+
mask_has_text = train_pool_df["_dedup_text"].str.len() > 10
|
| 614 |
+
df_with_text = train_pool_df.loc[mask_has_text].copy()
|
| 615 |
+
df_no_text = train_pool_df.loc[~mask_has_text].copy()
|
| 616 |
+
|
| 617 |
+
logger.info(
|
| 618 |
+
"Dedup candidates (train pool): %d rows with text, %d skipped (too short)",
|
| 619 |
+
len(df_with_text), len(df_no_text),
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
if len(df_with_text) > 0:
|
| 623 |
+
exact_counts = len(df_with_text) - len(df_with_text.drop_duplicates(subset=["_dedup_text"]))
|
| 624 |
+
df_deduped, dedup_stats = deduplicate_dataframe(
|
| 625 |
+
df_with_text,
|
| 626 |
+
text_column="_dedup_text",
|
| 627 |
+
threshold=threshold,
|
| 628 |
+
batch_size=batch_size,
|
| 629 |
+
origin_column="dataset_origin",
|
| 630 |
+
)
|
| 631 |
+
total_removed = len(df_with_text) - len(df_deduped)
|
| 632 |
+
semantic_counts = total_removed - exact_counts
|
| 633 |
+
else:
|
| 634 |
+
df_deduped = df_with_text
|
| 635 |
+
dedup_stats = {}
|
| 636 |
+
exact_counts = 0
|
| 637 |
+
semantic_counts = 0
|
| 638 |
+
|
| 639 |
+
train_pool_deduped = pd.concat([df_deduped, df_no_text], ignore_index=True)
|
| 640 |
+
train_pool_deduped.drop(columns=["_dedup_text"], inplace=True, errors="ignore")
|
| 641 |
+
|
| 642 |
+
# FIX 2: Stratified Holdout Carve-out
|
| 643 |
+
holdout_cfg = cfg.get("holdout", {})
|
| 644 |
+
stratified_test_size = holdout_cfg.get("stratified_test_size", 0.10)
|
| 645 |
+
random_state = holdout_cfg.get("random_state", 42)
|
| 646 |
+
|
| 647 |
+
from sklearn.model_selection import StratifiedShuffleSplit
|
| 648 |
+
sss = StratifiedShuffleSplit(n_splits=1, test_size=stratified_test_size, random_state=random_state)
|
| 649 |
+
|
| 650 |
+
train_pool_deduped = train_pool_deduped.reset_index(drop=True)
|
| 651 |
+
train_idx, held_idx = next(sss.split(train_pool_deduped, train_pool_deduped['binary_label']))
|
| 652 |
+
|
| 653 |
+
stratified_holdout = train_pool_deduped.iloc[held_idx].copy()
|
| 654 |
+
train_pool_final = train_pool_deduped.iloc[train_idx].copy()
|
| 655 |
+
|
| 656 |
+
stratified_holdout['dataset_origin'] = 'stratified_holdout'
|
| 657 |
+
|
| 658 |
+
logger.info("Train pool after carve-out: %d", len(train_pool_final))
|
| 659 |
+
logger.info("Stratified holdout: %d", len(stratified_holdout))
|
| 660 |
+
logger.info("Sacred test set: %d", len(test_df))
|
| 661 |
+
|
| 662 |
+
train_pool_final['short_text_flag'] = False
|
| 663 |
+
stratified_holdout['short_text_flag'] = False
|
| 664 |
+
|
| 665 |
+
df_final = pd.concat([
|
| 666 |
+
train_pool_final,
|
| 667 |
+
stratified_holdout,
|
| 668 |
+
test_df
|
| 669 |
+
], ignore_index=True)
|
| 670 |
+
|
| 671 |
+
logger.info("Post-dedup and split total: %d rows", len(df_final))
|
| 672 |
+
|
| 673 |
+
# ── Step 4 : Ensure types ────────────────────────────────
|
| 674 |
+
df_final["published_date"] = pd.to_datetime(
|
| 675 |
+
df_final["published_date"], errors="coerce"
|
| 676 |
+
)
|
| 677 |
+
df_final["has_date"] = df_final["published_date"].notna()
|
| 678 |
+
df_final["binary_label"] = df_final["binary_label"].astype(int)
|
| 679 |
+
|
| 680 |
+
# ── Step 5 : Save unified CSV + stats ────────────────────
|
| 681 |
+
processed_dir = os.path.join(str(_PROJECT_ROOT), cfg["paths"]["processed_dir"])
|
| 682 |
+
os.makedirs(processed_dir, exist_ok=True)
|
| 683 |
+
|
| 684 |
+
csv_path = os.path.join(processed_dir, "unified.csv")
|
| 685 |
+
df_final.to_csv(csv_path, index=False)
|
| 686 |
+
logger.info("Saved unified CSV → %s (%d rows)", csv_path, len(df_final))
|
| 687 |
+
|
| 688 |
+
# Stats
|
| 689 |
+
stats = {
|
| 690 |
+
"total_rows": len(df_final),
|
| 691 |
+
"train_pool_rows": len(train_pool_final),
|
| 692 |
+
"stratified_holdout_rows": len(stratified_holdout),
|
| 693 |
+
"sacred_test_rows": len(test_df),
|
| 694 |
+
"fake_count": int((df_final["binary_label"] == 0).sum()),
|
| 695 |
+
"true_count": int((df_final["binary_label"] == 1).sum()),
|
| 696 |
+
"has_date_ratio": float(df_final["has_date"].mean()),
|
| 697 |
+
"empty_texts_dropped": empty_dropped,
|
| 698 |
+
"short_text_flagged_in_test": short_test_flagged,
|
| 699 |
+
"dedup_removed_exact": exact_counts,
|
| 700 |
+
"dedup_removed_semantic": semantic_counts,
|
| 701 |
+
"per_origin": df_final["dataset_origin"].value_counts().to_dict(),
|
| 702 |
+
"dedup_stats": {k: int(v) for k, v in dedup_stats.items()}
|
| 703 |
+
}
|
| 704 |
+
stats_path = os.path.join(processed_dir, "stats.json")
|
| 705 |
+
with open(stats_path, "w", encoding="utf-8") as fh:
|
| 706 |
+
json.dump(stats, fh, indent=2, default=str)
|
| 707 |
+
logger.info("Saved stats → %s", stats_path)
|
| 708 |
+
|
| 709 |
+
pipeline_elapsed = time.perf_counter() - pipeline_t0
|
| 710 |
+
logger.info("═" * 60)
|
| 711 |
+
logger.info(" STAGE 1 — INGESTION COMPLETE (%.1fs total)", pipeline_elapsed)
|
| 712 |
+
logger.info("═" * 60)
|
| 713 |
+
|
| 714 |
+
return df_final
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
# ═══════════════════════════════════════════════════════════
|
| 718 |
+
# __main__ block for standalone testing
|
| 719 |
+
# ═══════════════════════════════════════════════════════════
|
| 720 |
+
if __name__ == "__main__":
|
| 721 |
+
cfg = load_config()
|
| 722 |
+
df = run_ingestion(cfg)
|
| 723 |
+
print("\n=== Final Unified Dataset ===")
|
| 724 |
+
print(f"Shape: {df.shape}")
|
| 725 |
+
print(f"\nLabel distribution:\n{df['binary_label'].value_counts()}")
|
| 726 |
+
print(f"\nOrigin distribution:\n{df['dataset_origin'].value_counts()}")
|
| 727 |
+
print(f"\nhas_date ratio: {df['has_date'].mean():.2%}")
|
| 728 |
+
print(f"\nSample rows:\n{df.head(3).to_string()}")
|
src/stage2_preprocessing.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import logging
|
| 6 |
+
import pickle
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import yaml
|
| 10 |
+
from sklearn.model_selection import StratifiedShuffleSplit
|
| 11 |
+
|
| 12 |
+
# Fix paths for imports
|
| 13 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 16 |
+
|
| 17 |
+
from src.utils.text_utils import clean_text, build_full_text, word_count, text_length_bucket
|
| 18 |
+
from src.utils.domain_weights import compute_domain_weights
|
| 19 |
+
from src.utils.freshness import apply_freshness_score
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(
|
| 22 |
+
level=logging.INFO,
|
| 23 |
+
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
| 24 |
+
datefmt="%H:%M:%S"
|
| 25 |
+
)
|
| 26 |
+
logger = logging.getLogger("stage2_preprocessing")
|
| 27 |
+
|
| 28 |
+
class KerasStyleTokenizer:
|
| 29 |
+
"""A lightweight, PyTorch-compatible word tokenizer mimicking Keras's Tokenizer."""
|
| 30 |
+
def __init__(self, num_words=None, oov_token="<OOV>"):
|
| 31 |
+
self.num_words = num_words
|
| 32 |
+
self.oov_token = oov_token
|
| 33 |
+
self.word_index = {self.oov_token: 1} # 0 is reserved for padding
|
| 34 |
+
self.word_counts = {}
|
| 35 |
+
|
| 36 |
+
def fit_on_texts(self, texts):
|
| 37 |
+
for text in texts:
|
| 38 |
+
# clean_text already removes punctuation, we just split by space
|
| 39 |
+
words = str(text).split()
|
| 40 |
+
for w in words:
|
| 41 |
+
self.word_counts[w] = self.word_counts.get(w, 0) + 1
|
| 42 |
+
|
| 43 |
+
# Sort by frequency
|
| 44 |
+
sorted_words = sorted(self.word_counts.items(), key=lambda x: x[1], reverse=True)
|
| 45 |
+
|
| 46 |
+
for idx, (w, _) in enumerate(sorted_words):
|
| 47 |
+
if self.num_words and idx >= self.num_words - 2:
|
| 48 |
+
break
|
| 49 |
+
self.word_index[w] = idx + 2
|
| 50 |
+
|
| 51 |
+
def texts_to_sequences(self, texts):
|
| 52 |
+
seqs = []
|
| 53 |
+
for text in texts:
|
| 54 |
+
words = str(text).split()
|
| 55 |
+
seq = [self.word_index.get(w, 1) for w in words]
|
| 56 |
+
seqs.append(seq)
|
| 57 |
+
return seqs
|
| 58 |
+
|
| 59 |
+
def truncate_str_array(df, col):
|
| 60 |
+
"""Memory fix: force string type for arrays."""
|
| 61 |
+
return df[col].astype(str).values
|
| 62 |
+
|
| 63 |
+
def run_preprocessing(cfg: dict = None):
|
| 64 |
+
t0 = time.perf_counter()
|
| 65 |
+
if cfg is None:
|
| 66 |
+
cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
|
| 67 |
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 68 |
+
cfg = yaml.safe_load(f)
|
| 69 |
+
|
| 70 |
+
logger.info("STAGE 2: PREPROCESSING START")
|
| 71 |
+
processed_dir = os.path.join(_PROJECT_ROOT, cfg["paths"]["processed_dir"])
|
| 72 |
+
splits_dir = os.path.join(_PROJECT_ROOT, cfg["paths"]["splits_dir"])
|
| 73 |
+
models_dir = os.path.join(_PROJECT_ROOT, cfg["paths"]["models_dir"])
|
| 74 |
+
os.makedirs(splits_dir, exist_ok=True)
|
| 75 |
+
os.makedirs(models_dir, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
# 1. Load Data
|
| 78 |
+
csv_path = os.path.join(processed_dir, "unified.csv")
|
| 79 |
+
df = pd.read_csv(csv_path)
|
| 80 |
+
df["published_date"] = pd.to_datetime(df["published_date"], errors="coerce")
|
| 81 |
+
logger.info("Loaded unified CSV: %d rows", len(df))
|
| 82 |
+
|
| 83 |
+
# 2. Extract Text Length Features & Clean
|
| 84 |
+
# "Concatenate title + '. ' + text as full_text" -> use build_full_text
|
| 85 |
+
df["full_text"] = df.apply(lambda r: build_full_text(
|
| 86 |
+
str(r["title"]) if pd.notna(r["title"]) else "",
|
| 87 |
+
str(r["text"]) if pd.notna(r["text"]) else ""
|
| 88 |
+
), axis=1)
|
| 89 |
+
# the prompt specifies cleaning the text by lowercasing, removing HTML/URLs/special bounds.
|
| 90 |
+
# clean_text handles exactly this cleanly.
|
| 91 |
+
logger.info("Applying text cleaning (HTML, URLs, whitespace, punctuation) ...")
|
| 92 |
+
df["clean_text"] = df["full_text"].apply(clean_text)
|
| 93 |
+
|
| 94 |
+
logger.info("Calculating word counts and text buckets ...")
|
| 95 |
+
df["word_count"] = df["clean_text"].apply(word_count)
|
| 96 |
+
df["text_length_bucket"] = df["word_count"].apply(text_length_bucket)
|
| 97 |
+
|
| 98 |
+
# 3. Domain Weights
|
| 99 |
+
ds_cfg = cfg.get("dataset", {})
|
| 100 |
+
min_domains = ds_cfg.get("min_domain_samples", 20)
|
| 101 |
+
max_multi = cfg.get("inference", {}).get("max_multiplier", 10)
|
| 102 |
+
|
| 103 |
+
logger.info("Computing domain-aware sample weights...")
|
| 104 |
+
df = compute_domain_weights(df, min_domain_samples=min_domains, max_multiplier=max_multi)
|
| 105 |
+
|
| 106 |
+
# 4. Freshness
|
| 107 |
+
logger.info("Applying temporal freshness scores...")
|
| 108 |
+
df = apply_freshness_score(df, is_inference=False)
|
| 109 |
+
|
| 110 |
+
# 5. Train/Val/Test Splits
|
| 111 |
+
# The user clarified exactly:
|
| 112 |
+
# stratified_holdout -> stage 3 proxy
|
| 113 |
+
# testing -> sacred
|
| 114 |
+
# train pool -> split 85/15 into train and val.
|
| 115 |
+
|
| 116 |
+
test_mask = df["dataset_origin"] == "testing"
|
| 117 |
+
holdout_mask = df["dataset_origin"] == "stratified_holdout"
|
| 118 |
+
train_pool_mask = ~(test_mask | holdout_mask)
|
| 119 |
+
|
| 120 |
+
test_df = df[test_mask].copy()
|
| 121 |
+
holdout_df = df[holdout_mask].copy()
|
| 122 |
+
train_pool_df = df[train_pool_mask].copy()
|
| 123 |
+
|
| 124 |
+
# Split train_pool into 85% train, 15% validation using StratifiedShuffleSplit
|
| 125 |
+
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=42)
|
| 126 |
+
train_pool_df = train_pool_df.reset_index(drop=True)
|
| 127 |
+
|
| 128 |
+
train_idx, val_idx = next(sss.split(train_pool_df, train_pool_df["binary_label"]))
|
| 129 |
+
train_df = train_pool_df.iloc[train_idx].copy()
|
| 130 |
+
val_df = train_pool_df.iloc[val_idx].copy()
|
| 131 |
+
|
| 132 |
+
logger.info("SPLITS SUMMARY:")
|
| 133 |
+
logger.info(" Train: %d rows", len(train_df))
|
| 134 |
+
logger.info(" Val: %d rows", len(val_df))
|
| 135 |
+
logger.info(" Holdout: %d rows", len(holdout_df))
|
| 136 |
+
logger.info(" Sacred: %d rows", len(test_df))
|
| 137 |
+
|
| 138 |
+
# 6. Save splits metadata and arrays
|
| 139 |
+
# Saving raw text separately just for PyTorch dataset convenience (faster than pd.read_csv for big models)
|
| 140 |
+
splits_dict = {
|
| 141 |
+
"train": train_df,
|
| 142 |
+
"val": val_df,
|
| 143 |
+
"holdout": holdout_df,
|
| 144 |
+
"test": test_df
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
for split_name, split_data in splits_dict.items():
|
| 148 |
+
np.save(os.path.join(splits_dir, f"X_text_{split_name}.npy"), truncate_str_array(split_data, "clean_text"))
|
| 149 |
+
np.save(os.path.join(splits_dir, f"y_{split_name}.npy"), split_data["binary_label"].values)
|
| 150 |
+
np.save(os.path.join(splits_dir, f"w_{split_name}.npy"), split_data["sample_weight"].values)
|
| 151 |
+
|
| 152 |
+
meta = {
|
| 153 |
+
"size": len(split_data),
|
| 154 |
+
"fake_count": int((split_data["binary_label"] == 0).sum()),
|
| 155 |
+
"true_count": int((split_data["binary_label"] == 1).sum()),
|
| 156 |
+
"word_count_median": float(split_data["word_count"].median()),
|
| 157 |
+
"freshness_mean": float(split_data["freshness_score"].mean())
|
| 158 |
+
}
|
| 159 |
+
with open(os.path.join(splits_dir, f"meta_{split_name}.json"), "w") as f:
|
| 160 |
+
json.dump(meta, f, indent=2)
|
| 161 |
+
|
| 162 |
+
# Save train_ids.csv explicitly
|
| 163 |
+
train_df[["article_id"]].to_csv(os.path.join(splits_dir, "train_ids.csv"), index=False)
|
| 164 |
+
# Also save the full preprocessed test sets to CSV for easy loading during Stage 3 / Evaluation
|
| 165 |
+
train_df.to_csv(os.path.join(splits_dir, "df_train.csv"), index=False)
|
| 166 |
+
val_df.to_csv(os.path.join(splits_dir, "df_val.csv"), index=False)
|
| 167 |
+
holdout_df.to_csv(os.path.join(splits_dir, "df_holdout.csv"), index=False)
|
| 168 |
+
test_df.to_csv(os.path.join(splits_dir, "df_test.csv"), index=False)
|
| 169 |
+
|
| 170 |
+
# 7. Tokenization (LSTM)
|
| 171 |
+
logger.info("Fitting LSTM Tokenizer on Train split...")
|
| 172 |
+
# Max features for LSTM or generic defaults usually just load all words. We will let it cap at e.g., 50k
|
| 173 |
+
vocab_size = cfg.get("preprocessing", {}).get("max_tfidf_features", 50000)
|
| 174 |
+
tok = KerasStyleTokenizer(num_words=vocab_size)
|
| 175 |
+
tok.fit_on_texts(train_df["clean_text"])
|
| 176 |
+
|
| 177 |
+
tok_path = os.path.join(models_dir, "tokenizer.pkl")
|
| 178 |
+
with open(tok_path, "wb") as f:
|
| 179 |
+
pickle.dump(tok, f)
|
| 180 |
+
logger.info(f"Saved tokenizer to {tok_path} (vocab size: {len(tok.word_index)})")
|
| 181 |
+
|
| 182 |
+
t_end = time.perf_counter()
|
| 183 |
+
logger.info("STAGE 2 FINISHED in %.2f seconds", t_end - t0)
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
run_preprocessing()
|
src/stage3_training.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
import subprocess
|
| 6 |
+
|
| 7 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 8 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 9 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
|
| 12 |
+
logger = logging.getLogger("stage3_training")
|
| 13 |
+
|
| 14 |
+
def run_training(cfg: dict = None):
|
| 15 |
+
t0 = time.perf_counter()
|
| 16 |
+
logger.info("STAGE 3: TRAINING START")
|
| 17 |
+
|
| 18 |
+
python_exe = sys.executable
|
| 19 |
+
models_dir = os.path.join(_PROJECT_ROOT, "src", "models")
|
| 20 |
+
|
| 21 |
+
scripts = [
|
| 22 |
+
("Logistic Regression", "logistic_model.py"),
|
| 23 |
+
("Bi-LSTM", "lstm_model.py"),
|
| 24 |
+
("DistilBERT", "distilbert_model.py"),
|
| 25 |
+
("RoBERTa", "roberta_model.py"),
|
| 26 |
+
("Meta-Classifier", "meta_classifier.py")
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
for name, script_name in scripts:
|
| 30 |
+
script_path = os.path.join(models_dir, script_name)
|
| 31 |
+
logger.info(f"==> Launching {name} Training ({script_name})")
|
| 32 |
+
val = subprocess.run([python_exe, script_path], cwd=_PROJECT_ROOT)
|
| 33 |
+
if val.returncode != 0:
|
| 34 |
+
logger.error(f"{name} aborted with exit code {val.returncode}")
|
| 35 |
+
sys.exit(1)
|
| 36 |
+
|
| 37 |
+
t_end = time.perf_counter()
|
| 38 |
+
logger.info("STAGE 3 FINISHED in %.2f seconds", t_end - t0)
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
run_training()
|
src/stage4_inference.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Stage 4 — Inference Engine (5-Signal Weighted Scoring)
|
| 3 |
+
=====================================================
|
| 4 |
+
Evaluates articles across five independent signals:
|
| 5 |
+
1. Source Credibility (30%)
|
| 6 |
+
2. Claim Verification (30%)
|
| 7 |
+
3. Linguistic Analysis (20%)
|
| 8 |
+
4. Freshness (10%)
|
| 9 |
+
5. Ensemble Model Vote (10%)
|
| 10 |
+
Then applies adversarial overrides and maps to a final verdict.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import re
|
| 15 |
+
import sys
|
| 16 |
+
import yaml
|
| 17 |
+
import logging
|
| 18 |
+
import pickle
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from datetime import datetime, timezone
|
| 23 |
+
|
| 24 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 25 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 26 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 27 |
+
|
| 28 |
+
from src.utils.text_utils import clean_text, build_full_text, word_count as wc_func, text_length_bucket
|
| 29 |
+
from src.stage2_preprocessing import KerasStyleTokenizer
|
| 30 |
+
|
| 31 |
+
import sys
|
| 32 |
+
setattr(sys.modules['__main__'], 'KerasStyleTokenizer', KerasStyleTokenizer)
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger("stage4_inference")
|
| 35 |
+
|
| 36 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 37 |
+
# CONSTANTS
|
| 38 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 39 |
+
|
| 40 |
+
CREDIBLE_OUTLETS = {
|
| 41 |
+
"reuters.com", "apnews.com", "bbc.com", "bbc.co.uk", "nytimes.com",
|
| 42 |
+
"washingtonpost.com", "theguardian.com", "cnn.com", "cbsnews.com",
|
| 43 |
+
"nbcnews.com", "abcnews.go.com", "npr.org", "pbs.org", "bloomberg.com",
|
| 44 |
+
"wsj.com", "ft.com", "economist.com", "usatoday.com", "time.com",
|
| 45 |
+
"politico.com", "thehill.com", "axios.com", "propublica.org",
|
| 46 |
+
"snopes.com", "factcheck.org", "politifact.com", "fullfact.org",
|
| 47 |
+
"aljazeera.com", "dw.com", "france24.com", "scmp.com",
|
| 48 |
+
"theatlantic.com", "newyorker.com", "wired.com", "nature.com",
|
| 49 |
+
"sciencemag.org", "thelancet.com", "bmj.com", "who.int",
|
| 50 |
+
"un.org", "whitehouse.gov", "gov.uk", "europa.eu",
|
| 51 |
+
"hindustantimes.com", "ndtv.com", "thehindu.com", "indianexpress.com",
|
| 52 |
+
"timesofindia.indiatimes.com", "livemint.com",
|
| 53 |
+
"abc.net.au", "cbc.ca", "globalnews.ca", "stuff.co.nz",
|
| 54 |
+
"forbes.com", "businessinsider.com", "cnbc.com", "techcrunch.com",
|
| 55 |
+
"arstechnica.com", "theverge.com", "engadget.com",
|
| 56 |
+
"espn.com", "bbc.com/sport", "skysports.com",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
CORROBORATION_OUTLETS_RE = re.compile(
|
| 60 |
+
r"(?i)\b(Reuters|Associated Press|\bAP\b|CBS|BBC|NBC|CNN|"
|
| 61 |
+
r"New York Times|NYT|Washington Post|The Guardian|NPR|PBS|"
|
| 62 |
+
r"Bloomberg|Wall Street Journal|Forbes)\b"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
AUTHOR_PATTERNS = re.compile(
|
| 66 |
+
r"(?i)\b(by|written by|reporter|staff writer|correspondent|"
|
| 67 |
+
r"contributing writer|author|edited by|reported by)\b\s*[A-Z]"
|
| 68 |
+
)
|
| 69 |
+
BYLINE_NAME_RE = re.compile(r"^[A-Z][a-z]+ [A-Z][a-z]+", re.MULTILINE)
|
| 70 |
+
|
| 71 |
+
SUPERLATIVE_RE = re.compile(
|
| 72 |
+
r"(?i)\b(shocking|massive|unprecedented|bombshell|explosive|"
|
| 73 |
+
r"stunning|jaw-dropping|mind-blowing|unbelievable|outrageous)\b"
|
| 74 |
+
)
|
| 75 |
+
SENSATIONAL_RE = re.compile(
|
| 76 |
+
r"(?i)(you won't believe|what happened next|this is why|"
|
| 77 |
+
r"one weird trick|exposed|destroyed|slammed)"
|
| 78 |
+
)
|
| 79 |
+
NO_ATTRIB_RE = re.compile(
|
| 80 |
+
r"(?i)(sources say|it is believed|reportedly|some people say|"
|
| 81 |
+
r"many believe|rumor has it|anonymous source|unconfirmed reports)"
|
| 82 |
+
)
|
| 83 |
+
PASSIVE_VOICE_RE = re.compile(
|
| 84 |
+
r"(?i)(it is being said|it was reported|it has been claimed|"
|
| 85 |
+
r"it is alleged|it was alleged|it is rumored)"
|
| 86 |
+
)
|
| 87 |
+
QUOTE_RE = re.compile(r'"([^"]{10,})"')
|
| 88 |
+
QUOTE_ATTRIB_RE = re.compile(
|
| 89 |
+
r"(?i)(said|stated|according to|told|announced|confirmed|wrote|called|described|noted|added|explained|argued|claimed)"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
STAT_RE = re.compile(r"\d+\s*%|\d+\s*(million|billion|trillion)", re.IGNORECASE)
|
| 93 |
+
CITATION_RE = re.compile(
|
| 94 |
+
r"(?i)(according to|source:|study by|data from|published by|research by|"
|
| 95 |
+
r"report by|survey by|analysis by|statistics from)"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
INSTITUTION_RE = re.compile(
|
| 99 |
+
r"(?i)(university|department of|ministry|commission|institute|agency|"
|
| 100 |
+
r"foundation|world health|WHO|FDA|CDC|NASA|UNICEF|IMF|World Bank)"
|
| 101 |
+
)
|
| 102 |
+
TEMPORAL_RE = re.compile(
|
| 103 |
+
r"(?i)(this week|this month|recently|new report|just released|"
|
| 104 |
+
r"annual forecast|latest data|new study|breaking|today|yesterday)"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class ModelNotTrainedError(Exception):
|
| 109 |
+
def __init__(self, message="Run python run_pipeline.py --stage 3 first"):
|
| 110 |
+
super().__init__(message)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 114 |
+
# MODEL LOADING (unchanged from original)
|
| 115 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 116 |
+
|
| 117 |
+
_MODEL_CACHE = {}
|
| 118 |
+
|
| 119 |
+
def load_config():
|
| 120 |
+
cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
|
| 121 |
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 122 |
+
return yaml.safe_load(f)
|
| 123 |
+
|
| 124 |
+
def _get_model(model_name, cfg):
|
| 125 |
+
"""Lazy load models."""
|
| 126 |
+
if model_name in _MODEL_CACHE:
|
| 127 |
+
return _MODEL_CACHE[model_name]
|
| 128 |
+
|
| 129 |
+
models_dir = os.path.join(_PROJECT_ROOT, cfg.get("paths", {}).get("models_dir", "models/saved"))
|
| 130 |
+
|
| 131 |
+
if model_name == "logistic":
|
| 132 |
+
import joblib
|
| 133 |
+
fpath = os.path.join(models_dir, "logistic_model", "logistic_model.pkl")
|
| 134 |
+
if not os.path.exists(fpath): raise ModelNotTrainedError()
|
| 135 |
+
_MODEL_CACHE[model_name] = joblib.load(fpath)
|
| 136 |
+
|
| 137 |
+
elif model_name == "lstm":
|
| 138 |
+
from src.models.lstm_model import BiLSTMClassifier, load_glove_embeddings, pad_sequences
|
| 139 |
+
tok_path = os.path.join(models_dir, "tokenizer.pkl")
|
| 140 |
+
if not os.path.exists(tok_path) or not os.path.exists(os.path.join(models_dir, "lstm_model", "model.pt")):
|
| 141 |
+
raise ModelNotTrainedError()
|
| 142 |
+
with open(tok_path, "rb") as f:
|
| 143 |
+
tok = pickle.load(f)
|
| 144 |
+
glove_path = os.path.join(_PROJECT_ROOT, cfg["paths"]["glove_path"])
|
| 145 |
+
emb_matrix, vocab_size = load_glove_embeddings(glove_path, tok.word_index)
|
| 146 |
+
|
| 147 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 148 |
+
model = BiLSTMClassifier(vocab_size, emb_matrix).to(device)
|
| 149 |
+
model.load_state_dict(torch.load(os.path.join(models_dir, "lstm_model", "model.pt"), map_location=device))
|
| 150 |
+
model.eval()
|
| 151 |
+
_MODEL_CACHE[model_name] = (model, tok, device)
|
| 152 |
+
|
| 153 |
+
elif model_name in ("distilbert", "roberta"):
|
| 154 |
+
try:
|
| 155 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 156 |
+
except ImportError:
|
| 157 |
+
raise ModelNotTrainedError()
|
| 158 |
+
d_path = os.path.join(models_dir, f"{model_name}_model")
|
| 159 |
+
if not os.path.exists(os.path.join(d_path, "config.json")):
|
| 160 |
+
raise ModelNotTrainedError()
|
| 161 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 162 |
+
tok = AutoTokenizer.from_pretrained(d_path)
|
| 163 |
+
model = AutoModelForSequenceClassification.from_pretrained(d_path).to(device)
|
| 164 |
+
model.eval()
|
| 165 |
+
_MODEL_CACHE[model_name] = (model, tok, device)
|
| 166 |
+
|
| 167 |
+
elif model_name == "meta":
|
| 168 |
+
import joblib
|
| 169 |
+
fpath = os.path.join(models_dir, "meta_classifier", "meta_classifier.pkl")
|
| 170 |
+
if not os.path.exists(fpath): raise ModelNotTrainedError()
|
| 171 |
+
_MODEL_CACHE[model_name] = joblib.load(fpath)
|
| 172 |
+
|
| 173 |
+
return _MODEL_CACHE[model_name]
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 177 |
+
# FEATURE EXTRACTION
|
| 178 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 179 |
+
|
| 180 |
+
def extract_features(title, text, source_domain, published_date, cfg):
|
| 181 |
+
"""Build standardized structural mapping for raw strings."""
|
| 182 |
+
full = build_full_text(title, text)
|
| 183 |
+
clean = clean_text(full)
|
| 184 |
+
wc = wc_func(clean)
|
| 185 |
+
bucket = text_length_bucket(wc)
|
| 186 |
+
|
| 187 |
+
has_date = pd.notna(published_date) and published_date != ""
|
| 188 |
+
if has_date and isinstance(published_date, str):
|
| 189 |
+
try:
|
| 190 |
+
published_date = pd.to_datetime(published_date, utc=True)
|
| 191 |
+
except Exception:
|
| 192 |
+
has_date = False
|
| 193 |
+
published_date = None
|
| 194 |
+
elif has_date:
|
| 195 |
+
try:
|
| 196 |
+
published_date = pd.Timestamp(published_date, tz="UTC")
|
| 197 |
+
except Exception:
|
| 198 |
+
has_date = False
|
| 199 |
+
published_date = None
|
| 200 |
+
|
| 201 |
+
return {
|
| 202 |
+
"clean_text": clean,
|
| 203 |
+
"full_text": full,
|
| 204 |
+
"word_count": wc,
|
| 205 |
+
"text_length_bucket": bucket,
|
| 206 |
+
"has_date": has_date,
|
| 207 |
+
"published_date": published_date,
|
| 208 |
+
"source_domain": source_domain if source_domain else "unknown",
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 213 |
+
# STEP 1 — SOURCE CREDIBILITY (weight: 30%)
|
| 214 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 215 |
+
|
| 216 |
+
def _levenshtein(s1, s2):
|
| 217 |
+
"""Minimal Levenshtein distance for typosquatting check."""
|
| 218 |
+
if len(s1) < len(s2):
|
| 219 |
+
return _levenshtein(s2, s1)
|
| 220 |
+
if len(s2) == 0:
|
| 221 |
+
return len(s1)
|
| 222 |
+
prev_row = range(len(s2) + 1)
|
| 223 |
+
for i, c1 in enumerate(s1):
|
| 224 |
+
curr_row = [i + 1]
|
| 225 |
+
for j, c2 in enumerate(s2):
|
| 226 |
+
curr_row.append(min(curr_row[j] + 1, prev_row[j + 1] + 1,
|
| 227 |
+
prev_row[j] + (c1 != c2)))
|
| 228 |
+
prev_row = curr_row
|
| 229 |
+
return prev_row[-1]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def score_source_credibility(source_domain, title, text):
|
| 233 |
+
"""
|
| 234 |
+
Step 1: Evaluate source trustworthiness.
|
| 235 |
+
Returns: (score, author_found, typosquatting_detected)
|
| 236 |
+
"""
|
| 237 |
+
# ── Early return: no source at all ──
|
| 238 |
+
if not source_domain or source_domain.strip() == "" or source_domain == "unknown":
|
| 239 |
+
# Still check for author in text body
|
| 240 |
+
author_found = bool(AUTHOR_PATTERNS.search(text[:500])) or bool(BYLINE_NAME_RE.search(text[:200]))
|
| 241 |
+
return 0.3, author_found, False
|
| 242 |
+
|
| 243 |
+
domain = source_domain.strip().lower()
|
| 244 |
+
|
| 245 |
+
# ── Typosquatting check ──
|
| 246 |
+
for outlet in CREDIBLE_OUTLETS:
|
| 247 |
+
dist = _levenshtein(domain, outlet)
|
| 248 |
+
if 0 < dist <= 2: # close but not exact
|
| 249 |
+
return 0.0, False, True
|
| 250 |
+
|
| 251 |
+
# ── Component scoring ──
|
| 252 |
+
score = 0.0
|
| 253 |
+
|
| 254 |
+
# Base: any valid domain
|
| 255 |
+
score += 0.20
|
| 256 |
+
|
| 257 |
+
# Known outlet
|
| 258 |
+
if domain in CREDIBLE_OUTLETS:
|
| 259 |
+
score += 0.40
|
| 260 |
+
|
| 261 |
+
# Author verifiability
|
| 262 |
+
search_area = text[:500]
|
| 263 |
+
author_found = bool(AUTHOR_PATTERNS.search(search_area)) or bool(BYLINE_NAME_RE.search(text[:200]))
|
| 264 |
+
if author_found:
|
| 265 |
+
score += 0.20
|
| 266 |
+
|
| 267 |
+
# Corroboration: text mentions other major outlets
|
| 268 |
+
if CORROBORATION_OUTLETS_RE.search(text):
|
| 269 |
+
score += 0.20
|
| 270 |
+
|
| 271 |
+
return min(1.0, score), author_found, False
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 275 |
+
# STEP 2 — CLAIM VERIFICATION (weight: 30%)
|
| 276 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 277 |
+
|
| 278 |
+
_SPACY_NLP = None
|
| 279 |
+
|
| 280 |
+
def _get_spacy():
|
| 281 |
+
global _SPACY_NLP
|
| 282 |
+
if _SPACY_NLP is None:
|
| 283 |
+
import spacy
|
| 284 |
+
try:
|
| 285 |
+
_SPACY_NLP = spacy.load("en_core_web_sm")
|
| 286 |
+
except OSError:
|
| 287 |
+
import subprocess
|
| 288 |
+
subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
| 289 |
+
_SPACY_NLP = spacy.load("en_core_web_sm")
|
| 290 |
+
return _SPACY_NLP
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def score_claim_verification(meta_proba, clean_text_str, title):
|
| 294 |
+
"""
|
| 295 |
+
Step 2: Entity-level claim verification.
|
| 296 |
+
Returns: (claim_score, entities_found, n_verifiable, quotes_attributed, quotes_total)
|
| 297 |
+
"""
|
| 298 |
+
nlp = _get_spacy()
|
| 299 |
+
# Process a capped version to avoid memory issues on long articles
|
| 300 |
+
doc = nlp(clean_text_str[:5000])
|
| 301 |
+
|
| 302 |
+
# Sub-step A: Named Entity Extraction
|
| 303 |
+
verifiable_types = {"PERSON", "ORG", "GPE"}
|
| 304 |
+
numeric_types = {"MONEY", "PERCENT", "CARDINAL"}
|
| 305 |
+
|
| 306 |
+
verifiable_ents = [ent.text for ent in doc.ents if ent.label_ in verifiable_types]
|
| 307 |
+
numeric_ents = [ent for ent in doc.ents if ent.label_ in numeric_types]
|
| 308 |
+
|
| 309 |
+
n_verifiable = len(set(verifiable_ents))
|
| 310 |
+
|
| 311 |
+
# Count unverifiable numeric claims (no citation within ±100 chars)
|
| 312 |
+
n_unverifiable = 0
|
| 313 |
+
for ent in numeric_ents:
|
| 314 |
+
start = max(0, ent.start_char - 100)
|
| 315 |
+
end = min(len(clean_text_str), ent.end_char + 100)
|
| 316 |
+
context = clean_text_str[start:end]
|
| 317 |
+
if not CITATION_RE.search(context):
|
| 318 |
+
n_unverifiable += 1
|
| 319 |
+
|
| 320 |
+
# Sub-step B: Quote Attribution
|
| 321 |
+
quotes = QUOTE_RE.findall(clean_text_str[:5000])
|
| 322 |
+
quotes_total = len(quotes)
|
| 323 |
+
quotes_attributed = 0
|
| 324 |
+
|
| 325 |
+
for q in quotes:
|
| 326 |
+
q_pos = clean_text_str.find(q)
|
| 327 |
+
if q_pos == -1:
|
| 328 |
+
continue
|
| 329 |
+
context_start = max(0, q_pos - 50)
|
| 330 |
+
context_end = min(len(clean_text_str), q_pos + len(q) + 50)
|
| 331 |
+
context = clean_text_str[context_start:context_end]
|
| 332 |
+
if QUOTE_ATTRIB_RE.search(context):
|
| 333 |
+
quotes_attributed += 1
|
| 334 |
+
|
| 335 |
+
attributed_ratio = (quotes_attributed / quotes_total) if quotes_total > 0 else 1.0
|
| 336 |
+
|
| 337 |
+
# Sub-step C: Combine
|
| 338 |
+
entity_score = min(1.0, n_verifiable / 3) # 3+ verifiable entities = full marks
|
| 339 |
+
unverifiable_penalty = min(0.15, n_unverifiable * 0.05)
|
| 340 |
+
|
| 341 |
+
claim_score = (meta_proba * 0.60) + (entity_score * 0.25) + (attributed_ratio * 0.15)
|
| 342 |
+
claim_score = max(0.0, min(1.0, claim_score - unverifiable_penalty))
|
| 343 |
+
|
| 344 |
+
entities_found = list(set(verifiable_ents))[:10] # Cap for JSON output
|
| 345 |
+
|
| 346 |
+
return claim_score, entities_found, n_verifiable, quotes_attributed, quotes_total
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 350 |
+
# STEP 3 — LINGUISTIC ANALYSIS (weight: 20%)
|
| 351 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 352 |
+
|
| 353 |
+
def score_linguistic_quality(title, text, clean_text_str, author_found, cfg=None):
|
| 354 |
+
"""
|
| 355 |
+
Step 3: Rule-based linguistic quality scoring.
|
| 356 |
+
Reuses DistilBERT for headline contradiction check.
|
| 357 |
+
Returns: (linguistic_score, deductions_applied, headline_contradicts)
|
| 358 |
+
"""
|
| 359 |
+
score = 1.0
|
| 360 |
+
deductions = []
|
| 361 |
+
headline_contradicts = False
|
| 362 |
+
title_str = str(title) if title else ""
|
| 363 |
+
|
| 364 |
+
# ── 1. Sensationalist headline (-0.20) ──
|
| 365 |
+
sensational = False
|
| 366 |
+
if title_str:
|
| 367 |
+
caps_words = re.findall(r"\b[A-Z]{4,}\b", title_str)
|
| 368 |
+
if len(caps_words) >= 1:
|
| 369 |
+
sensational = True
|
| 370 |
+
if "!" in title_str:
|
| 371 |
+
sensational = True
|
| 372 |
+
if SENSATIONAL_RE.search(title_str):
|
| 373 |
+
sensational = True
|
| 374 |
+
if sensational:
|
| 375 |
+
score -= 0.20
|
| 376 |
+
deductions.append("Sensationalist headline detected")
|
| 377 |
+
|
| 378 |
+
# ── 2. Excessive superlatives (-0.15, needs ≥2 matches) ──
|
| 379 |
+
superlative_matches = SUPERLATIVE_RE.findall(clean_text_str)
|
| 380 |
+
if len(superlative_matches) >= 2:
|
| 381 |
+
score -= 0.15
|
| 382 |
+
deductions.append(f"Excessive superlatives ({len(superlative_matches)} found)")
|
| 383 |
+
|
| 384 |
+
# ── 3. No attribution (-0.15) ──
|
| 385 |
+
if NO_ATTRIB_RE.search(clean_text_str):
|
| 386 |
+
score -= 0.15
|
| 387 |
+
deductions.append("Anonymous/vague attribution patterns found")
|
| 388 |
+
|
| 389 |
+
# ── 4. Headline contradicts body (-0.10) ──
|
| 390 |
+
# Guard: only run if title looks like a real headline, not an auto-extracted body sentence
|
| 391 |
+
is_real_headline = (
|
| 392 |
+
title_str
|
| 393 |
+
and len(title_str) > 10
|
| 394 |
+
and len(title_str.split()) <= 15
|
| 395 |
+
and not title_str.lower().startswith(("it has", "it was", "it is", "there was", "there is"))
|
| 396 |
+
and title_str.lower() not in str(text).lower()[:100]
|
| 397 |
+
)
|
| 398 |
+
if is_real_headline:
|
| 399 |
+
body_only = str(text)[:512] # Raw body text, NOT clean_text_str which has title prepended
|
| 400 |
+
try:
|
| 401 |
+
if "distilbert" in _MODEL_CACHE:
|
| 402 |
+
model, tok, device = _MODEL_CACHE["distilbert"]
|
| 403 |
+
with torch.no_grad():
|
| 404 |
+
t_enc = tok(title_str, return_tensors="pt", truncation=True, max_length=64, padding=True).to(device)
|
| 405 |
+
b_enc = tok(body_only, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device)
|
| 406 |
+
t_hidden = model.distilbert(**t_enc).last_hidden_state[:, 0, :] # CLS token
|
| 407 |
+
b_hidden = model.distilbert(**b_enc).last_hidden_state[:, 0, :]
|
| 408 |
+
cos_sim = float(torch.nn.functional.cosine_similarity(t_hidden, b_hidden).item())
|
| 409 |
+
if cos_sim < 0.30:
|
| 410 |
+
headline_contradicts = True
|
| 411 |
+
score -= 0.10
|
| 412 |
+
deductions.append(f"Headline may contradict body (similarity={cos_sim:.2f})")
|
| 413 |
+
except Exception as e:
|
| 414 |
+
# Fallback: simple word overlap against body only
|
| 415 |
+
title_words = set(title_str.lower().split())
|
| 416 |
+
body_words = set(body_only.lower().split())
|
| 417 |
+
overlap = len(title_words & body_words) / max(len(title_words), 1)
|
| 418 |
+
if overlap < 0.15 and len(title_words) > 3:
|
| 419 |
+
headline_contradicts = True
|
| 420 |
+
score -= 0.10
|
| 421 |
+
deductions.append("Headline has very low word overlap with body")
|
| 422 |
+
|
| 423 |
+
# ── 5. Internal contradictions (-0.10) ──
|
| 424 |
+
# Heuristic: negation near repeated noun phrase
|
| 425 |
+
sentences = re.split(r'[.!?]+', clean_text_str[:3000])
|
| 426 |
+
negation_re = re.compile(r"\b(not|no|never|false|deny|denied|incorrect|wrong)\b", re.IGNORECASE)
|
| 427 |
+
noun_counts = {}
|
| 428 |
+
contradiction_found = False
|
| 429 |
+
for sent in sentences:
|
| 430 |
+
words = sent.lower().split()
|
| 431 |
+
# Track nouns (simple: capitalized words in original text)
|
| 432 |
+
for w in words:
|
| 433 |
+
if len(w) > 3:
|
| 434 |
+
noun_counts[w] = noun_counts.get(w, 0) + 1
|
| 435 |
+
# Check if a repeated noun appears near negation
|
| 436 |
+
if negation_re.search(sent):
|
| 437 |
+
for w in words:
|
| 438 |
+
if noun_counts.get(w, 0) >= 2 and len(w) > 4:
|
| 439 |
+
contradiction_found = True
|
| 440 |
+
break
|
| 441 |
+
if contradiction_found:
|
| 442 |
+
break
|
| 443 |
+
if contradiction_found:
|
| 444 |
+
score -= 0.10
|
| 445 |
+
deductions.append("Possible internal contradiction detected")
|
| 446 |
+
|
| 447 |
+
# ── 6. Passive voice obscuring agency (-0.10) ──
|
| 448 |
+
if PASSIVE_VOICE_RE.search(clean_text_str):
|
| 449 |
+
score -= 0.10
|
| 450 |
+
deductions.append("Passive voice used to obscure agency")
|
| 451 |
+
|
| 452 |
+
# ── 7. Missing byline (-0.05) ──
|
| 453 |
+
if not author_found:
|
| 454 |
+
score -= 0.05
|
| 455 |
+
deductions.append("No byline or author attribution found")
|
| 456 |
+
|
| 457 |
+
score = max(0.0, score)
|
| 458 |
+
return score, deductions, headline_contradicts
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
# ═════════════════════════════════════════════���═══════════════════════════════
|
| 462 |
+
# STEP 4 — FRESHNESS (weight: 10%)
|
| 463 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 464 |
+
|
| 465 |
+
def score_freshness_v2(published_date, has_date, title, text):
|
| 466 |
+
"""
|
| 467 |
+
Step 4: Temporal freshness scoring.
|
| 468 |
+
Case A: Date found → bracket-based scoring.
|
| 469 |
+
Case B: No date → contextual signal scanning.
|
| 470 |
+
Returns: (score, case, signals_found)
|
| 471 |
+
"""
|
| 472 |
+
if has_date and published_date is not None:
|
| 473 |
+
# ── Case A ──
|
| 474 |
+
now = datetime.now(timezone.utc)
|
| 475 |
+
try:
|
| 476 |
+
if getattr(published_date, 'tzinfo', None) is None:
|
| 477 |
+
published_date = published_date.replace(tzinfo=timezone.utc)
|
| 478 |
+
days_old = (now - published_date).days
|
| 479 |
+
except Exception:
|
| 480 |
+
# Fallback to Case B if date math fails
|
| 481 |
+
return _freshness_case_b(title, text)
|
| 482 |
+
|
| 483 |
+
if days_old < 0:
|
| 484 |
+
days_old = 0
|
| 485 |
+
|
| 486 |
+
if days_old < 30:
|
| 487 |
+
return 1.0, "A", []
|
| 488 |
+
elif days_old <= 180:
|
| 489 |
+
return 0.75, "A", []
|
| 490 |
+
elif days_old <= 730: # 2 years
|
| 491 |
+
return 0.5, "A", []
|
| 492 |
+
else:
|
| 493 |
+
return 0.2, "A", []
|
| 494 |
+
else:
|
| 495 |
+
return _freshness_case_b(title, text)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def _freshness_case_b(title, text):
|
| 499 |
+
"""Case B: No date found — scan for contextual freshness signals."""
|
| 500 |
+
combined = str(title) + " " + str(text)
|
| 501 |
+
signals = []
|
| 502 |
+
now = datetime.now()
|
| 503 |
+
|
| 504 |
+
# Signal 1: Current year mentioned (dynamic)
|
| 505 |
+
year_re = re.compile(r"\b(" + str(now.year) + r"|" + str(now.year - 1) + r")\b")
|
| 506 |
+
if year_re.search(combined):
|
| 507 |
+
signals.append(f"Current/recent year mentioned ({now.year} or {now.year-1})")
|
| 508 |
+
|
| 509 |
+
# Signal 2: Temporal phrases
|
| 510 |
+
if TEMPORAL_RE.search(combined):
|
| 511 |
+
signals.append("Temporal freshness phrase detected")
|
| 512 |
+
|
| 513 |
+
# Signal 3: Named institution
|
| 514 |
+
if INSTITUTION_RE.search(combined):
|
| 515 |
+
signals.append("Named institutional publisher found")
|
| 516 |
+
|
| 517 |
+
# Signal 4: Major outlet corroboration
|
| 518 |
+
if CORROBORATION_OUTLETS_RE.search(combined):
|
| 519 |
+
signals.append("Major outlet corroboration cited")
|
| 520 |
+
|
| 521 |
+
score_map = {4: 0.80, 3: 0.70, 2: 0.60, 1: 0.50, 0: 0.40}
|
| 522 |
+
n = min(len(signals), 4)
|
| 523 |
+
return score_map[n], "B", signals
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 527 |
+
# STEP 5 — MODEL VOTE (weight: 10%)
|
| 528 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 529 |
+
|
| 530 |
+
def score_model_vote(votes):
|
| 531 |
+
"""Step 5: Proportion of TRUE votes from the ensemble."""
|
| 532 |
+
if not votes:
|
| 533 |
+
return 0.5
|
| 534 |
+
return sum(votes.values()) / len(votes)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 538 |
+
# ADVERSARIAL OVERRIDE
|
| 539 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 540 |
+
|
| 541 |
+
def check_adversarial_flags(has_date, author_found, n_verifiable, headline_contradicts,
|
| 542 |
+
typosquatting_detected, text):
|
| 543 |
+
"""
|
| 544 |
+
Post-scoring adversarial check.
|
| 545 |
+
Any flag → cap final_score at 0.25.
|
| 546 |
+
Returns: list of triggered flag names.
|
| 547 |
+
"""
|
| 548 |
+
flags = []
|
| 549 |
+
|
| 550 |
+
# Flag 1: Triple anonymity
|
| 551 |
+
if not has_date and not author_found and n_verifiable == 0:
|
| 552 |
+
flags.append("Triple anonymity (no date, no author, no named sources)")
|
| 553 |
+
|
| 554 |
+
# Flag 2: Headline contradicts body
|
| 555 |
+
if headline_contradicts:
|
| 556 |
+
flags.append("Headline contradicts article body")
|
| 557 |
+
|
| 558 |
+
# Flag 3: Typosquatting
|
| 559 |
+
if typosquatting_detected:
|
| 560 |
+
flags.append("Domain mimics a known outlet (typosquatting)")
|
| 561 |
+
|
| 562 |
+
# Flag 4: Statistics without traceable source
|
| 563 |
+
stats_found = STAT_RE.findall(text)
|
| 564 |
+
if stats_found:
|
| 565 |
+
# Check if any citation pattern exists in the text
|
| 566 |
+
if not CITATION_RE.search(text):
|
| 567 |
+
flags.append("Statistics cited with no traceable primary source")
|
| 568 |
+
|
| 569 |
+
return flags
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 573 |
+
# REASON BUILDER
|
| 574 |
+
# ═══════════════════════════════��═════════════════════════════════════════════
|
| 575 |
+
|
| 576 |
+
def build_reasons_and_missing(scores, n_verifiable, author_found, has_date,
|
| 577 |
+
deductions, adversarial_flags):
|
| 578 |
+
"""
|
| 579 |
+
Programmatically generate top_reasons and missing_signals from scores.
|
| 580 |
+
Returns: (reasons[:3], missing_signals)
|
| 581 |
+
"""
|
| 582 |
+
reasons = []
|
| 583 |
+
missing = []
|
| 584 |
+
|
| 585 |
+
# ── Negative signals ──
|
| 586 |
+
if scores["source"] < 0.4:
|
| 587 |
+
reasons.append("Source is unknown or not editorially accountable")
|
| 588 |
+
if scores["claim"] < 0.5:
|
| 589 |
+
reasons.append("Core claims could not be fully verified")
|
| 590 |
+
if scores["linguistic"] < 0.7:
|
| 591 |
+
reasons.append("Writing style shows signs of sensationalism or manipulation")
|
| 592 |
+
if scores["freshness"] < 0.5:
|
| 593 |
+
reasons.append("Article age or missing date reduces temporal reliability")
|
| 594 |
+
if scores["model_vote"] < 0.5:
|
| 595 |
+
reasons.append("AI models flagged patterns inconsistent with credible journalism")
|
| 596 |
+
|
| 597 |
+
# ── Positive signals ──
|
| 598 |
+
if scores["source"] >= 0.8:
|
| 599 |
+
reasons.append("Article is from a known, credible outlet")
|
| 600 |
+
if scores["claim"] >= 0.8:
|
| 601 |
+
reasons.append("Core claims are well-attributed with verifiable entities")
|
| 602 |
+
if scores["linguistic"] >= 0.9:
|
| 603 |
+
reasons.append("Writing style is neutral and well-attributed")
|
| 604 |
+
if scores["model_vote"] >= 0.75:
|
| 605 |
+
reasons.append("AI models strongly agree this content is credible")
|
| 606 |
+
|
| 607 |
+
# ── Adversarial flags ──
|
| 608 |
+
for flag in adversarial_flags:
|
| 609 |
+
reasons.append(f"Adversarial flag: {flag}")
|
| 610 |
+
|
| 611 |
+
# ── Missing signals ──
|
| 612 |
+
if not author_found:
|
| 613 |
+
missing.append("Author identity could not be verified")
|
| 614 |
+
if not has_date:
|
| 615 |
+
missing.append("Publication date not found")
|
| 616 |
+
if scores["source"] <= 0.3:
|
| 617 |
+
missing.append("Source domain not recognized")
|
| 618 |
+
if n_verifiable == 0:
|
| 619 |
+
missing.append("No verifiable named entities found in text")
|
| 620 |
+
|
| 621 |
+
return reasons[:3], missing
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 625 |
+
# MAIN INFERENCE INTERFACE
|
| 626 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 627 |
+
|
| 628 |
+
def predict_article(title, text, source_domain, published_date, mode="full", trigger_rag=True):
|
| 629 |
+
"""
|
| 630 |
+
5-Signal weighted scoring inference.
|
| 631 |
+
|
| 632 |
+
Execution order:
|
| 633 |
+
1. extract_features()
|
| 634 |
+
2. Run base models (LR/LSTM/DistilBERT/RoBERTa) → probas, votes
|
| 635 |
+
3. Run meta-classifier → meta_proba
|
| 636 |
+
4. Step 1: score_source_credibility()
|
| 637 |
+
5. Step 2: score_claim_verification()
|
| 638 |
+
6. Step 3: score_linguistic_quality() [needs author_found from Step 1]
|
| 639 |
+
7. Step 4: score_freshness_v2()
|
| 640 |
+
8. Step 5: score_model_vote()
|
| 641 |
+
9. Weighted final score + adversarial override + verdict
|
| 642 |
+
"""
|
| 643 |
+
cfg = load_config()
|
| 644 |
+
feat = extract_features(title, text, source_domain, published_date, cfg)
|
| 645 |
+
|
| 646 |
+
probas = {
|
| 647 |
+
"lr_proba": np.nan, "lstm_proba": np.nan,
|
| 648 |
+
"distilbert_proba": np.nan, "roberta_proba": np.nan,
|
| 649 |
+
}
|
| 650 |
+
votes = {}
|
| 651 |
+
|
| 652 |
+
# ── Base Model Inference ──────────────────────────────────────────────
|
| 653 |
+
|
| 654 |
+
# 1. Logistic Regression
|
| 655 |
+
if mode in ("fast", "balanced", "full"):
|
| 656 |
+
lr_pipe = _get_model("logistic", cfg)
|
| 657 |
+
df_lr = pd.DataFrame([{
|
| 658 |
+
"clean_text": feat["clean_text"],
|
| 659 |
+
"word_count": feat["word_count"],
|
| 660 |
+
"text_length_bucket": feat["text_length_bucket"],
|
| 661 |
+
"has_date": 1 if feat["has_date"] else 0,
|
| 662 |
+
"freshness_score": 0.5, # neutral for model input
|
| 663 |
+
"source_domain": feat["source_domain"],
|
| 664 |
+
}])
|
| 665 |
+
try:
|
| 666 |
+
p = float(lr_pipe.predict_proba(df_lr)[:, 1][0])
|
| 667 |
+
probas["lr_proba"] = p
|
| 668 |
+
votes["logistic"] = int(p >= 0.5)
|
| 669 |
+
except Exception as e:
|
| 670 |
+
logger.warning(f"LR inference failed: {e}")
|
| 671 |
+
|
| 672 |
+
# 2. Bi-LSTM
|
| 673 |
+
if mode in ("balanced", "full"):
|
| 674 |
+
lstm_model, tok, device = _get_model("lstm", cfg)
|
| 675 |
+
maxlen = cfg.get("preprocessing", {}).get("lstm_max_len", 512)
|
| 676 |
+
from src.models.lstm_model import pad_sequences
|
| 677 |
+
|
| 678 |
+
seq = tok.texts_to_sequences([feat["clean_text"]])
|
| 679 |
+
pad = pad_sequences(seq, maxlen=maxlen, padding='post')
|
| 680 |
+
t_pad = torch.from_numpy(pad).long().to(device)
|
| 681 |
+
|
| 682 |
+
with torch.no_grad():
|
| 683 |
+
logits = lstm_model(t_pad)
|
| 684 |
+
p = float(torch.sigmoid(logits).cpu().numpy()[0])
|
| 685 |
+
probas["lstm_proba"] = p
|
| 686 |
+
votes["lstm"] = int(p >= 0.5)
|
| 687 |
+
|
| 688 |
+
# 3. Transformers (DistilBERT + RoBERTa)
|
| 689 |
+
if mode == "full":
|
| 690 |
+
for t_name in ("distilbert", "roberta"):
|
| 691 |
+
model, tok, device = _get_model(t_name, cfg)
|
| 692 |
+
inputs = tok(feat["clean_text"], padding=True, truncation=True,
|
| 693 |
+
max_length=512, return_tensors="pt").to(device)
|
| 694 |
+
with torch.no_grad():
|
| 695 |
+
out = model(**inputs)
|
| 696 |
+
p = float(torch.softmax(out.logits, dim=-1)[0, 1].item())
|
| 697 |
+
if t_name == "roberta":
|
| 698 |
+
p = p * 0.92 # RoBERTa TRUE-bias dampening
|
| 699 |
+
probas[t_name + "_proba"] = p
|
| 700 |
+
votes[t_name] = int(p >= 0.5)
|
| 701 |
+
|
| 702 |
+
# 4. Meta-Classifier
|
| 703 |
+
meta_bundle = _get_model("meta", cfg)
|
| 704 |
+
meta_preprocessor = meta_bundle["preprocessor"]
|
| 705 |
+
meta_model = meta_bundle["model"]
|
| 706 |
+
|
| 707 |
+
df_meta = pd.DataFrame([{
|
| 708 |
+
"lr_proba": probas["lr_proba"],
|
| 709 |
+
"lstm_proba": probas["lstm_proba"],
|
| 710 |
+
"distilbert_proba": probas["distilbert_proba"],
|
| 711 |
+
"roberta_proba": probas["roberta_proba"],
|
| 712 |
+
"word_count": feat["word_count"],
|
| 713 |
+
"has_date": 1 if feat["has_date"] else 0,
|
| 714 |
+
"freshness_score": 0.5, # neutral — freshness is now scored separately in Step 4
|
| 715 |
+
}])
|
| 716 |
+
|
| 717 |
+
df_cats = pd.DataFrame([{
|
| 718 |
+
"text_length_bucket": feat["text_length_bucket"],
|
| 719 |
+
"source_domain": feat["source_domain"],
|
| 720 |
+
}])
|
| 721 |
+
cat_feats = meta_preprocessor.transform(df_cats)
|
| 722 |
+
X_meta = np.hstack((df_meta.values, cat_feats))
|
| 723 |
+
|
| 724 |
+
meta_proba = float(meta_model.predict_proba(X_meta)[:, 1][0])
|
| 725 |
+
|
| 726 |
+
# Short-text dampening (under 50 words)
|
| 727 |
+
short_text = feat["word_count"] < 50
|
| 728 |
+
if short_text:
|
| 729 |
+
meta_proba = 0.5 + (meta_proba - 0.5) * 0.6
|
| 730 |
+
|
| 731 |
+
# ── 5-Signal Scoring ─────────────────────────────────────────────────
|
| 732 |
+
|
| 733 |
+
# Step 1: Source Credibility
|
| 734 |
+
source_score, author_found, typosquat = score_source_credibility(
|
| 735 |
+
feat["source_domain"], title, text
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# Step 2: Claim Verification
|
| 739 |
+
claim_score, entities_found, n_verifiable, q_attr, q_total = score_claim_verification(
|
| 740 |
+
meta_proba, feat["clean_text"], title
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# Step 3: Linguistic Analysis (depends on author_found from Step 1)
|
| 744 |
+
ling_score, deductions, headline_contradicts = score_linguistic_quality(
|
| 745 |
+
title, text, feat["clean_text"], author_found, cfg
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Step 4: Freshness
|
| 749 |
+
fresh_score, fresh_case, fresh_signals = score_freshness_v2(
|
| 750 |
+
feat.get("published_date"), feat["has_date"], title, text
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# Step 5: Model Vote
|
| 754 |
+
vote_score = score_model_vote(votes)
|
| 755 |
+
|
| 756 |
+
# ── Final Weighted Score ──────────────────────────────────────────────
|
| 757 |
+
|
| 758 |
+
scores = {
|
| 759 |
+
"source": round(source_score, 4),
|
| 760 |
+
"claim": round(claim_score, 4),
|
| 761 |
+
"linguistic": round(ling_score, 4),
|
| 762 |
+
"freshness": round(fresh_score, 4),
|
| 763 |
+
"model_vote": round(vote_score, 4),
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
final_score = (
|
| 767 |
+
source_score * 0.30 +
|
| 768 |
+
claim_score * 0.30 +
|
| 769 |
+
ling_score * 0.20 +
|
| 770 |
+
fresh_score * 0.10 +
|
| 771 |
+
vote_score * 0.10
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# ── Adversarial Override ──────────────────────────────────────────────
|
| 775 |
+
|
| 776 |
+
adv_flags = check_adversarial_flags(
|
| 777 |
+
feat["has_date"], author_found, n_verifiable,
|
| 778 |
+
headline_contradicts, typosquat, feat["clean_text"]
|
| 779 |
+
)
|
| 780 |
+
if adv_flags:
|
| 781 |
+
final_score = min(final_score, 0.25)
|
| 782 |
+
|
| 783 |
+
final_score = round(final_score, 4)
|
| 784 |
+
|
| 785 |
+
# ── Verdict ───────────────────────────────────────────────────────────
|
| 786 |
+
|
| 787 |
+
if final_score >= 0.75:
|
| 788 |
+
verdict = "TRUE"
|
| 789 |
+
elif final_score >= 0.55:
|
| 790 |
+
verdict = "UNCERTAIN"
|
| 791 |
+
elif final_score >= 0.35:
|
| 792 |
+
verdict = "LIKELY FALSE"
|
| 793 |
+
else:
|
| 794 |
+
verdict = "FALSE"
|
| 795 |
+
|
| 796 |
+
# ── Reasons & Missing Signals ─────────────────────────────────────────
|
| 797 |
+
|
| 798 |
+
top_reasons, missing_signals = build_reasons_and_missing(
|
| 799 |
+
scores, n_verifiable, author_found, feat["has_date"],
|
| 800 |
+
deductions, adv_flags
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
# ── Confidence ────────────────────────────────────────────────────────
|
| 804 |
+
|
| 805 |
+
missing_count = len(missing_signals)
|
| 806 |
+
if adv_flags or missing_count >= 3:
|
| 807 |
+
confidence = "LOW"
|
| 808 |
+
elif verdict == "UNCERTAIN" or missing_count in (1, 2):
|
| 809 |
+
confidence = "MEDIUM"
|
| 810 |
+
elif final_score >= 0.75 or final_score < 0.35:
|
| 811 |
+
confidence = "HIGH"
|
| 812 |
+
else:
|
| 813 |
+
confidence = "MEDIUM"
|
| 814 |
+
|
| 815 |
+
# ── Recommended Action + LOW Guard ──────────────────���─────────────────
|
| 816 |
+
|
| 817 |
+
action_map = {
|
| 818 |
+
"TRUE": "Publish",
|
| 819 |
+
"UNCERTAIN": "Flag for review",
|
| 820 |
+
"LIKELY FALSE": "Suppress",
|
| 821 |
+
"FALSE": "Escalate",
|
| 822 |
+
}
|
| 823 |
+
recommended_action = action_map[verdict]
|
| 824 |
+
|
| 825 |
+
# Hard rule: LOW confidence → never "Publish"
|
| 826 |
+
if confidence == "LOW" and recommended_action == "Publish":
|
| 827 |
+
recommended_action = "Flag for review"
|
| 828 |
+
|
| 829 |
+
# ── Return Full JSON ──────────────────────────────────────────────────
|
| 830 |
+
|
| 831 |
+
return {
|
| 832 |
+
"verdict": verdict,
|
| 833 |
+
"final_score": final_score,
|
| 834 |
+
"scores": scores,
|
| 835 |
+
"freshness_case": fresh_case,
|
| 836 |
+
"freshness_signals_found": fresh_signals,
|
| 837 |
+
"adversarial_flags": adv_flags,
|
| 838 |
+
"top_reasons": top_reasons,
|
| 839 |
+
"missing_signals": missing_signals,
|
| 840 |
+
"confidence": confidence,
|
| 841 |
+
"recommended_action": recommended_action,
|
| 842 |
+
"base_model_votes": votes,
|
| 843 |
+
"base_model_probas": probas,
|
| 844 |
+
"word_count": feat["word_count"],
|
| 845 |
+
"short_text_warning": short_text,
|
| 846 |
+
"deductions_applied": deductions,
|
| 847 |
+
"entities_found": entities_found,
|
| 848 |
+
"quotes_attributed": q_attr,
|
| 849 |
+
"quotes_total": q_total,
|
| 850 |
+
}
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
if __name__ == "__main__":
|
| 854 |
+
import json
|
| 855 |
+
try:
|
| 856 |
+
res = predict_article(
|
| 857 |
+
"Breaking: AI solves P=NP",
|
| 858 |
+
"The algorithm has shocked absolutely everyone across the earth entirely "
|
| 859 |
+
"resolving everything overnight. Sources say it is unprecedented.",
|
| 860 |
+
"techcrunch.com",
|
| 861 |
+
datetime.now().isoformat(),
|
| 862 |
+
mode="fast"
|
| 863 |
+
)
|
| 864 |
+
print("Verdict Dict:")
|
| 865 |
+
print(json.dumps(res, indent=2, default=str))
|
| 866 |
+
except ModelNotTrainedError as e:
|
| 867 |
+
print("ERROR:", str(e))
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Fake News Detection — Utilities Package
|
src/utils/deduplication.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
deduplication.py — Fast near-duplicate removal for large datasets.
|
| 3 |
+
|
| 4 |
+
Uses a two-phase strategy:
|
| 5 |
+
1. **Exact dedup** — hash-based O(n) removal of identical texts.
|
| 6 |
+
2. **Near-dedup via Sentence-BERT** — encode texts, build a cosine
|
| 7 |
+
similarity index, and remove near-duplicate pairs above a
|
| 8 |
+
configurable threshold. Uses chunked approach with early
|
| 9 |
+
termination to keep runtime feasible on 100K+ rows.
|
| 10 |
+
|
| 11 |
+
The ``all-MiniLM-L6-v2`` model is used for embedding.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import hashlib
|
| 15 |
+
import logging
|
| 16 |
+
import time
|
| 17 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import pandas as pd
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
_DEFAULT_MODEL = "all-MiniLM-L6-v2"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ═══════════════════════════════════════════════════════════
|
| 28 |
+
# Phase 1: Exact dedup (hash-based, O(n))
|
| 29 |
+
# ═══════════════════════════════════════════════════════════
|
| 30 |
+
|
| 31 |
+
def _exact_dedup(df: pd.DataFrame, text_column: str) -> Tuple[pd.DataFrame, int]:
|
| 32 |
+
"""Remove rows with identical text via SHA-256 hashing.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
df: Input DataFrame.
|
| 36 |
+
text_column: Column to hash for exact comparison.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
(deduplicated DataFrame, number of rows removed).
|
| 40 |
+
"""
|
| 41 |
+
before = len(df)
|
| 42 |
+
hashes: Dict[str, int] = {}
|
| 43 |
+
keep: List[bool] = []
|
| 44 |
+
|
| 45 |
+
for idx, txt in enumerate(df[text_column].fillna("").astype(str)):
|
| 46 |
+
h = hashlib.sha256(txt.encode("utf-8", errors="replace")).hexdigest()
|
| 47 |
+
if h in hashes:
|
| 48 |
+
keep.append(False)
|
| 49 |
+
else:
|
| 50 |
+
hashes[h] = idx
|
| 51 |
+
keep.append(True)
|
| 52 |
+
|
| 53 |
+
df_out = df.loc[keep].reset_index(drop=True)
|
| 54 |
+
removed = before - len(df_out)
|
| 55 |
+
logger.info("Exact dedup: removed %d / %d identical rows", removed, before)
|
| 56 |
+
return df_out, removed
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ═══════════════════════════════════════════════════════════
|
| 60 |
+
# Phase 2: Semantic near-dedup (Sentence-BERT + chunked cosine)
|
| 61 |
+
# ═══════════════════════════════════════════════════════════
|
| 62 |
+
|
| 63 |
+
def _semantic_dedup(
|
| 64 |
+
df: pd.DataFrame,
|
| 65 |
+
text_column: str,
|
| 66 |
+
threshold: float,
|
| 67 |
+
batch_size: int,
|
| 68 |
+
model_name: str,
|
| 69 |
+
max_rows_for_pairwise: int = 30_000,
|
| 70 |
+
) -> Tuple[pd.DataFrame, int]:
|
| 71 |
+
"""Remove near-duplicate rows using Sentence-BERT cosine similarity.
|
| 72 |
+
|
| 73 |
+
For datasets larger than *max_rows_for_pairwise*, the comparison is
|
| 74 |
+
done in a block-diagonal fashion (each chunk vs. itself) to keep
|
| 75 |
+
computation tractable. Cross-chunk duplicates are rare across
|
| 76 |
+
dataset origins, and exact dedup already handles identical pairs.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
df: Input DataFrame (already exact-deduped).
|
| 80 |
+
text_column: Column to encode.
|
| 81 |
+
threshold: Cosine similarity cutoff.
|
| 82 |
+
batch_size: Encoding batch size.
|
| 83 |
+
model_name: Sentence-BERT model name.
|
| 84 |
+
max_rows_for_pairwise: Max rows for full pairwise comparison.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
(deduplicated DataFrame, number of rows removed).
|
| 88 |
+
"""
|
| 89 |
+
from sentence_transformers import SentenceTransformer
|
| 90 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 91 |
+
|
| 92 |
+
n = len(df)
|
| 93 |
+
if n < 2:
|
| 94 |
+
return df.copy(), 0
|
| 95 |
+
|
| 96 |
+
texts = df[text_column].fillna("").astype(str).tolist()
|
| 97 |
+
# Truncate long texts to first 256 chars for fast encoding
|
| 98 |
+
texts_trunc = [t[:256] for t in texts]
|
| 99 |
+
|
| 100 |
+
logger.info(
|
| 101 |
+
"Encoding %d texts with %s (batch_size=%d) …",
|
| 102 |
+
n, model_name, batch_size,
|
| 103 |
+
)
|
| 104 |
+
model = SentenceTransformer(model_name)
|
| 105 |
+
embeddings = model.encode(
|
| 106 |
+
texts_trunc,
|
| 107 |
+
batch_size=batch_size,
|
| 108 |
+
show_progress_bar=True,
|
| 109 |
+
convert_to_numpy=True,
|
| 110 |
+
normalize_embeddings=True,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
duplicate_indices: Set[int] = set()
|
| 114 |
+
|
| 115 |
+
if n <= max_rows_for_pairwise:
|
| 116 |
+
# Full pairwise — feasible for ≤ 30K rows
|
| 117 |
+
logger.info("Running full pairwise cosine similarity (%d × %d) …", n, n)
|
| 118 |
+
chunk_size = 2000
|
| 119 |
+
for start in range(0, n, chunk_size):
|
| 120 |
+
end = min(start + chunk_size, n)
|
| 121 |
+
sim = cosine_similarity(embeddings[start:end], embeddings)
|
| 122 |
+
for li in range(sim.shape[0]):
|
| 123 |
+
gi = start + li
|
| 124 |
+
if gi in duplicate_indices:
|
| 125 |
+
continue
|
| 126 |
+
# Only compare with later-indexed rows
|
| 127 |
+
for j in range(gi + 1, n):
|
| 128 |
+
if j in duplicate_indices:
|
| 129 |
+
continue
|
| 130 |
+
if sim[li, j] >= threshold:
|
| 131 |
+
duplicate_indices.add(j)
|
| 132 |
+
else:
|
| 133 |
+
# For very large datasets: compare within blocks of 10K rows
|
| 134 |
+
logger.info(
|
| 135 |
+
"Dataset too large (%d) for full pairwise — using block dedup",
|
| 136 |
+
n,
|
| 137 |
+
)
|
| 138 |
+
block_size = 10_000
|
| 139 |
+
for block_start in range(0, n, block_size):
|
| 140 |
+
block_end = min(block_start + block_size, n)
|
| 141 |
+
block_emb = embeddings[block_start:block_end]
|
| 142 |
+
block_n = block_end - block_start
|
| 143 |
+
logger.info(
|
| 144 |
+
" Block [%d:%d] (%d rows) …",
|
| 145 |
+
block_start, block_end, block_n,
|
| 146 |
+
)
|
| 147 |
+
sim = cosine_similarity(block_emb, block_emb)
|
| 148 |
+
for li in range(block_n):
|
| 149 |
+
gi = block_start + li
|
| 150 |
+
if gi in duplicate_indices:
|
| 151 |
+
continue
|
| 152 |
+
for lj in range(li + 1, block_n):
|
| 153 |
+
gj = block_start + lj
|
| 154 |
+
if gj in duplicate_indices:
|
| 155 |
+
continue
|
| 156 |
+
if sim[li, lj] >= threshold:
|
| 157 |
+
duplicate_indices.add(gj)
|
| 158 |
+
|
| 159 |
+
removed = len(duplicate_indices)
|
| 160 |
+
if removed > 0:
|
| 161 |
+
keep_mask = np.ones(n, dtype=bool)
|
| 162 |
+
for idx in duplicate_indices:
|
| 163 |
+
keep_mask[idx] = False
|
| 164 |
+
df_out = df.loc[keep_mask].reset_index(drop=True)
|
| 165 |
+
else:
|
| 166 |
+
df_out = df.copy()
|
| 167 |
+
|
| 168 |
+
logger.info("Semantic dedup: removed %d / %d near-duplicate rows", removed, n)
|
| 169 |
+
return df_out, removed
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ═══════════════════════════════════════════════════════════
|
| 173 |
+
# Public API
|
| 174 |
+
# ═══════════════════════════════════════════════════════════
|
| 175 |
+
|
| 176 |
+
def deduplicate_dataframe(
|
| 177 |
+
df: pd.DataFrame,
|
| 178 |
+
text_column: str = "text",
|
| 179 |
+
threshold: float = 0.92,
|
| 180 |
+
batch_size: int = 64,
|
| 181 |
+
model_name: str = _DEFAULT_MODEL,
|
| 182 |
+
origin_column: Optional[str] = "dataset_origin",
|
| 183 |
+
) -> Tuple[pd.DataFrame, Dict[str, int]]:
|
| 184 |
+
"""Remove duplicate rows from *df* (exact + semantic).
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
df: Input DataFrame (must contain *text_column*).
|
| 188 |
+
text_column: Column to use for duplicate detection.
|
| 189 |
+
threshold: Cosine similarity cutoff for near-dedup.
|
| 190 |
+
batch_size: Encoding batch size.
|
| 191 |
+
model_name: Sentence-BERT model identifier.
|
| 192 |
+
origin_column: Optional column for per-origin stats.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
(cleaned DataFrame, stats dict with per-origin removal counts).
|
| 196 |
+
"""
|
| 197 |
+
t0 = time.perf_counter()
|
| 198 |
+
logger.info("=" * 60)
|
| 199 |
+
logger.info("Starting deduplication pipeline (threshold=%.2f) …", threshold)
|
| 200 |
+
n_before = len(df)
|
| 201 |
+
|
| 202 |
+
# Phase 1: exact
|
| 203 |
+
df_exact, exact_removed = _exact_dedup(df, text_column)
|
| 204 |
+
|
| 205 |
+
# Phase 2: semantic
|
| 206 |
+
df_final, semantic_removed = _semantic_dedup(
|
| 207 |
+
df_exact,
|
| 208 |
+
text_column=text_column,
|
| 209 |
+
threshold=threshold,
|
| 210 |
+
batch_size=batch_size,
|
| 211 |
+
model_name=model_name,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
total_removed = n_before - len(df_final)
|
| 215 |
+
|
| 216 |
+
# Build per-origin stats
|
| 217 |
+
stats: Dict[str, int] = {}
|
| 218 |
+
if origin_column and origin_column in df.columns:
|
| 219 |
+
before_counts = df[origin_column].value_counts().to_dict()
|
| 220 |
+
after_counts = df_final[origin_column].value_counts().to_dict()
|
| 221 |
+
for origin in before_counts:
|
| 222 |
+
stats[origin] = before_counts[origin] - after_counts.get(origin, 0)
|
| 223 |
+
else:
|
| 224 |
+
stats["total"] = total_removed
|
| 225 |
+
|
| 226 |
+
elapsed = time.perf_counter() - t0
|
| 227 |
+
logger.info(
|
| 228 |
+
"Dedup complete: %d → %d rows (removed %d, %.1f%%) in %.1fs",
|
| 229 |
+
n_before, len(df_final), total_removed,
|
| 230 |
+
100 * total_removed / max(n_before, 1), elapsed,
|
| 231 |
+
)
|
| 232 |
+
for origin, cnt in stats.items():
|
| 233 |
+
if cnt > 0:
|
| 234 |
+
logger.info(" %-30s %6d removed", origin, cnt)
|
| 235 |
+
logger.info("=" * 60)
|
| 236 |
+
|
| 237 |
+
return df_final, stats
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ─── standalone test ────────────────────────────────────────
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
logging.basicConfig(level=logging.INFO)
|
| 243 |
+
sample = pd.DataFrame({
|
| 244 |
+
"text": [
|
| 245 |
+
"The president signed the bill into law today.",
|
| 246 |
+
"The president signed the bill into law today.", # exact dup
|
| 247 |
+
"Scientists discover a new species of frog in the Amazon.",
|
| 248 |
+
"A new frog species has been found in the Amazon rainforest.", # near dup
|
| 249 |
+
"Stock markets rallied after a strong jobs report.",
|
| 250 |
+
],
|
| 251 |
+
"dataset_origin": ["a", "a", "b", "b", "c"],
|
| 252 |
+
})
|
| 253 |
+
clean, info = deduplicate_dataframe(sample, threshold=0.92)
|
| 254 |
+
print(f"\nKept {len(clean)} / {len(sample)} rows")
|
| 255 |
+
print("Stats:", info)
|
| 256 |
+
print(clean[["text", "dataset_origin"]])
|
src/utils/domain_weights.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
def compute_domain_weights(
|
| 8 |
+
df: pd.DataFrame,
|
| 9 |
+
min_domain_samples: int = 20,
|
| 10 |
+
max_multiplier: float = 10.0
|
| 11 |
+
) -> pd.DataFrame:
|
| 12 |
+
"""
|
| 13 |
+
Compute domain-aware sample weights to handle heavily biased domains.
|
| 14 |
+
|
| 15 |
+
Rules:
|
| 16 |
+
- Domains with < min_domain_samples -> merge into "other"
|
| 17 |
+
- Calculate global class distributions.
|
| 18 |
+
- Calculate per-domain class distributions.
|
| 19 |
+
- Compute weight = global_class_ratio / domain_class_ratio
|
| 20 |
+
- Clip weights at max_multiplier * median_weight
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
df: Input DataFrame containing 'source_domain' and 'binary_label'
|
| 24 |
+
min_domain_samples: Threshold below which domains are grouped to 'other'
|
| 25 |
+
max_multiplier: Max multiplier over the median weight to clip extreme weights
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
DataFrame with an additional 'sample_weight' column.
|
| 29 |
+
"""
|
| 30 |
+
df = df.copy()
|
| 31 |
+
|
| 32 |
+
# Ensure source_domain exists
|
| 33 |
+
if "source_domain" not in df.columns:
|
| 34 |
+
logger.warning("'source_domain' not found in DataFrame. Returning weights=1.0")
|
| 35 |
+
df["sample_weight"] = 1.0
|
| 36 |
+
return df
|
| 37 |
+
|
| 38 |
+
# 1. Merge small domains into "other"
|
| 39 |
+
domain_counts = df["source_domain"].value_counts()
|
| 40 |
+
small_domains = set(domain_counts[domain_counts < min_domain_samples].index)
|
| 41 |
+
|
| 42 |
+
df["_effective_domain"] = df["source_domain"].apply(
|
| 43 |
+
lambda x: "other" if x in small_domains or not isinstance(x, str) else x
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# 2. Compute global class ratios
|
| 47 |
+
global_counts = df["binary_label"].value_counts()
|
| 48 |
+
global_total = len(df)
|
| 49 |
+
global_ratio = {
|
| 50 |
+
label: count / global_total
|
| 51 |
+
for label, count in global_counts.items()
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# 3. Compute domain class ratios and assign weights
|
| 55 |
+
# We group by domain and label to get counts per domain
|
| 56 |
+
domain_label_counts = df.groupby(["_effective_domain", "binary_label"]).size().unstack(fill_value=0)
|
| 57 |
+
domain_totals = domain_label_counts.sum(axis=1)
|
| 58 |
+
|
| 59 |
+
weights_map = {}
|
| 60 |
+
for domain in domain_label_counts.index:
|
| 61 |
+
weights_map[domain] = {}
|
| 62 |
+
d_total = domain_totals[domain]
|
| 63 |
+
for label in global_ratio.keys():
|
| 64 |
+
if label in domain_label_counts.columns:
|
| 65 |
+
d_count = domain_label_counts.loc[domain, label]
|
| 66 |
+
if d_count == 0:
|
| 67 |
+
# If domain has 0 instances of this class, we won't observe it here anyway,
|
| 68 |
+
# but set some fallback value.
|
| 69 |
+
weights_map[domain][label] = 1.0
|
| 70 |
+
else:
|
| 71 |
+
d_ratio = d_count / d_total
|
| 72 |
+
weights_map[domain][label] = global_ratio[label] / d_ratio
|
| 73 |
+
else:
|
| 74 |
+
weights_map[domain][label] = 1.0
|
| 75 |
+
|
| 76 |
+
# 4. Map weights back to dataframe
|
| 77 |
+
df["sample_weight"] = df.apply(
|
| 78 |
+
lambda r: weights_map[r["_effective_domain"]].get(r["binary_label"], 1.0),
|
| 79 |
+
axis=1
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# 5. Clip weights at max_multiplier * median_weight
|
| 83 |
+
median_w = df["sample_weight"].median()
|
| 84 |
+
max_w = max_multiplier * median_w
|
| 85 |
+
df["sample_weight"] = df["sample_weight"].clip(upper=max_w)
|
| 86 |
+
|
| 87 |
+
# Clean up temp col
|
| 88 |
+
df.drop(columns=["_effective_domain"], inplace=True)
|
| 89 |
+
|
| 90 |
+
logger.info("Computed domain weights (median: %.3f, max applied: %.3f)", median_w, df["sample_weight"].max())
|
| 91 |
+
|
| 92 |
+
return df
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
# Test script
|
| 96 |
+
data = pd.DataFrame({
|
| 97 |
+
"source_domain": ["nytimes.com"] * 100 + ["fakenews.biz"] * 100 + ["tinyblog.com"] * 5,
|
| 98 |
+
"binary_label": [1] * 90 + [0] * 10 + [0] * 95 + [1] * 5 + [0] * 5
|
| 99 |
+
})
|
| 100 |
+
|
| 101 |
+
out = compute_domain_weights(data, min_domain_samples=20, max_multiplier=10.0)
|
| 102 |
+
print(out.groupby("source_domain")["sample_weight"].mean())
|
src/utils/freshness.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timezone
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def calculate_freshness(
|
| 6 |
+
published_date,
|
| 7 |
+
has_date: bool,
|
| 8 |
+
is_inference: bool = False,
|
| 9 |
+
reference_date: datetime = None
|
| 10 |
+
) -> float:
|
| 11 |
+
"""
|
| 12 |
+
Calculate the temporal freshness score for a single article.
|
| 13 |
+
|
| 14 |
+
Rules:
|
| 15 |
+
- score = 1.0 if article is < 30 days old
|
| 16 |
+
- score = max(0.1, 1 - (days_old / 365)) for older articles
|
| 17 |
+
- score = 0.5 if has_date is False (neutral for training)
|
| 18 |
+
- score = 0.35 if has_date is False AND called from inference
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
published_date: The published date of the article (datetime or NaT).
|
| 22 |
+
has_date: Boolean flag indicating if a valid date is present.
|
| 23 |
+
is_inference: Whether the scoring is happening during live inference.
|
| 24 |
+
reference_date: The date to compute 'days_old' against (defaults to now).
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Float score between 0.1 and 1.0.
|
| 28 |
+
"""
|
| 29 |
+
if not has_date or pd.isna(published_date):
|
| 30 |
+
return 0.35 if is_inference else 0.50
|
| 31 |
+
|
| 32 |
+
if reference_date is None:
|
| 33 |
+
reference_date = datetime.now(timezone.utc)
|
| 34 |
+
|
| 35 |
+
# Ensure published_date is timezone-aware
|
| 36 |
+
if pd.api.types.is_scalar(published_date) and getattr(published_date, 'tzinfo', None) is None:
|
| 37 |
+
# Assuming UTC if naive, typical for web dates
|
| 38 |
+
try:
|
| 39 |
+
published_date = published_date.replace(tzinfo=timezone.utc)
|
| 40 |
+
except Exception:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
days_old = (reference_date - published_date).days
|
| 44 |
+
|
| 45 |
+
# Handle future dates gracefully (e.g., bad parsed data)
|
| 46 |
+
if days_old < 0:
|
| 47 |
+
days_old = 0
|
| 48 |
+
|
| 49 |
+
if days_old < 30:
|
| 50 |
+
return 1.0
|
| 51 |
+
|
| 52 |
+
return max(0.1, 1.0 - (days_old / 365.0))
|
| 53 |
+
|
| 54 |
+
def apply_freshness_score(df: pd.DataFrame, is_inference: bool = False) -> pd.DataFrame:
|
| 55 |
+
"""
|
| 56 |
+
Apply freshness scoring to a DataFrame.
|
| 57 |
+
"""
|
| 58 |
+
df = df.copy()
|
| 59 |
+
ref_date = datetime.now(timezone.utc)
|
| 60 |
+
|
| 61 |
+
# Vectorized execution wrapper
|
| 62 |
+
df["freshness_score"] = df.apply(
|
| 63 |
+
lambda r: calculate_freshness(
|
| 64 |
+
r.get("published_date"),
|
| 65 |
+
r.get("has_date", pd.notna(r.get("published_date"))),
|
| 66 |
+
is_inference,
|
| 67 |
+
ref_date
|
| 68 |
+
),
|
| 69 |
+
axis=1
|
| 70 |
+
)
|
| 71 |
+
return df
|
src/utils/rag_retrieval.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import yaml
|
| 4 |
+
import logging
|
| 5 |
+
import spacy
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from duckduckgo_search import DDGS
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 11 |
+
|
| 12 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 13 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger("rag_retrieval")
|
| 17 |
+
|
| 18 |
+
# Lazy-load massive models securely within the file
|
| 19 |
+
_NLP = None
|
| 20 |
+
_SIM_MODEL = None
|
| 21 |
+
|
| 22 |
+
def load_spacy():
|
| 23 |
+
global _NLP
|
| 24 |
+
if _NLP is None:
|
| 25 |
+
try:
|
| 26 |
+
_NLP = spacy.load("en_core_web_sm")
|
| 27 |
+
except OSError:
|
| 28 |
+
logger.info("Downloading spaCy en_core_web_sm model dynamically...")
|
| 29 |
+
import subprocess
|
| 30 |
+
subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
| 31 |
+
_NLP = spacy.load("en_core_web_sm")
|
| 32 |
+
return _NLP
|
| 33 |
+
|
| 34 |
+
def load_sim_model():
|
| 35 |
+
global _SIM_MODEL
|
| 36 |
+
if _SIM_MODEL is None:
|
| 37 |
+
_SIM_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
| 38 |
+
return _SIM_MODEL
|
| 39 |
+
|
| 40 |
+
def extract_focused_query(title: str, text: str) -> str:
|
| 41 |
+
"""
|
| 42 |
+
Extracts the top 3 named entities + main noun phrases to form a focused DDG query.
|
| 43 |
+
"""
|
| 44 |
+
nlp = load_spacy()
|
| 45 |
+
# Prioritize title if strong, else merge heavily
|
| 46 |
+
target = title if (isinstance(title, str) and len(title.split()) > 4) else (str(title) + " " + str(text))[:1000]
|
| 47 |
+
|
| 48 |
+
doc = nlp(target)
|
| 49 |
+
|
| 50 |
+
# 1. Grab Entities (ORG, PERSON, GPE, DATE, EVENT)
|
| 51 |
+
entities = [ent.text for ent in doc.ents if ent.label_ in ['ORG', 'PERSON', 'GPE', 'EVENT']]
|
| 52 |
+
unique_entities = list(dict.fromkeys(entities))[:3]
|
| 53 |
+
|
| 54 |
+
# 2. Grab top Noun Phrases if entities are missing
|
| 55 |
+
noun_phrases = [chunk.text for chunk in doc.noun_chunks]
|
| 56 |
+
|
| 57 |
+
query_parts = unique_entities.copy()
|
| 58 |
+
if len(query_parts) < 3:
|
| 59 |
+
for np_chunk in noun_phrases:
|
| 60 |
+
if np_chunk not in query_parts and len(np_chunk.split()) <= 3:
|
| 61 |
+
query_parts.append(np_chunk)
|
| 62 |
+
if len(query_parts) >= 3:
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
focused_query = " ".join(query_parts)
|
| 66 |
+
if not focused_query.strip():
|
| 67 |
+
# Fallback to pure headline first 5 words
|
| 68 |
+
focused_query = " ".join(target.split()[:5])
|
| 69 |
+
|
| 70 |
+
return focused_query
|
| 71 |
+
|
| 72 |
+
def execute_rag(title: str, text: str):
|
| 73 |
+
"""
|
| 74 |
+
1. Extracts Query.
|
| 75 |
+
2. DuckDuckGo Search (top 5).
|
| 76 |
+
3. Measure Similarity vs article body via all-MiniLM-L6-v2.
|
| 77 |
+
4. Return strict evaluations.
|
| 78 |
+
"""
|
| 79 |
+
cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml")
|
| 80 |
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 81 |
+
cfg = yaml.safe_load(f)
|
| 82 |
+
rag_cfg = cfg.get("rag", {})
|
| 83 |
+
|
| 84 |
+
top_k = rag_cfg.get("top_k", 5)
|
| 85 |
+
support_thresh = rag_cfg.get("support_threshold", 0.65)
|
| 86 |
+
conflict_thresh = rag_cfg.get("conflict_threshold", 0.30)
|
| 87 |
+
|
| 88 |
+
query = extract_focused_query(title, text)
|
| 89 |
+
logger.info(f"RAG Triggered. Extracted Search Query: {query}")
|
| 90 |
+
|
| 91 |
+
search_results = []
|
| 92 |
+
try:
|
| 93 |
+
with DDGS() as ddgs:
|
| 94 |
+
results = list(ddgs.text(query, max_results=top_k))
|
| 95 |
+
search_results = [r.get("body", r.get("title", "")) for r in results if isinstance(r, dict)]
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"DDGS failure: {e}")
|
| 98 |
+
return {"status": "error", "message": "Search engine failure", "data": []}, "INCONCLUSIVE"
|
| 99 |
+
|
| 100 |
+
if not search_results:
|
| 101 |
+
return {"status": "empty", "message": "No external context found", "data": []}, "INCONCLUSIVE"
|
| 102 |
+
|
| 103 |
+
sim_model = load_sim_model()
|
| 104 |
+
|
| 105 |
+
# Compare target text against all snippets
|
| 106 |
+
corpus_text = (str(title) + " " + str(text))[:2000] # Cap memory context
|
| 107 |
+
embed_target = sim_model.encode([corpus_text])
|
| 108 |
+
embed_search = sim_model.encode(search_results)
|
| 109 |
+
|
| 110 |
+
similarities = cosine_similarity(embed_target, embed_search)[0]
|
| 111 |
+
|
| 112 |
+
supports = 0
|
| 113 |
+
conflicts = 0
|
| 114 |
+
eval_payload = []
|
| 115 |
+
|
| 116 |
+
for i, sim in enumerate(similarities):
|
| 117 |
+
s_float = float(sim)
|
| 118 |
+
if s_float >= support_thresh:
|
| 119 |
+
supports += 1
|
| 120 |
+
nature = "SUPPORTS"
|
| 121 |
+
elif s_float < conflict_thresh:
|
| 122 |
+
conflicts += 1
|
| 123 |
+
nature = "CONFLICTS"
|
| 124 |
+
else:
|
| 125 |
+
nature = "NEUTRAL"
|
| 126 |
+
|
| 127 |
+
eval_payload.append({
|
| 128 |
+
"snippet": search_results[i],
|
| 129 |
+
"similarity": s_float,
|
| 130 |
+
"nature": nature
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
# Rag Verdict Check
|
| 134 |
+
if supports >= 2:
|
| 135 |
+
verdict = "CORROBORATED"
|
| 136 |
+
elif conflicts >= 2:
|
| 137 |
+
verdict = "CONTRADICTED"
|
| 138 |
+
else:
|
| 139 |
+
verdict = "INCONCLUSIVE"
|
| 140 |
+
|
| 141 |
+
final_output = {
|
| 142 |
+
"status": "success",
|
| 143 |
+
"query": query,
|
| 144 |
+
"supports": supports,
|
| 145 |
+
"conflicts": conflicts,
|
| 146 |
+
"data": eval_payload
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
return final_output, verdict
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
t = "Eiffel Tower sold for scrap metal in surprising Paris decree."
|
| 153 |
+
tx = "The mayor of Paris declared the tower will be dismantled."
|
| 154 |
+
o, v = execute_rag(t, tx)
|
| 155 |
+
import json
|
| 156 |
+
print("Verdict:", v)
|
| 157 |
+
print(json.dumps(o, indent=2))
|