dermnet-skin23-convnext / ensemble_eval.py
iamcode6's picture
Upload ensemble_eval.py with huggingface_hub
cf94db4 verified
"""Multi-model ensemble + proper TTA — averages softmax across 1+ checkpoints.
Usage:
python ensemble_eval.py best_ema.pt
python ensemble_eval.py best_ema_v2.pt convnextv2_best_ema.pt
python ensemble_eval.py best_ema_v2.pt swa_v2.pt convnextv2_best_ema.pt convnextv2_swa.pt
Each checkpoint is loaded with its own model_name + img_size from its dict
(so EVA-02 @ 448 and ConvNeXt V2 @ 384 mix freely). Each runs TTA = identity +
hflip + vflip + 2 scale crops, then per-image softmaxes are averaged.
Saves predictions.csv and prints classification_report.
"""
from __future__ import annotations
import os
import sys
import numpy as np
import timm
import torch
from sklearn.metrics import classification_report, f1_score
from timm.data import create_transform
from torch.utils.data import DataLoader
from torchvision import datasets
BASE_EXTRACT_DIR = "./dermnet-skin40-cleaned-dataset"
DATA_DIR = os.path.join(BASE_EXTRACT_DIR, "kaggle/working/Merged_Dermnet_Skin40")
TEST_DIR = os.path.join(DATA_DIR, "test")
BATCH_SIZE = 16
NUM_WORKERS = 4
MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
os.environ["HIP_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def make_loader(img_size, crop_pct):
"""Each (img_size, crop_pct) combo gets its own loader for proper TTA."""
tf = create_transform(input_size=img_size, is_training=False,
crop_pct=crop_pct, interpolation="bicubic", mean=MEAN, std=STD)
ds = datasets.ImageFolder(TEST_DIR, transform=tf)
return ds, DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
def load_model(ckpt_path):
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model_name = ckpt["model_name"]
img_size = ckpt["img_size"]
sd = ckpt["model_state_dict"]
# Infer num_classes from head weight
head_keys = [k for k in sd.keys() if k.endswith("head.weight") or k.endswith("fc.weight")]
num_classes = sd[head_keys[0]].shape[0] if head_keys else 23
model = timm.create_model(model_name, pretrained=False,
num_classes=num_classes, img_size=img_size if "vit" in model_name or "eva" in model_name else None)
if "img_size" in timm.create_model.__code__.co_varnames:
pass
# Some non-ViT models reject img_size kwarg — handle gracefully
try:
model = timm.create_model(model_name, pretrained=False,
num_classes=num_classes, img_size=img_size)
except TypeError:
model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
model.load_state_dict(sd)
model = model.to(device, memory_format=torch.channels_last).eval()
print(f" loaded {ckpt_path} ({model_name} @ {img_size}, "
f"prev acc={ckpt.get('val_acc', 0)*100:.2f}%)")
return model, img_size
@torch.no_grad()
def tta_softmax(model, img_size):
"""4-augmentation TTA: identity + hflip @ crop_pct=0.95, plus same pair @ crop_pct=1.0."""
aggregated = None
targets = None
classes = None
for crop_pct in [0.95, 1.0]:
ds, loader = make_loader(img_size, crop_pct)
if classes is None:
classes = ds.classes
batch_softmax = []
batch_targets = []
for x, y in loader:
x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
# identity
p = torch.softmax(model(x), dim=-1).float()
# hflip
p = p + torch.softmax(model(torch.flip(x, dims=[-1])), dim=-1).float()
batch_softmax.append(p.cpu())
batch_targets.append(y)
crop_softmax = torch.cat(batch_softmax) # 2 augs already summed
if aggregated is None:
aggregated = crop_softmax
targets = torch.cat(batch_targets)
else:
aggregated = aggregated + crop_softmax
# Total: 2 crops × 2 flips = 4 augmentations summed
aggregated = aggregated / 4.0
return aggregated, targets, classes
def main():
if len(sys.argv) < 2:
print("Usage: python ensemble_eval.py <ckpt1> [ckpt2 ...]")
sys.exit(1)
ckpts = sys.argv[1:]
print(f"Ensembling {len(ckpts)} checkpoint(s) with TTA (4 augs/model):")
all_probs = []
targets = None
classes = None
for ckpt in ckpts:
if not os.path.exists(ckpt):
print(f" SKIP missing: {ckpt}"); continue
model, img_size = load_model(ckpt)
probs, t, cls = tta_softmax(model, img_size)
all_probs.append(probs)
if targets is None:
targets, classes = t, cls
# Free GPU memory between models
del model; torch.cuda.empty_cache()
if not all_probs:
print("No valid checkpoints loaded."); sys.exit(1)
# Equal-weight average across models
ensemble = torch.stack(all_probs).mean(0)
preds = ensemble.argmax(1)
acc = (preds == targets).float().mean().item()
f1 = f1_score(targets.numpy(), preds.numpy(), average="macro")
print(f"\n{'='*60}")
print(f"Ensemble of {len(all_probs)} model(s) + 4-aug TTA")
print(f"{'='*60}")
print(f"Accuracy: {acc*100:.2f}%")
print(f"Macro F1: {f1:.4f}")
print(f"\nPer-class report:")
print(classification_report(targets.numpy(), preds.numpy(),
target_names=classes, digits=3, zero_division=0))
# Save predictions for further analysis
np.savetxt("predictions.csv",
np.column_stack([targets.numpy(), preds.numpy(), ensemble.numpy()]),
delimiter=",", fmt="%g",
header="true,pred," + ",".join(f"p_{c}" for c in classes), comments="")
print(f"Saved predictions.csv ({ensemble.shape[0]} rows)")
if __name__ == "__main__":
main()