iamcode6 commited on
Commit
cf94db4
·
verified ·
1 Parent(s): 3ff37b1

Upload ensemble_eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ensemble_eval.py +159 -0
ensemble_eval.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-model ensemble + proper TTA — averages softmax across 1+ checkpoints.
2
+
3
+ Usage:
4
+ python ensemble_eval.py best_ema.pt
5
+ python ensemble_eval.py best_ema_v2.pt convnextv2_best_ema.pt
6
+ python ensemble_eval.py best_ema_v2.pt swa_v2.pt convnextv2_best_ema.pt convnextv2_swa.pt
7
+
8
+ Each checkpoint is loaded with its own model_name + img_size from its dict
9
+ (so EVA-02 @ 448 and ConvNeXt V2 @ 384 mix freely). Each runs TTA = identity +
10
+ hflip + vflip + 2 scale crops, then per-image softmaxes are averaged.
11
+
12
+ Saves predictions.csv and prints classification_report.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import os
17
+ import sys
18
+
19
+ import numpy as np
20
+ import timm
21
+ import torch
22
+ from sklearn.metrics import classification_report, f1_score
23
+ from timm.data import create_transform
24
+ from torch.utils.data import DataLoader
25
+ from torchvision import datasets
26
+
27
+ BASE_EXTRACT_DIR = "./dermnet-skin40-cleaned-dataset"
28
+ DATA_DIR = os.path.join(BASE_EXTRACT_DIR, "kaggle/working/Merged_Dermnet_Skin40")
29
+ TEST_DIR = os.path.join(DATA_DIR, "test")
30
+ BATCH_SIZE = 16
31
+ NUM_WORKERS = 4
32
+ MEAN, STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
33
+
34
+ os.environ["HIP_VISIBLE_DEVICES"] = "0"
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+
38
+ def make_loader(img_size, crop_pct):
39
+ """Each (img_size, crop_pct) combo gets its own loader for proper TTA."""
40
+ tf = create_transform(input_size=img_size, is_training=False,
41
+ crop_pct=crop_pct, interpolation="bicubic", mean=MEAN, std=STD)
42
+ ds = datasets.ImageFolder(TEST_DIR, transform=tf)
43
+ return ds, DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False,
44
+ num_workers=NUM_WORKERS, pin_memory=True)
45
+
46
+
47
+ def load_model(ckpt_path):
48
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
49
+ model_name = ckpt["model_name"]
50
+ img_size = ckpt["img_size"]
51
+ sd = ckpt["model_state_dict"]
52
+ # Infer num_classes from head weight
53
+ head_keys = [k for k in sd.keys() if k.endswith("head.weight") or k.endswith("fc.weight")]
54
+ num_classes = sd[head_keys[0]].shape[0] if head_keys else 23
55
+ model = timm.create_model(model_name, pretrained=False,
56
+ num_classes=num_classes, img_size=img_size if "vit" in model_name or "eva" in model_name else None)
57
+ if "img_size" in timm.create_model.__code__.co_varnames:
58
+ pass
59
+ # Some non-ViT models reject img_size kwarg — handle gracefully
60
+ try:
61
+ model = timm.create_model(model_name, pretrained=False,
62
+ num_classes=num_classes, img_size=img_size)
63
+ except TypeError:
64
+ model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
65
+ model.load_state_dict(sd)
66
+ model = model.to(device, memory_format=torch.channels_last).eval()
67
+ print(f" loaded {ckpt_path} ({model_name} @ {img_size}, "
68
+ f"prev acc={ckpt.get('val_acc', 0)*100:.2f}%)")
69
+ return model, img_size
70
+
71
+
72
+ @torch.no_grad()
73
+ def tta_softmax(model, img_size):
74
+ """4-augmentation TTA: identity + hflip @ crop_pct=0.95, plus same pair @ crop_pct=1.0."""
75
+ aggregated = None
76
+ targets = None
77
+ classes = None
78
+
79
+ for crop_pct in [0.95, 1.0]:
80
+ ds, loader = make_loader(img_size, crop_pct)
81
+ if classes is None:
82
+ classes = ds.classes
83
+
84
+ batch_softmax = []
85
+ batch_targets = []
86
+ for x, y in loader:
87
+ x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last)
88
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
89
+ # identity
90
+ p = torch.softmax(model(x), dim=-1).float()
91
+ # hflip
92
+ p = p + torch.softmax(model(torch.flip(x, dims=[-1])), dim=-1).float()
93
+ batch_softmax.append(p.cpu())
94
+ batch_targets.append(y)
95
+
96
+ crop_softmax = torch.cat(batch_softmax) # 2 augs already summed
97
+ if aggregated is None:
98
+ aggregated = crop_softmax
99
+ targets = torch.cat(batch_targets)
100
+ else:
101
+ aggregated = aggregated + crop_softmax
102
+
103
+ # Total: 2 crops × 2 flips = 4 augmentations summed
104
+ aggregated = aggregated / 4.0
105
+ return aggregated, targets, classes
106
+
107
+
108
+ def main():
109
+ if len(sys.argv) < 2:
110
+ print("Usage: python ensemble_eval.py <ckpt1> [ckpt2 ...]")
111
+ sys.exit(1)
112
+
113
+ ckpts = sys.argv[1:]
114
+ print(f"Ensembling {len(ckpts)} checkpoint(s) with TTA (4 augs/model):")
115
+
116
+ all_probs = []
117
+ targets = None
118
+ classes = None
119
+
120
+ for ckpt in ckpts:
121
+ if not os.path.exists(ckpt):
122
+ print(f" SKIP missing: {ckpt}"); continue
123
+ model, img_size = load_model(ckpt)
124
+ probs, t, cls = tta_softmax(model, img_size)
125
+ all_probs.append(probs)
126
+ if targets is None:
127
+ targets, classes = t, cls
128
+ # Free GPU memory between models
129
+ del model; torch.cuda.empty_cache()
130
+
131
+ if not all_probs:
132
+ print("No valid checkpoints loaded."); sys.exit(1)
133
+
134
+ # Equal-weight average across models
135
+ ensemble = torch.stack(all_probs).mean(0)
136
+ preds = ensemble.argmax(1)
137
+
138
+ acc = (preds == targets).float().mean().item()
139
+ f1 = f1_score(targets.numpy(), preds.numpy(), average="macro")
140
+
141
+ print(f"\n{'='*60}")
142
+ print(f"Ensemble of {len(all_probs)} model(s) + 4-aug TTA")
143
+ print(f"{'='*60}")
144
+ print(f"Accuracy: {acc*100:.2f}%")
145
+ print(f"Macro F1: {f1:.4f}")
146
+ print(f"\nPer-class report:")
147
+ print(classification_report(targets.numpy(), preds.numpy(),
148
+ target_names=classes, digits=3, zero_division=0))
149
+
150
+ # Save predictions for further analysis
151
+ np.savetxt("predictions.csv",
152
+ np.column_stack([targets.numpy(), preds.numpy(), ensemble.numpy()]),
153
+ delimiter=",", fmt="%g",
154
+ header="true,pred," + ",".join(f"p_{c}" for c in classes), comments="")
155
+ print(f"Saved predictions.csv ({ensemble.shape[0]} rows)")
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()