GPReconResNet — BraTS 2020 · 4-channel (T1 T2 T1CE FLAIR) · Axial

Architecture — GPReconResNet

GPReconResNet is a residual reconstruction network adapted for classification:

Hyperparameter Value
Residual blocks 14
Starting feature maps 64
Up/down-sampling blocks 2
Activation Leaky ReLU
Dropout (residual) 0.5
Upsampling sinc interpolation
Batch normalisation
3-D mode ✗ (2-D)

The reconstruction bottleneck forces the network to learn compact, semantically rich representations that are easy to interpret as class-discriminative maps.

Overview

This model is part of the GPModels family — a set of inherently-explainable convolutional networks for simultaneous brain-tumour classification and weakly-supervised segmentation from multi-contrast MRI. The models were trained and evaluated on the BraTS 2020 dataset.

  • Task: Multi-class brain-tumour classification (Healthy / LGG / HGG) and weakly-supervised segmentation
  • Orientation: Axial 2-D slices (240 × 240 px)
  • Input channels (order): T1 · T2 · T1CE · FLAIR (channel indices 0 → 3)
  • Output classes: 0 = Healthy, 1 = LGG (Low-grade glioma), 2 = HGG (High-grade glioma)
  • Normalization: Per-image max normalization (divide by slice maximum)
  • Preprint (open access): https://arxiv.org/abs/2206.05148
  • Code & training details: https://github.com/soumickmj/GPModels

Model Inputs / Outputs

Property Value
Input shape (B, 4, 240, 240) — float32, max-normalized
Output — train mode (B, 3) — raw logits
Output — eval mode ((B, 3), (B, 3, H, W)) — logits and spatial heatmap
Class order [Healthy, LGG, HGG]

Usage

import torch
from transformers import AutoConfig, AutoModel

# Load from HuggingFace Hub
model = AutoModel.from_pretrained("soumickmj/GPReconResNet_BraTS2020_T1T2T1ceFlair_Axial", trust_remote_code=True)
model.eval()

# Inference (B=1 example)
# x: tensor of shape (1, 4, 240, 240), channels = [T1, T2, T1CE, FLAIR]
x = torch.randn(1, 4, 240, 240)   # replace with real normalised slice
with torch.no_grad():
    logits = model(x)              # shape: (1, 3)
    probs  = torch.softmax(logits, dim=1)
    pred   = probs.argmax(dim=1)   # 0=Healthy, 1=LGG, 2=HGG

Weakly-Supervised Segmentation (Heatmap Mode)

These models are inherently explainable: by switching to eval mode the global max-pooling (GMP) step is bypassed in the decoder pathway, exposing a full-resolution spatial activation map for every class. No extra labels or re-training are required.

How it works:

  • model.train() → GMP is applied → returns logits (B, n_classes)classification only
  • model.eval() → GMP is skipped → returns (logits, heatmap) where heatmap has shape (B, n_classes, H, W) — one spatial map per class

The heatmap channels correspond to the same class order as the logits: 0 = Healthy, 1 = LGG, 2 = HGG. The whole-tumour map can be obtained by combining channels 1 and 2.

import torch
import torch.nn.functional as F
from transformers import AutoModel

model = AutoModel.from_pretrained("soumickmj/GPReconResNet_BraTS2020_T1T2T1ceFlair_Axial", trust_remote_code=True)
model.eval()                        # ← activates heatmap mode

x = torch.randn(1, 4, 240, 240)    # (B, 4, H, W): [T1, T2, T1CE, FLAIR], max-normalised

with torch.no_grad():
    logits, heatmap = model(x)      # heatmap: (B, 3, H, W)

pred_class = logits.argmax(dim=1)   # (B,)  — 0=Healthy, 1=LGG, 2=HGG

# --- Whole-tumour heatmap (LGG + HGG channels) ---
wt_map = heatmap[:, 1:, :, :].max(dim=1).values   # (B, H, W)

# --- Min-max normalise to [0, 1] ---
wt_flat = wt_map.view(wt_map.size(0), -1)
wt_min  = wt_flat.min(dim=1).values[:, None, None]
wt_max  = wt_flat.max(dim=1).values[:, None, None]
wt_norm = (wt_map - wt_min) / (wt_max - wt_min + 1e-8)   # (B, H, W)

# --- Binary mask via simple threshold ---
binary_mask = (wt_norm > 0.5).float()              # (B, H, W)

For more advanced post-processing (multi-Otsu thresholding, top-k binarisation, morphological clean-up, per-slice result aggregation, etc.) see the full post-processing scripts in the project repository.

Training Details

  • Dataset: BraTS 2020 (Kaggle release) — non-empty axial slices
  • Splits: Stratified 75 / 25 train-test split (seed 13), then 5-fold CV (fold 0 reported)
  • Optimizer: Adam (lr = 1e-3, weight_decay = 5e-4)
  • Loss: Cross-entropy with class-balanced weights
  • Mixed precision: AMP (fp16)
  • Max epochs: 300 | Grad. accumulation: 16 steps
  • Augmentation: Random rotation ±330°, horizontal & vertical flip

Citation

If you use this model please cite:


@article{chatterjee2026weakly,
  title={Weakly-supervised segmentation using inherently-explainable classification models and their application to brain tumour classification},
  author={Chatterjee, Soumick and Yassin, Hadya and Dubost, Florian and N{\"u}rnberger, Andreas and Speck, Oliver},
  journal={Neurocomputing},
  pages={133460},
  year={2026},
  publisher={Elsevier}
}

License

MIT — see https://github.com/soumickmj/GPModels for full details.

Downloads last month
29
Safetensors
Model size
17.3M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including soumickmj/GPReconResNet_BraTS2020_T1T2T1ceFlair_Axial

Paper for soumickmj/GPReconResNet_BraTS2020_T1T2T1ceFlair_Axial