File size: 5,841 Bytes
cf94db4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Multi-model ensemble + proper TTA — averages softmax across 1+ checkpoints.

Usage:
  python ensemble_eval.py best_ema.pt
  python ensemble_eval.py best_ema_v2.pt convnextv2_best_ema.pt
  python ensemble_eval.py best_ema_v2.pt swa_v2.pt convnextv2_best_ema.pt convnextv2_swa.pt

Each checkpoint is loaded with its own model_name + img_size from its dict
(so EVA-02 @ 448 and ConvNeXt V2 @ 384 mix freely). Each runs TTA = identity +
hflip + vflip + 2 scale crops, then per-image softmaxes are averaged.

Saves predictions.csv and prints classification_report.
"""
from __future__ import annotations

import os
import sys

import numpy as np
import timm
import torch
from sklearn.metrics import classification_report, f1_score
from timm.data import create_transform
from torch.utils.data import DataLoader
from torchvision import datasets

BASE_EXTRACT_DIR = "./dermnet-skin40-cleaned-dataset"
DATA_DIR = os.path.join(BASE_EXTRACT_DIR, "kaggle/working/Merged_Dermnet_Skin40")
TEST_DIR = os.path.join(DATA_DIR, "test")
BATCH_SIZE = 16
NUM_WORKERS = 4
MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

os.environ["HIP_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def make_loader(img_size, crop_pct):
    """Each (img_size, crop_pct) combo gets its own loader for proper TTA."""
    tf = create_transform(input_size=img_size, is_training=False,
        crop_pct=crop_pct, interpolation="bicubic", mean=MEAN, std=STD)
    ds = datasets.ImageFolder(TEST_DIR, transform=tf)
    return ds, DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True)


def load_model(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    model_name = ckpt["model_name"]
    img_size = ckpt["img_size"]
    sd = ckpt["model_state_dict"]
    # Infer num_classes from head weight
    head_keys = [k for k in sd.keys() if k.endswith("head.weight") or k.endswith("fc.weight")]
    num_classes = sd[head_keys[0]].shape[0] if head_keys else 23
    model = timm.create_model(model_name, pretrained=False,
        num_classes=num_classes, img_size=img_size if "vit" in model_name or "eva" in model_name else None)
    if "img_size" in timm.create_model.__code__.co_varnames:
        pass
    # Some non-ViT models reject img_size kwarg — handle gracefully
    try:
        model = timm.create_model(model_name, pretrained=False,
            num_classes=num_classes, img_size=img_size)
    except TypeError:
        model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
    model.load_state_dict(sd)
    model = model.to(device, memory_format=torch.channels_last).eval()
    print(f"  loaded {ckpt_path}  ({model_name} @ {img_size}, "
          f"prev acc={ckpt.get('val_acc', 0)*100:.2f}%)")
    return model, img_size


@torch.no_grad()
def tta_softmax(model, img_size):
    """4-augmentation TTA: identity + hflip @ crop_pct=0.95, plus same pair @ crop_pct=1.0."""
    aggregated = None
    targets = None
    classes = None

    for crop_pct in [0.95, 1.0]:
        ds, loader = make_loader(img_size, crop_pct)
        if classes is None:
            classes = ds.classes

        batch_softmax = []
        batch_targets = []
        for x, y in loader:
            x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                # identity
                p = torch.softmax(model(x), dim=-1).float()
                # hflip
                p = p + torch.softmax(model(torch.flip(x, dims=[-1])), dim=-1).float()
            batch_softmax.append(p.cpu())
            batch_targets.append(y)

        crop_softmax = torch.cat(batch_softmax)  # 2 augs already summed
        if aggregated is None:
            aggregated = crop_softmax
            targets = torch.cat(batch_targets)
        else:
            aggregated = aggregated + crop_softmax

    # Total: 2 crops × 2 flips = 4 augmentations summed
    aggregated = aggregated / 4.0
    return aggregated, targets, classes


def main():
    if len(sys.argv) < 2:
        print("Usage: python ensemble_eval.py <ckpt1> [ckpt2 ...]")
        sys.exit(1)

    ckpts = sys.argv[1:]
    print(f"Ensembling {len(ckpts)} checkpoint(s) with TTA (4 augs/model):")

    all_probs = []
    targets = None
    classes = None

    for ckpt in ckpts:
        if not os.path.exists(ckpt):
            print(f"  SKIP missing: {ckpt}"); continue
        model, img_size = load_model(ckpt)
        probs, t, cls = tta_softmax(model, img_size)
        all_probs.append(probs)
        if targets is None:
            targets, classes = t, cls
        # Free GPU memory between models
        del model; torch.cuda.empty_cache()

    if not all_probs:
        print("No valid checkpoints loaded."); sys.exit(1)

    # Equal-weight average across models
    ensemble = torch.stack(all_probs).mean(0)
    preds = ensemble.argmax(1)

    acc = (preds == targets).float().mean().item()
    f1 = f1_score(targets.numpy(), preds.numpy(), average="macro")

    print(f"\n{'='*60}")
    print(f"Ensemble of {len(all_probs)} model(s) + 4-aug TTA")
    print(f"{'='*60}")
    print(f"Accuracy:   {acc*100:.2f}%")
    print(f"Macro F1:   {f1:.4f}")
    print(f"\nPer-class report:")
    print(classification_report(targets.numpy(), preds.numpy(),
        target_names=classes, digits=3, zero_division=0))

    # Save predictions for further analysis
    np.savetxt("predictions.csv",
        np.column_stack([targets.numpy(), preds.numpy(), ensemble.numpy()]),
        delimiter=",", fmt="%g",
        header="true,pred," + ",".join(f"p_{c}" for c in classes), comments="")
    print(f"Saved predictions.csv  ({ensemble.shape[0]} rows)")


if __name__ == "__main__":
    main()