ResNet-50 U-Net — Kvasir-SEG Polyp Segmentation

Binary segmentation model for gastrointestinal polyp detection, trained on the Kvasir-SEG dataset.

Architecture

A U-Net with a ResNet-50 encoder pretrained on ImageNet.

Component Details
Encoder ResNet-50 (ImageNet pretrained), all layers fine-tuned
Skip connections Concatenation of encoder and decoder feature maps at each scale
Decoder Transposed-conv ×2 upsampling + double conv block at each level
Output Single-channel logit map (1 × H × W)

Feature map sizes for a 256×256 input:

stride 2 ->  64ch, 128×128
stride 4 -> 256ch, 64×64
stride 8 -> 512ch, 32×32
stride 16 -> 1024ch, 16×16
stride 32 -> 2048ch, 8×8 (bottleneck)

Loss Function

Medical image segmentation suffers from severe class imbalance (background pixels vastly outnumber polyp pixels). To address this, the model is trained with a combined loss:

Loss = 0.5 × Focal + 0.5 × Tversky
Loss Parameters Role
Focal α = 0.8, γ = 2 Down-weights easy background pixels; focuses learning on hard examples
Tversky α = 0.3, β = 0.7 Penalises false negatives more than false positives, improving recall on small polyps

Training

  • Dataset: 880 train / 60 val / 60 test (Kvasir-SEG split)
  • Input size: 256×256, normalised with ImageNet mean/std
  • Batch size: 8
  • Epochs: 20
  • LR schedule: Cosine decay with 10% linear warm-up
  • Hyperparameter search: Optuna over learning_rate {5e-4, 1e-3} and weight_decay {1e-4, 1e-3}
  • Best weights: snapshot of the epoch with highest validation Dice across all trials

Data Augmentation (train only)

Spatial (applied identically to image and mask): Random horizontal/vertical flip, ±20° rotation, affine (translate ±5%, scale 0.9–1.1), elastic deformation (α=60, σ=6)

Photometric (image only, p=0.8): Color jitter (brightness/contrast ±0.3, saturation ±0.2, hue ±0.05), Gaussian blur (3×3, σ∈[0.1, 2.0])

Usage

from model import ResNet50UNet
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

model = ResNet50UNet(num_classes=1, pretrained=False)
weights_path = hf_hub_download(repo_id="VioletaR/kvasir-seg-unet", filename="model.safetensors")
model.load_state_dict(load_file(weights_path))
model.eval()

Inputs must be 256×256 RGB tensors normalised with ImageNet statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]). The model outputs raw logits; apply torch.sigmoid and threshold at 0.5 for the binary mask.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train violetar/kvasir-seg-unet