Handling Supervision Scarcity in Chest X-ray Classification: Long-Tailed and Zero-Shot Learning
Paper β’ 2602.13430 β’ Published β’ 1
Top-1 submission for Task 1 (Long-tailed Multi-label Chest X-ray Classification) of the CXR-LT 2026 Challenge.
ConvNeXtV2-Base (timm, global_pool="", drop_path_rate=0.2)
βββ spatial feature map (B, 1024, H, W)
βββ BatchNorm2d(1024)
βββ CSRA head (Ξ»=0.1)
βββ GAP branch : Linear(1024 β 30) β logit_gap
βββ Attention : Conv2d(1024β30, 1Γ1) β Softmax
β class-wise weighted pool β logit_csra
βββ output : logit_gap + 0.1 Γ logit_csra
Total parameters: 87,758,367
Stage 1 β Pre-train on MIMIC-CXR (14 classes) [FC/MLP head]
ConvNeXtV2-Base (ImageNet-22k init)
+ head: BN β Dropout β Linear(1024β512) β ReLU β BN β Dropout β Linear(512β14)
βββ Phase 1: head-only warm-up (LR 1e-3, 3 epochs)
βββ Phase 2: full fine-tune (backbone LR 1e-5 / head LR 1e-4)
Loss: AsymmetricLoss + LogitAdjustment, EMA decay 0.9999
Stage 2 β DB-CAS fine-tune on PadChest (30 classes) [FC head β CSRA head]
Backbone weights resumed from Stage 1; FC head discarded; new CSRA head initialized
βββ Double-Balance (DB) sampling: balances label frequency + co-occurrence
βββ Class Activation Spatial (CAS): leverages CSRA spatial attention
Loss: AsymmetricLoss, EMA decay 0.9999
| File | Size | Description |
|---|---|---|
convnextv2_base_mimic-cxr_padchest_csra_dbcas.safetensors |
~351 MB | Weights only (recommended) |
convnextv2_base_mimic-cxr_padchest_csra_dbcas.pth |
~1.05 GB | Full checkpoint incl. optimizer state |
model.py |
β | timm-compatible model registration |
from huggingface_hub import hf_hub_download
import importlib.util, sys
# Register custom model (one-time)
path = hf_hub_download("hieuphamha/cxrlt2026-task1-convnextv2", "model.py")
spec = importlib.util.spec_from_file_location("cxrlt", path)
mod = importlib.util.module_from_spec(spec)
sys.modules["cxrlt"] = mod
spec.loader.exec_module(mod)
import timm
model = timm.create_model("cxrlt2026_task1_csra_dbcas", pretrained=True)
model.eval()
import torch, torch.nn as nn, timm
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
class CSRA(nn.Module):
def __init__(self, input_dim, num_classes, lam=0.1):
super().__init__()
self.num_classes = num_classes
self.lam = lam
self.classifier = nn.Linear(input_dim, num_classes)
self.conv_att = nn.Conv2d(input_dim, num_classes, kernel_size=1, bias=False)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
b, c, h, w = x.size()
logit_gap = self.classifier(x.mean(dim=(2, 3)))
att_score = self.softmax(self.conv_att(x).view(b, self.num_classes, h * w))
csra_feat = torch.bmm(att_score, x.view(b, c, h * w).permute(0, 2, 1))
logit_csra = (csra_feat * self.classifier.weight.unsqueeze(0)).sum(2) + self.classifier.bias
return logit_gap + self.lam * logit_csra
class ConvNeXtV2CXR(nn.Module):
def __init__(self, num_classes=30):
super().__init__()
self.backbone = timm.create_model(
"convnextv2_base", pretrained=False, num_classes=0,
drop_path_rate=0.2, global_pool="",
)
nf = self.backbone.num_features # 1024
self.bn = nn.BatchNorm2d(nf)
self.head = CSRA(input_dim=nf, num_classes=num_classes, lam=0.1)
def forward(self, x):
return self.head(self.bn(self.backbone(x)))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNeXtV2CXR(num_classes=30)
model.load_state_dict(
load_file(hf_hub_download("hieuphamha/cxrlt2026-task1-convnextv2",
"convnextv2_base_mimic-cxr_padchest_csra_dbcas.safetensors")),
strict=True,
)
model.eval().to(device)
import cv2, numpy as np, torch
import torchvision.transforms as T
CLASS_NAMES = ['Normal', 'aortic elongation', 'cardiomegaly', 'pleural effusion', 'Nodule', 'atelectasis', 'pleural thickening', 'aortic atheromatosis', 'Support Devices', 'alveolar pattern', 'fracture', 'Hernia', 'Emphysema', 'azygos lobe', 'Hydropneumothorax', 'Kyphosis', 'Mass', 'Pneumothorax', 'Subcutaneous Emphysema', 'pneumoperitoneo', 'vascular hilar enlargement', 'vertebral degenerative changes', 'hyperinflated lung', 'interstitial pattern', 'central venous catheter', 'hypoexpansion', 'bronchiectasis', 'hemidiaphragm elevation', 'sternotomy', 'calcified densities']
transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(image_path):
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (512, 512))
x = transform(np.stack([img] * 3, axis=-1).astype(np.float32) / 255.0)
x = x.unsqueeze(0).to(device)
with torch.no_grad():
probs = torch.sigmoid(model(x)).squeeze(0).cpu().numpy()
return dict(zip(CLASS_NAMES, probs.tolist()))
results = predict("chest_xray.png")
for cls, p in sorted(results.items(), key=lambda x: -x[1])[:5]:
print(f"{p:.3f} {cls}")
Normalaortic elongationcardiomegalypleural effusionNoduleatelectasispleural thickeningaortic atheromatosisSupport Devicesalveolar patternfractureHerniaEmphysemaazygos lobeHydropneumothoraxKyphosisMassPneumothoraxSubcutaneous Emphysemapneumoperitoneovascular hilar enlargementvertebral degenerative changeshyperinflated lunginterstitial patterncentral venous catheterhypoexpansionbronchiectasishemidiaphragm elevationsternotomycalcified densitiesmean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]@article{Pham2026HandlingSS,
title = {Handling Supervision Scarcity in Chest X-ray Classification:
Long-Tailed and Zero-Shot Learning},
author = {Ha-Hieu Pham and Hai-Dang Nguyen and Thanh-Huy Nguyen and
Min Xu and Ulas Bagci and Trung-Nghia Le and Huy-Hieu Pham},
journal = {ArXiv},
year = {2026},
volume = {abs/2602.13430},
url = {https://arxiv.org/abs/2602.13430}
}