ResNet-50 U-Net — Kvasir-SEG Polyp Segmentation
Binary segmentation model for gastrointestinal polyp detection, trained on the Kvasir-SEG dataset.
Architecture
A U-Net with a ResNet-50 encoder pretrained on ImageNet.
| Component | Details |
|---|---|
| Encoder | ResNet-50 (ImageNet pretrained), all layers fine-tuned |
| Skip connections | Concatenation of encoder and decoder feature maps at each scale |
| Decoder | Transposed-conv ×2 upsampling + double conv block at each level |
| Output | Single-channel logit map (1 × H × W) |
Feature map sizes for a 256×256 input:
stride 2 -> 64ch, 128×128
stride 4 -> 256ch, 64×64
stride 8 -> 512ch, 32×32
stride 16 -> 1024ch, 16×16
stride 32 -> 2048ch, 8×8 (bottleneck)
Loss Function
Medical image segmentation suffers from severe class imbalance (background pixels vastly outnumber polyp pixels). To address this, the model is trained with a combined loss:
Loss = 0.5 × Focal + 0.5 × Tversky
| Loss | Parameters | Role |
|---|---|---|
| Focal | α = 0.8, γ = 2 | Down-weights easy background pixels; focuses learning on hard examples |
| Tversky | α = 0.3, β = 0.7 | Penalises false negatives more than false positives, improving recall on small polyps |
Training
- Dataset: 880 train / 60 val / 60 test (Kvasir-SEG split)
- Input size: 256×256, normalised with ImageNet mean/std
- Batch size: 8
- Epochs: 20
- LR schedule: Cosine decay with 10% linear warm-up
- Hyperparameter search: Optuna over
learning_rate{5e-4, 1e-3} andweight_decay{1e-4, 1e-3} - Best weights: snapshot of the epoch with highest validation Dice across all trials
Data Augmentation (train only)
Spatial (applied identically to image and mask): Random horizontal/vertical flip, ±20° rotation, affine (translate ±5%, scale 0.9–1.1), elastic deformation (α=60, σ=6)
Photometric (image only, p=0.8): Color jitter (brightness/contrast ±0.3, saturation ±0.2, hue ±0.05), Gaussian blur (3×3, σ∈[0.1, 2.0])
Usage
from model import ResNet50UNet
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
model = ResNet50UNet(num_classes=1, pretrained=False)
weights_path = hf_hub_download(repo_id="VioletaR/kvasir-seg-unet", filename="model.safetensors")
model.load_state_dict(load_file(weights_path))
model.eval()
Inputs must be 256×256 RGB tensors normalised with ImageNet statistics
(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).
The model outputs raw logits; apply torch.sigmoid and threshold at 0.5 for the binary mask.