GPReconResNet — BraTS 2020 · 4-channel (T1 T2 T1CE FLAIR) · Axial
Architecture — GPReconResNet
GPReconResNet is a residual reconstruction network adapted for classification:
| Hyperparameter | Value |
|---|---|
| Residual blocks | 14 |
| Starting feature maps | 64 |
| Up/down-sampling blocks | 2 |
| Activation | Leaky ReLU |
| Dropout (residual) | 0.5 |
| Upsampling | sinc interpolation |
| Batch normalisation | ✗ |
| 3-D mode | ✗ (2-D) |
The reconstruction bottleneck forces the network to learn compact, semantically rich representations that are easy to interpret as class-discriminative maps.
Overview
This model is part of the GPModels family — a set of inherently-explainable convolutional networks for simultaneous brain-tumour classification and weakly-supervised segmentation from multi-contrast MRI. The models were trained and evaluated on the BraTS 2020 dataset.
- Task: Multi-class brain-tumour classification (Healthy / LGG / HGG) and weakly-supervised segmentation
- Orientation: Axial 2-D slices (240 × 240 px)
- Input channels (order): T1 · T2 · T1CE · FLAIR (channel indices 0 → 3)
- Output classes: 0 = Healthy, 1 = LGG (Low-grade glioma), 2 = HGG (High-grade glioma)
- Normalization: Per-image max normalization (divide by slice maximum)
- Preprint (open access): https://arxiv.org/abs/2206.05148
- Code & training details: https://github.com/soumickmj/GPModels
Model Inputs / Outputs
| Property | Value |
|---|---|
| Input shape | (B, 4, 240, 240) — float32, max-normalized |
| Output — train mode | (B, 3) — raw logits |
| Output — eval mode | ((B, 3), (B, 3, H, W)) — logits and spatial heatmap |
| Class order | [Healthy, LGG, HGG] |
Usage
import torch
from transformers import AutoConfig, AutoModel
# Load from HuggingFace Hub
model = AutoModel.from_pretrained("soumickmj/GPReconResNet_BraTS2020_T1T2T1ceFlair_Axial", trust_remote_code=True)
model.eval()
# Inference (B=1 example)
# x: tensor of shape (1, 4, 240, 240), channels = [T1, T2, T1CE, FLAIR]
x = torch.randn(1, 4, 240, 240) # replace with real normalised slice
with torch.no_grad():
logits = model(x) # shape: (1, 3)
probs = torch.softmax(logits, dim=1)
pred = probs.argmax(dim=1) # 0=Healthy, 1=LGG, 2=HGG
Weakly-Supervised Segmentation (Heatmap Mode)
These models are inherently explainable: by switching to eval mode the global max-pooling (GMP) step is bypassed in the decoder pathway, exposing a full-resolution spatial activation map for every class. No extra labels or re-training are required.
How it works:
model.train()→ GMP is applied → returnslogits (B, n_classes)— classification onlymodel.eval()→ GMP is skipped → returns(logits, heatmap)whereheatmaphas shape(B, n_classes, H, W)— one spatial map per class
The heatmap channels correspond to the same class order as the logits:
0 = Healthy, 1 = LGG, 2 = HGG. The whole-tumour map can be obtained by combining channels 1 and 2.
import torch
import torch.nn.functional as F
from transformers import AutoModel
model = AutoModel.from_pretrained("soumickmj/GPReconResNet_BraTS2020_T1T2T1ceFlair_Axial", trust_remote_code=True)
model.eval() # ← activates heatmap mode
x = torch.randn(1, 4, 240, 240) # (B, 4, H, W): [T1, T2, T1CE, FLAIR], max-normalised
with torch.no_grad():
logits, heatmap = model(x) # heatmap: (B, 3, H, W)
pred_class = logits.argmax(dim=1) # (B,) — 0=Healthy, 1=LGG, 2=HGG
# --- Whole-tumour heatmap (LGG + HGG channels) ---
wt_map = heatmap[:, 1:, :, :].max(dim=1).values # (B, H, W)
# --- Min-max normalise to [0, 1] ---
wt_flat = wt_map.view(wt_map.size(0), -1)
wt_min = wt_flat.min(dim=1).values[:, None, None]
wt_max = wt_flat.max(dim=1).values[:, None, None]
wt_norm = (wt_map - wt_min) / (wt_max - wt_min + 1e-8) # (B, H, W)
# --- Binary mask via simple threshold ---
binary_mask = (wt_norm > 0.5).float() # (B, H, W)
For more advanced post-processing (multi-Otsu thresholding, top-k binarisation, morphological clean-up, per-slice result aggregation, etc.) see the full post-processing scripts in the project repository.
Training Details
- Dataset: BraTS 2020 (Kaggle release) — non-empty axial slices
- Splits: Stratified 75 / 25 train-test split (seed 13), then 5-fold CV (fold 0 reported)
- Optimizer: Adam (lr = 1e-3, weight_decay = 5e-4)
- Loss: Cross-entropy with class-balanced weights
- Mixed precision: AMP (fp16)
- Max epochs: 300 | Grad. accumulation: 16 steps
- Augmentation: Random rotation ±330°, horizontal & vertical flip
Citation
If you use this model please cite:
@article{chatterjee2026weakly,
title={Weakly-supervised segmentation using inherently-explainable classification models and their application to brain tumour classification},
author={Chatterjee, Soumick and Yassin, Hadya and Dubost, Florian and N{\"u}rnberger, Andreas and Speck, Oliver},
journal={Neurocomputing},
pages={133460},
year={2026},
publisher={Elsevier}
}
License
MIT — see https://github.com/soumickmj/GPModels for full details.
- Downloads last month
- 29