CXR-LT 2026 Task 1 β€” ConvNeXtV2 + CSRA DB-CAS (πŸ† Top-1)

Top-1 submission for Task 1 (Long-tailed Multi-label Chest X-ray Classification) of the CXR-LT 2026 Challenge.


Architecture

ConvNeXtV2-Base  (timm, global_pool="", drop_path_rate=0.2)
  └── spatial feature map  (B, 1024, H, W)
  └── BatchNorm2d(1024)
  └── CSRA head  (Ξ»=0.1)
       β”œβ”€β”€ GAP branch : Linear(1024 β†’ 30)               β†’ logit_gap
       └── Attention  : Conv2d(1024β†’30, 1Γ—1) β†’ Softmax
                        β†’ class-wise weighted pool       β†’ logit_csra
       └── output     : logit_gap + 0.1 Γ— logit_csra

Total parameters: 87,758,367


Training Pipeline

Stage 1 β€” Pre-train on MIMIC-CXR (14 classes)  [FC/MLP head]
  ConvNeXtV2-Base (ImageNet-22k init)
  + head: BN β†’ Dropout β†’ Linear(1024β†’512) β†’ ReLU β†’ BN β†’ Dropout β†’ Linear(512β†’14)
  β”œβ”€β”€ Phase 1: head-only warm-up  (LR 1e-3, 3 epochs)
  └── Phase 2: full fine-tune    (backbone LR 1e-5 / head LR 1e-4)
      Loss: AsymmetricLoss + LogitAdjustment, EMA decay 0.9999

Stage 2 β€” DB-CAS fine-tune on PadChest (30 classes)  [FC head β†’ CSRA head]
  Backbone weights resumed from Stage 1; FC head discarded; new CSRA head initialized
  β”œβ”€β”€ Double-Balance (DB) sampling: balances label frequency + co-occurrence
  └── Class Activation Spatial (CAS): leverages CSRA spatial attention
      Loss: AsymmetricLoss, EMA decay 0.9999

Files

File Size Description
convnextv2_base_mimic-cxr_padchest_csra_dbcas.safetensors ~351 MB Weights only (recommended)
convnextv2_base_mimic-cxr_padchest_csra_dbcas.pth ~1.05 GB Full checkpoint incl. optimizer state
model.py β€” timm-compatible model registration

Usage

Option A β€” via timm

from huggingface_hub import hf_hub_download
import importlib.util, sys

# Register custom model (one-time)
path = hf_hub_download("hieuphamha/cxrlt2026-task1-convnextv2", "model.py")
spec = importlib.util.spec_from_file_location("cxrlt", path)
mod  = importlib.util.module_from_spec(spec)
sys.modules["cxrlt"] = mod
spec.loader.exec_module(mod)

import timm
model = timm.create_model("cxrlt2026_task1_csra_dbcas", pretrained=True)
model.eval()

Option B β€” manual load

import torch, torch.nn as nn, timm
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

class CSRA(nn.Module):
    def __init__(self, input_dim, num_classes, lam=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.lam         = lam
        self.classifier  = nn.Linear(input_dim, num_classes)
        self.conv_att    = nn.Conv2d(input_dim, num_classes, kernel_size=1, bias=False)
        self.softmax     = nn.Softmax(dim=2)

    def forward(self, x):
        b, c, h, w = x.size()
        logit_gap  = self.classifier(x.mean(dim=(2, 3)))
        att_score  = self.softmax(self.conv_att(x).view(b, self.num_classes, h * w))
        csra_feat  = torch.bmm(att_score, x.view(b, c, h * w).permute(0, 2, 1))
        logit_csra = (csra_feat * self.classifier.weight.unsqueeze(0)).sum(2) + self.classifier.bias
        return logit_gap + self.lam * logit_csra

class ConvNeXtV2CXR(nn.Module):
    def __init__(self, num_classes=30):
        super().__init__()
        self.backbone = timm.create_model(
            "convnextv2_base", pretrained=False, num_classes=0,
            drop_path_rate=0.2, global_pool="",
        )
        nf        = self.backbone.num_features   # 1024
        self.bn   = nn.BatchNorm2d(nf)
        self.head = CSRA(input_dim=nf, num_classes=num_classes, lam=0.1)

    def forward(self, x):
        return self.head(self.bn(self.backbone(x)))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ConvNeXtV2CXR(num_classes=30)
model.load_state_dict(
    load_file(hf_hub_download("hieuphamha/cxrlt2026-task1-convnextv2",
              "convnextv2_base_mimic-cxr_padchest_csra_dbcas.safetensors")),
    strict=True,
)
model.eval().to(device)

Inference

import cv2, numpy as np, torch
import torchvision.transforms as T

CLASS_NAMES = ['Normal', 'aortic elongation', 'cardiomegaly', 'pleural effusion', 'Nodule', 'atelectasis', 'pleural thickening', 'aortic atheromatosis', 'Support Devices', 'alveolar pattern', 'fracture', 'Hernia', 'Emphysema', 'azygos lobe', 'Hydropneumothorax', 'Kyphosis', 'Mass', 'Pneumothorax', 'Subcutaneous Emphysema', 'pneumoperitoneo', 'vascular hilar enlargement', 'vertebral degenerative changes', 'hyperinflated lung', 'interstitial pattern', 'central venous catheter', 'hypoexpansion', 'bronchiectasis', 'hemidiaphragm elevation', 'sternotomy', 'calcified densities']

transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def predict(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (512, 512))
    x   = transform(np.stack([img] * 3, axis=-1).astype(np.float32) / 255.0)
    x   = x.unsqueeze(0).to(device)
    with torch.no_grad():
        probs = torch.sigmoid(model(x)).squeeze(0).cpu().numpy()
    return dict(zip(CLASS_NAMES, probs.tolist()))

results = predict("chest_xray.png")
for cls, p in sorted(results.items(), key=lambda x: -x[1])[:5]:
    print(f"{p:.3f}  {cls}")

Classes (30 chest X-ray findings)

  1. Normal
  2. aortic elongation
  3. cardiomegaly
  4. pleural effusion
  5. Nodule
  6. atelectasis
  7. pleural thickening
  8. aortic atheromatosis
  9. Support Devices
  10. alveolar pattern
  11. fracture
  12. Hernia
  13. Emphysema
  14. azygos lobe
  15. Hydropneumothorax
  16. Kyphosis
  17. Mass
  18. Pneumothorax
  19. Subcutaneous Emphysema
  20. pneumoperitoneo
  21. vascular hilar enlargement
  22. vertebral degenerative changes
  23. hyperinflated lung
  24. interstitial pattern
  25. central venous catheter
  26. hypoexpansion
  27. bronchiectasis
  28. hemidiaphragm elevation
  29. sternotomy
  30. calcified densities

Input Specification

  • Image size: 512 Γ— 512 (grayscale chest X-ray β†’ 3-channel repeat)
  • Normalization: mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]

Citation

@article{Pham2026HandlingSS,
  title   = {Handling Supervision Scarcity in Chest X-ray Classification:
             Long-Tailed and Zero-Shot Learning},
  author  = {Ha-Hieu Pham and Hai-Dang Nguyen and Thanh-Huy Nguyen and
             Min Xu and Ulas Bagci and Trung-Nghia Le and Huy-Hieu Pham},
  journal = {ArXiv},
  year    = {2026},
  volume  = {abs/2602.13430},
  url     = {https://arxiv.org/abs/2602.13430}
}
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for hieuphamha/cxrlt2026-task1-convnextv2