AbstractPhil commited on
Commit
463e724
Β·
verified Β·
1 Parent(s): 9601cd1

Create runner.py

Browse files
Files changed (1) hide show
  1. 50k_results/runner.py +591 -0
50k_results/runner.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Large-Scale Geometric Similarity - Cell 10
3
+ ============================================
4
+ 50,000 synthetic character images β†’ FLUX VAE β†’ Geometric Features
5
+ Categories from generator_type field (15 types).
6
+
7
+ Streams from HuggingFace datasets, encodes in batches,
8
+ extracts gate vectors + patch features, computes similarity.
9
+
10
+ Requires Cell 1 (generator.py) and Cell 2 (model.py) in namespace.
11
+ """
12
+
13
+ import os, json, gc, time
14
+ from pathlib import Path
15
+ from collections import Counter, defaultdict
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+ from PIL import Image
21
+ from torchvision import transforms
22
+ import matplotlib
23
+ matplotlib.use("Agg")
24
+ import matplotlib.pyplot as plt
25
+ import matplotlib.patheffects as pe
26
+
27
+ # ── Config ────────────────────────────────────────────────────────────────────
28
+
29
+ DATASET_ID = "AbstractPhil/synthetic-characters"
30
+ SUBSET = "schnell_full_1_512"
31
+ MODEL_REPO = "AbstractPhil/grid-geometric-multishape"
32
+ MODEL_FILE = "checkpoint_v10/best_model_epoch200.pt"
33
+ VAE_REPO = "black-forest-labs/FLUX.1-schnell"
34
+
35
+ OUTPUT_DIR = "/content/results_50k"
36
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
+ IMAGE_SIZE = 128
38
+ FLUX_SCALE = 0.3611
39
+
40
+ # Batch sizes β€” tuned for L4 (24GB VRAM)
41
+ VAE_BATCH = 128 # images per VAE encode
42
+ FEAT_BATCH = 256 # adapted latents per model forward
43
+
44
+ MIN_CATEGORY_SIZE = 50 # drop categories smaller than this
45
+
46
+ img_transform = transforms.Compose([
47
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
50
+ ])
51
+
52
+ # ── Load Models ───────────────────────────────────────────────────────────────
53
+
54
+ def load_vae():
55
+ from diffusers import AutoencoderKL
56
+ print("Loading FLUX VAE...")
57
+ vae = AutoencoderKL.from_pretrained(
58
+ VAE_REPO, subfolder="vae", torch_dtype=torch.float16,
59
+ ).to(DEVICE).eval()
60
+ print("βœ“ VAE ready")
61
+ return vae
62
+
63
+
64
+ def load_model():
65
+ from huggingface_hub import hf_hub_download
66
+ print("Loading geometric model...")
67
+ path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
68
+ ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
69
+ config = ckpt["config"]
70
+ model = SuperpositionPatchClassifier(
71
+ embed_dim=config["embed_dim"],
72
+ patch_dim=config["patch_dim"],
73
+ n_bootstrap=config["n_bootstrap"],
74
+ n_geometric=config["n_geometric"],
75
+ n_heads=config["n_heads"],
76
+ dropout=0.0,
77
+ ).to(DEVICE).eval()
78
+ model.load_state_dict(ckpt["model_state_dict"])
79
+ print(f"βœ“ Model ready (epoch {ckpt['epoch']})")
80
+ return model
81
+
82
+
83
+ # ── Streaming Encode + Extract ────────────────────────────────────────────────
84
+
85
+ def process_image(img_pil):
86
+ """PIL Image β†’ tensor ready for VAE."""
87
+ return img_transform(img_pil.convert("RGB"))
88
+
89
+
90
+ def adapt_latent(z):
91
+ """(B, 16, H, W) β†’ (B, 8, 16, 16)"""
92
+ B, C, H, W = z.shape
93
+ if H != 16 or W != 16:
94
+ z = F.interpolate(z, size=(16, 16), mode='bilinear', align_corners=False)
95
+ if C == 16:
96
+ z = z.view(B, 8, 2, 16, 16).mean(dim=2)
97
+ return z
98
+
99
+
100
+ @torch.no_grad()
101
+ def extract_gate_vectors(adapted, model):
102
+ """
103
+ adapted: (B, 8, 16, 16)
104
+ Returns: gate_vectors (B, 64, 17), patch_features (B, 64, 256)
105
+ """
106
+ out = model(adapted)
107
+
108
+ local_gates = torch.cat([
109
+ F.softmax(out["local_dim_logits"], dim=-1),
110
+ F.softmax(out["local_curv_logits"], dim=-1),
111
+ torch.sigmoid(out["local_bound_logits"]),
112
+ torch.sigmoid(out["local_axis_logits"]),
113
+ ], dim=-1)
114
+
115
+ struct_gates = torch.cat([
116
+ F.softmax(out["struct_topo_logits"], dim=-1),
117
+ torch.sigmoid(out["struct_neighbor_logits"]),
118
+ F.softmax(out["struct_role_logits"], dim=-1),
119
+ ], dim=-1)
120
+
121
+ gates = torch.cat([local_gates, struct_gates], dim=-1)
122
+ return gates.cpu(), out["patch_features"].cpu()
123
+
124
+
125
+ # ── Dataset wrapper for DataLoader ────────────────────────────────────────────
126
+
127
+ class HFImageDataset(torch.utils.data.Dataset):
128
+ """Wraps HF dataset for PyTorch DataLoader with parallel workers."""
129
+ def __init__(self, hf_ds):
130
+ self.ds = hf_ds
131
+ self.N = len(hf_ds)
132
+
133
+ def __len__(self):
134
+ return self.N
135
+
136
+ def __getitem__(self, idx):
137
+ row = self.ds[idx]
138
+ try:
139
+ tensor = img_transform(row["image"].convert("RGB"))
140
+ except:
141
+ tensor = torch.zeros(3, IMAGE_SIZE, IMAGE_SIZE)
142
+ cat = row.get("generator_type", "unknown")
143
+ rid = row.get("id", idx)
144
+ return tensor, cat, rid
145
+
146
+
147
+ def collate_fn(batch):
148
+ tensors, cats, ids = zip(*batch)
149
+ return torch.stack(tensors), list(cats), list(ids)
150
+
151
+
152
+ def _save_checkpoint(all_gates, all_patch, all_cats, all_ids, n):
153
+ g = torch.cat(all_gates) if isinstance(all_gates[0], torch.Tensor) and all_gates[0].dim() == 3 else torch.cat(all_gates)
154
+ p = torch.cat(all_patch) if isinstance(all_patch[0], torch.Tensor) and all_patch[0].dim() == 3 else torch.cat(all_patch)
155
+ path = os.path.join(OUTPUT_DIR, f"checkpoint_{n}.pt")
156
+ torch.save({"gates": g, "patch_feats": p, "categories": all_cats, "ids": all_ids}, path)
157
+ print(f"\n πŸ’Ύ Checkpoint: {path} ({g.shape[0]} samples)")
158
+
159
+
160
+ def find_latest_checkpoint(output_dir=OUTPUT_DIR):
161
+ """Find highest numbered checkpoint file."""
162
+ import glob
163
+ pattern = os.path.join(output_dir, "checkpoint_*.pt")
164
+ files = glob.glob(pattern)
165
+ if not files:
166
+ return None, 0
167
+ # Extract numbers
168
+ best_n, best_f = 0, None
169
+ for f in files:
170
+ try:
171
+ n = int(os.path.basename(f).replace("checkpoint_", "").replace(".pt", ""))
172
+ if n > best_n:
173
+ best_n, best_f = n, f
174
+ except:
175
+ pass
176
+ return best_f, best_n
177
+
178
+
179
+ def run_extraction(ds, vae, model):
180
+ """
181
+ DataLoader with workers β†’ VAE encode β†’ geometric extract.
182
+ Resumes from latest checkpoint if available.
183
+ Returns: gates (N, 64, 17), patch_feats (N, 64, 256), categories list
184
+ """
185
+ from tqdm import tqdm
186
+
187
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
188
+
189
+ # Check for existing checkpoint
190
+ ckpt_path, resume_from = find_latest_checkpoint()
191
+ if ckpt_path:
192
+ print(f"\nπŸ”„ Resuming from checkpoint: {ckpt_path} ({resume_from} samples)")
193
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
194
+ all_gates = [ckpt["gates"]]
195
+ all_patch = [ckpt["patch_feats"]]
196
+ all_cats = list(ckpt["categories"])
197
+ all_ids = list(ckpt["ids"])
198
+ processed = resume_from
199
+ del ckpt
200
+ gc.collect()
201
+ print(f" βœ“ Loaded {processed} cached samples")
202
+ else:
203
+ all_gates = []
204
+ all_patch = []
205
+ all_cats = []
206
+ all_ids = []
207
+ processed = 0
208
+
209
+ # Skip already-processed samples
210
+ N = len(ds)
211
+ remaining = N - resume_from
212
+
213
+ if remaining <= 0:
214
+ print(f"βœ“ All {N} samples already extracted")
215
+ gates = torch.cat(all_gates)
216
+ patch_feats = torch.cat(all_patch)
217
+ return gates, patch_feats, all_cats, all_ids
218
+
219
+ # Subset dataset to remaining samples
220
+ if resume_from > 0:
221
+ ds_remaining = ds.select(range(resume_from, N))
222
+ print(f" Extracting remaining {remaining} samples...")
223
+ else:
224
+ ds_remaining = ds
225
+
226
+ dataset = HFImageDataset(ds_remaining)
227
+ loader = torch.utils.data.DataLoader(
228
+ dataset,
229
+ batch_size=VAE_BATCH,
230
+ shuffle=False,
231
+ num_workers=8,
232
+ pin_memory=True,
233
+ prefetch_factor=4,
234
+ collate_fn=collate_fn,
235
+ persistent_workers=True,
236
+ )
237
+
238
+ pbar = tqdm(total=remaining, unit="img", desc=f"Extracting (from {resume_from})")
239
+
240
+ for batch_pixels, cats, ids in loader:
241
+ batch_pixels = batch_pixels.to(DEVICE, non_blocking=True)
242
+
243
+ # VAE encode (fp16)
244
+ with torch.no_grad(), torch.cuda.amp.autocast():
245
+ latents = vae.encode(batch_pixels.half()).latent_dist.sample() * FLUX_SCALE
246
+ adapted = adapt_latent(latents.float()) # geometric model expects fp32
247
+
248
+ # Extract in sub-batches
249
+ for fstart in range(0, adapted.shape[0], FEAT_BATCH):
250
+ fend = min(fstart + FEAT_BATCH, adapted.shape[0])
251
+ gates, patch_feats = extract_gate_vectors(adapted[fstart:fend], model)
252
+ all_gates.append(gates)
253
+ all_patch.append(patch_feats)
254
+
255
+ all_cats.extend(cats)
256
+ all_ids.extend(ids)
257
+ processed += len(cats)
258
+
259
+ pbar.update(len(cats))
260
+
261
+ # Periodic checkpoint
262
+ if processed % SAVE_EVERY < VAE_BATCH and processed >= SAVE_EVERY:
263
+ _save_checkpoint(all_gates, all_patch, all_cats, all_ids, processed)
264
+
265
+ pbar.close()
266
+ print(f"βœ“ Processed {processed} images total")
267
+
268
+ # Final checkpoint
269
+ _save_checkpoint(all_gates, all_patch, all_cats, all_ids, processed)
270
+
271
+ gates = torch.cat(all_gates)
272
+ patch_feats = torch.cat(all_patch)
273
+
274
+ return gates, patch_feats, all_cats, all_ids
275
+
276
+
277
+ # ── Build Representations ─────────────────────────────────────────────────────
278
+
279
+ def build_reps(gates, patch_feats):
280
+ N = gates.shape[0]
281
+
282
+ # Mean pool on GPU (49k Γ— 64 Γ— 256 is 3.2GB β€” fits L4)
283
+ global_feats = patch_feats.to(DEVICE).mean(dim=1).cpu() # (N, 256)
284
+ torch.cuda.empty_cache()
285
+
286
+ # Normalize on GPU per-rep
287
+ reps = {
288
+ "gate_vectors": F.normalize(gates.reshape(N, -1).to(DEVICE), dim=-1).cpu(),
289
+ "patch_feat": F.normalize(patch_feats.reshape(N, -1).to(DEVICE), dim=-1).cpu(),
290
+ "global_feat": F.normalize(global_feats.to(DEVICE), dim=-1).cpu(),
291
+ }
292
+ torch.cuda.empty_cache()
293
+ return reps, global_feats
294
+
295
+
296
+ # ── Category Similarity (size-weighted) ───────────────────────────────────────
297
+
298
+ def compute_similarity(reps, cat_indices, cat_names):
299
+ """
300
+ GPU-accelerated chunked similarity.
301
+ Computes only the category blocks needed.
302
+ """
303
+ results = {}
304
+
305
+ for rep_name, features in reps.items():
306
+ print(f" Computing: {rep_name}...")
307
+ features_gpu = features.to(DEVICE)
308
+ n_cats = len(cat_names)
309
+ cat_matrix = np.zeros((n_cats, n_cats))
310
+
311
+ for i, ci in enumerate(cat_names):
312
+ fi = features_gpu[cat_indices[ci]] # (ni, D) on GPU
313
+ for j, cj in enumerate(cat_names):
314
+ if j < i:
315
+ # Symmetric β€” reuse
316
+ cat_matrix[i, j] = cat_matrix[j, i]
317
+ continue
318
+
319
+ fj = features_gpu[cat_indices[cj]] # (nj, D) on GPU
320
+
321
+ # Chunked matmul on GPU
322
+ chunk = 4000
323
+ block_sums = 0.0
324
+ block_count = 0
325
+ diag_sum = 0.0
326
+ diag_count = 0
327
+
328
+ for s in range(0, fi.shape[0], chunk):
329
+ sim = fi[s:s+chunk] @ fj.T # (chunk, nj) on GPU
330
+ if i == j:
331
+ # Exclude self-similarity on diagonal
332
+ row_offset = s
333
+ for r in range(sim.shape[0]):
334
+ global_r = row_offset + r
335
+ if global_r < sim.shape[1]:
336
+ diag_sum += sim[r, global_r].item()
337
+ diag_count += 1
338
+ block_sums += sim.sum().item()
339
+ block_count += sim.numel()
340
+ else:
341
+ block_sums += sim.sum().item()
342
+ block_count += sim.numel()
343
+
344
+ if i == j:
345
+ # Within: total minus diagonal, divided by off-diagonal count
346
+ val = (block_sums - diag_sum) / max(block_count - diag_count, 1)
347
+ else:
348
+ val = block_sums / max(block_count, 1)
349
+
350
+ cat_matrix[i, j] = float(val)
351
+ if j > i:
352
+ cat_matrix[j, i] = float(val)
353
+
354
+ del features_gpu
355
+ torch.cuda.empty_cache()
356
+
357
+ # Size-weighted between
358
+ sizes = {c: len(cat_indices[c]) for c in cat_names}
359
+ total = sum(sizes.values())
360
+
361
+ between_sum, between_pairs = 0.0, 0
362
+ for i, ci in enumerate(cat_names):
363
+ for j, cj in enumerate(cat_names):
364
+ if i != j:
365
+ n_pairs = sizes[ci] * sizes[cj]
366
+ between_sum += cat_matrix[i, j] * n_pairs
367
+ between_pairs += n_pairs
368
+ between_mean = between_sum / max(between_pairs, 1)
369
+
370
+ discriminability = {}
371
+ for i, ci in enumerate(cat_names):
372
+ cross_sum, cross_n = 0.0, 0
373
+ for j, cj in enumerate(cat_names):
374
+ if i != j:
375
+ cross_sum += cat_matrix[i, j] * sizes[cj]
376
+ cross_n += sizes[cj]
377
+ cat_between = cross_sum / max(cross_n, 1)
378
+ discriminability[ci] = float(cat_matrix[i, i] - cat_between)
379
+
380
+ overall = sum(discriminability[c] * sizes[c] / total for c in cat_names)
381
+
382
+ results[rep_name] = {
383
+ "matrix": cat_matrix,
384
+ "within": {c: float(cat_matrix[i, i]) for i, c in enumerate(cat_names)},
385
+ "between_mean": float(between_mean),
386
+ "discriminability": discriminability,
387
+ "overall_discriminability": float(overall),
388
+ "sizes": sizes,
389
+ }
390
+
391
+ return results
392
+
393
+
394
+ # ── Display ───────────────────────────────────────────────────────────────────
395
+
396
+ def print_results(results, cat_names):
397
+ first = next(iter(results.values()))
398
+ sizes = first["sizes"]
399
+ total = sum(sizes.values())
400
+
401
+ print(f"\nCategories ({len(cat_names)}, {total} total):")
402
+ for c in cat_names:
403
+ print(f" {c:30s} n={sizes[c]:5d} ({sizes[c]/total*100:5.1f}%)")
404
+
405
+ for rep_name, data in results.items():
406
+ print(f"\n{'='*80}")
407
+ print(f" {rep_name}")
408
+ print(f"{'='*80}")
409
+
410
+ # Top/bottom within
411
+ within_sorted = sorted(data["within"].items(), key=lambda x: -x[1])
412
+ print(f"\n Within-category similarity (top 5 / bottom 5):")
413
+ for c, v in within_sorted[:5]:
414
+ print(f" {c:30s} {v:.4f} (n={sizes[c]})")
415
+ print(f" ...")
416
+ for c, v in within_sorted[-5:]:
417
+ print(f" {c:30s} {v:.4f} (n={sizes[c]})")
418
+
419
+ print(f"\n Between-category mean: {data['between_mean']:.4f}")
420
+
421
+ # Discriminability ranked
422
+ disc_sorted = sorted(data["discriminability"].items(), key=lambda x: -x[1])
423
+ print(f"\n Discriminability (within βˆ’ weighted between):")
424
+ print(f" {'Top 5':>36s}")
425
+ for c, d in disc_sorted[:5]:
426
+ sign = "+" if d > 0 else ""
427
+ print(f" {c:30s} {sign}{d:.4f}")
428
+ print(f" {'Bottom 5':>36s}")
429
+ for c, d in disc_sorted[-5:]:
430
+ sign = "+" if d > 0 else ""
431
+ print(f" {c:30s} {sign}{d:.4f}")
432
+ print(f" {'OVERALL':30s} {'+' if data['overall_discriminability'] > 0 else ''}{data['overall_discriminability']:.4f}")
433
+
434
+
435
+ def plot_results(results, cat_names, output_dir=OUTPUT_DIR):
436
+ os.makedirs(output_dir, exist_ok=True)
437
+
438
+ for rep_name, data in results.items():
439
+ mat = data["matrix"]
440
+ n = len(cat_names)
441
+
442
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7), facecolor='#0a0a0a')
443
+
444
+ # Similarity matrix
445
+ im = ax1.imshow(mat, cmap='magma', vmin=mat.min() * 0.95, vmax=mat.max(), aspect='equal')
446
+ ax1.set_xticks(range(n))
447
+ ax1.set_yticks(range(n))
448
+ short = [c.replace("character_", "").replace("_", "\n") for c in cat_names]
449
+ ax1.set_xticklabels(short, fontsize=6, color='white', rotation=45, ha='right')
450
+ ax1.set_yticklabels(short, fontsize=6, color='white')
451
+ for i in range(n):
452
+ for j in range(n):
453
+ ax1.text(j, i, f'{mat[i,j]:.3f}', ha='center', va='center',
454
+ fontsize=5, color='white' if mat[i,j] < np.median(mat) else 'black')
455
+ ax1.set_title(f"{rep_name} β€” Similarity Matrix", color='white', fontsize=10, fontweight='bold')
456
+ ax1.tick_params(colors='white')
457
+ plt.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)
458
+
459
+ # Discriminability bar chart
460
+ ax2.set_facecolor('#0a0a0a')
461
+ disc = data["discriminability"]
462
+ disc_sorted = sorted(disc.items(), key=lambda x: -x[1])
463
+ names_d = [x[0].replace("character_", "") for x in disc_sorted]
464
+ vals_d = [x[1] for x in disc_sorted]
465
+ colors = ['#00b894' if v > 0 else '#e17055' for v in vals_d]
466
+
467
+ ax2.barh(range(len(names_d)), vals_d, color=colors, edgecolor='white', linewidth=0.3)
468
+ ax2.set_yticks(range(len(names_d)))
469
+ ax2.set_yticklabels(names_d, fontsize=7, color='white')
470
+ ax2.axvline(0, color='white', linewidth=0.5, alpha=0.5)
471
+ ax2.axvline(data["overall_discriminability"], color='#fdcb6e',
472
+ linewidth=1, linestyle='--', alpha=0.8, label=f'overall={data["overall_discriminability"]:.4f}')
473
+ ax2.set_xlabel("Discriminability", color='white', fontsize=9)
474
+ ax2.set_title(f"{rep_name} β€” Discriminability", color='white', fontsize=10, fontweight='bold')
475
+ ax2.tick_params(colors='white', labelsize=7)
476
+ ax2.spines['bottom'].set_color('white')
477
+ ax2.spines['left'].set_color('white')
478
+ ax2.spines['top'].set_visible(False)
479
+ ax2.spines['right'].set_visible(False)
480
+ ax2.legend(fontsize=7, framealpha=0.7, facecolor='#1a1a2e', labelcolor='white')
481
+
482
+ safe_name = rep_name.replace(" ", "_").replace("(", "").replace(")", "")
483
+ path = os.path.join(output_dir, f"{safe_name}.png")
484
+ fig.savefig(path, dpi=150, bbox_inches='tight', facecolor=fig.get_facecolor())
485
+ plt.close(fig)
486
+ print(f"βœ“ Plot: {path}")
487
+
488
+
489
+ # ── Save ──────────────────────────────────────────────────────────────────────
490
+
491
+ def save_final(gates, patch_feats, categories, ids, results, cat_names):
492
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
493
+
494
+ # Features
495
+ tpath = os.path.join(OUTPUT_DIR, "geometric_features_50k.pt")
496
+ torch.save({
497
+ "gate_vectors": gates,
498
+ "patch_features": patch_feats,
499
+ "global_features": patch_feats.to(DEVICE).mean(dim=1).cpu(),
500
+ "categories": categories,
501
+ "ids": ids,
502
+ "cat_names": cat_names,
503
+ }, tpath)
504
+ print(f"βœ“ Saved: {tpath}")
505
+ print(f" gates: {gates.shape}, patch_feats: {patch_feats.shape}")
506
+
507
+ # Similarity JSON
508
+ out = {}
509
+ for rep_name, data in results.items():
510
+ out[rep_name] = {
511
+ "within": data["within"],
512
+ "between_mean": data["between_mean"],
513
+ "discriminability": data["discriminability"],
514
+ "overall_discriminability": data["overall_discriminability"],
515
+ "sizes": data["sizes"],
516
+ "matrix": data["matrix"].tolist(),
517
+ }
518
+ jpath = os.path.join(OUTPUT_DIR, "similarity_results_50k.json")
519
+ with open(jpath, "w") as f:
520
+ json.dump(out, f, indent=2)
521
+ print(f"βœ“ Saved: {jpath}")
522
+
523
+
524
+ # ── Main ────────────────────────────────────────��─────────────────────────────
525
+
526
+ def run_50k():
527
+ from datasets import load_dataset
528
+
529
+ # 1. Load dataset
530
+ print(f"Loading dataset: {DATASET_ID} / {SUBSET}...")
531
+ ds = load_dataset(DATASET_ID, SUBSET, split="train")
532
+ print(f"βœ“ {len(ds)} samples loaded")
533
+
534
+ # Show category distribution
535
+ cats = ds["generator_type"]
536
+ cat_counts = Counter(cats)
537
+ print(f"\nGenerator type distribution:")
538
+ for c, n in cat_counts.most_common():
539
+ print(f" {c:30s} {n:6d} ({n/len(ds)*100:5.1f}%)")
540
+
541
+ # 2. Load models
542
+ vae = load_vae()
543
+ model = load_model()
544
+
545
+ # 3. Stream encode + extract
546
+ gates, patch_feats, categories, ids = run_extraction(ds, vae, model)
547
+
548
+ # 4. Free VAE
549
+ del vae
550
+ gc.collect()
551
+ torch.cuda.empty_cache()
552
+ print("βœ“ Freed VAE memory")
553
+
554
+ # 5. Build category indices (with minimum size filter)
555
+ cat_counts_final = Counter(categories)
556
+ cat_names = sorted([c for c, n in cat_counts_final.items() if n >= MIN_CATEGORY_SIZE])
557
+ dropped = [c for c, n in cat_counts_final.items() if n < MIN_CATEGORY_SIZE]
558
+ if dropped:
559
+ print(f"\n⚠ Dropping {len(dropped)} categories with < {MIN_CATEGORY_SIZE} samples: {dropped}")
560
+
561
+ # Build index mapping (vectorized)
562
+ cat_indices = {}
563
+ cat_array = np.array(categories)
564
+ for c in cat_names:
565
+ cat_indices[c] = torch.from_numpy(np.where(cat_array == c)[0]).long()
566
+
567
+ total_used = sum(len(v) for v in cat_indices.values())
568
+ print(f"\nUsing {len(cat_names)} categories, {total_used}/{len(categories)} samples")
569
+
570
+ # 6. Build representations
571
+ print("\nBuilding representations...")
572
+ reps, global_feats = build_reps(gates, patch_feats)
573
+
574
+ # 7. Compute similarity
575
+ print("Computing category similarity (chunked)...")
576
+ sim_results = compute_similarity(reps, cat_indices, cat_names)
577
+
578
+ # 8. Display
579
+ print_results(sim_results, cat_names)
580
+
581
+ # 9. Plot
582
+ plot_results(sim_results, cat_names)
583
+
584
+ # 10. Save
585
+ save_final(gates, patch_feats, categories, ids, sim_results, cat_names)
586
+
587
+ return sim_results, gates, patch_feats, cat_indices, cat_names
588
+
589
+
590
+ # ── Run ───────────────────────────────────────────────────────────────────────
591
+ sim_results, gates, patch_feats, cat_indices, cat_names = run_50k()