Leacb4 commited on
Commit
f38441f
·
verified ·
1 Parent(s): bc323e6

Upload evaluation/annex92_color_heatmaps.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/annex92_color_heatmaps.py +269 -51
evaluation/annex92_color_heatmaps.py CHANGED
@@ -24,6 +24,7 @@ import pandas as pd
24
  import numpy as np
25
  import matplotlib.pyplot as plt
26
  import seaborn as sns
 
27
  from sklearn.metrics.pairwise import cosine_similarity
28
  from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
29
  from sklearn.model_selection import train_test_split
@@ -42,6 +43,14 @@ PRIMARY_COLORS = [
42
  'orange', 'purple', 'brown', 'gray', 'black', 'white'
43
  ]
44
 
 
 
 
 
 
 
 
 
45
  class ColorEncoder:
46
  def __init__(self, main_model_path, device='mps'):
47
  self.device = torch.device(device)
@@ -63,6 +72,13 @@ class ColorEncoder:
63
 
64
  # Create processor
65
  self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
 
 
 
 
 
 
 
66
 
67
  # Load dataset
68
  self._load_dataset()
@@ -115,8 +131,14 @@ class ColorEncoder:
115
 
116
  return dataloader
117
 
118
- def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
119
- """Extract color embeddings (first 16 dimensions) from text or image"""
 
 
 
 
 
 
120
  all_embeddings = []
121
  all_colors = []
122
 
@@ -131,12 +153,20 @@ class ColorEncoder:
131
  images = images.to(self.device)
132
  images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
133
 
134
- # Process text inputs
135
- text_inputs = self.processor(text=texts, padding=True, return_tensors="pt")
 
 
 
 
 
 
 
 
136
  text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
137
 
138
  # Forward pass through main model
139
- outputs = self.main_model(**text_inputs, pixel_values=images)
140
 
141
  # Extract embeddings based on type
142
  if embedding_type == 'text':
@@ -324,7 +354,19 @@ class ColorEncoder:
324
 
325
  return results
326
 
327
- def create_color_similarity_heatmap(self, embeddings, colors, embedding_type='text', save_path='evaluation/color_similarity_results/color_similarity_heatmap.png'):
 
 
 
 
 
 
 
 
 
 
 
 
328
  """
329
  Create a heatmap of similarities between encoded colors
330
  """
@@ -338,34 +380,114 @@ class ColorEncoder:
338
  if len(color_indices) > 0:
339
  color_embeddings = embeddings[color_indices]
340
  centroids[color] = np.mean(color_embeddings, axis=0)
 
 
 
 
 
 
 
 
341
 
342
  similarity_matrix = np.zeros((len(unique_colors), len(unique_colors)))
343
 
344
  for i, color1 in enumerate(unique_colors):
345
  for j, color2 in enumerate(unique_colors):
346
  if i == j:
 
347
  similarity_matrix[i, j] = 1.0
348
  else:
349
  similarity = cosine_similarity([centroids[color1]], [centroids[color2]])[0][0]
350
  similarity_matrix[i, j] = similarity
351
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  plt.figure(figsize=(12, 10))
353
-
354
- sns.heatmap(
355
- similarity_matrix,
356
- annot=True,
357
- fmt='.3f',
358
- cmap='RdYlBu_r',
359
  xticklabels=unique_colors,
360
  yticklabels=unique_colors,
361
  square=True,
362
- cbar_kws={'label': 'Cosine Similarity'},
363
  linewidths=0.5,
364
- vmin=-0.6,
365
- vmax=1.0
366
  )
367
-
368
- plt.title(f'Color similarity ({embedding_type} embeddings)',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  fontsize=16, fontweight='bold', pad=20)
370
  plt.xlabel('Colors', fontsize=14, fontweight='bold')
371
  plt.ylabel('Colors', fontsize=14, fontweight='bold')
@@ -377,6 +499,99 @@ class ColorEncoder:
377
  print(f"💾 Heatmap saved: {save_path}")
378
 
379
  return plt.gcf(), similarity_matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
 
382
 
@@ -518,38 +733,41 @@ if __name__ == "__main__":
518
  device=device
519
  )
520
 
521
- # Evaluate primary color classification
522
  results = color_encoder.evaluate_color_classification(
523
- color_encoder.val_df,
524
- max_samples=10000
525
  )
526
-
527
- if results:
528
- print(f"\n✅ Primary color encoding and confusion matrix generation completed!")
529
- print(f"📊 Results saved in 'evaluation/color_evaluation_results/' directory")
530
- print(f"🎨 Text Primary Color Accuracy: {results['text']['accuracy']*100:.1f}%")
531
- print(f"🖼️ Image Primary Color Accuracy: {results['image']['accuracy']*100:.1f}%")
532
-
533
- # NOUVELLE SECTION: Analyse des similarités
534
- print(f"\n🎨 Starting Color Similarity Analysis...")
535
- similarity_results = color_encoder.create_color_similarity_analysis(results)
536
-
537
- print(f"\n✅ Color similarity analysis completed!")
538
- print(f"📊 Similarity heatmaps saved in 'evaluation/color_similarity_results/' directory")
539
-
540
- # Show some sample predictions
541
- print(f"\n📝 Sample Text Predictions:")
542
- for i in range(min(10, len(results['text']['true_colors']))):
543
- true_color = results['text']['true_colors'][i]
544
- pred_color = results['text']['predicted_colors'][i]
545
- status = "✓" if true_color == pred_color else "✗"
546
- print(f" {status} True: {true_color:>8} | Predicted: {pred_color:>8}")
547
-
548
- print(f"\n🖼️ Sample Image Predictions:")
549
- for i in range(min(10, len(results['image']['true_colors']))):
550
- true_color = results['image']['true_colors'][i]
551
- pred_color = results['image']['predicted_colors'][i]
552
- status = "✓" if true_color == pred_color else "✗"
553
- print(f" {status} True: {true_color:>8} | Predicted: {pred_color:>8}")
554
- else:
555
- print("❌ No results generated - check if primary colors exist in dataset")
 
 
 
 
24
  import numpy as np
25
  import matplotlib.pyplot as plt
26
  import seaborn as sns
27
+ from matplotlib.colors import TwoSlopeNorm
28
  from sklearn.metrics.pairwise import cosine_similarity
29
  from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
30
  from sklearn.model_selection import train_test_split
 
43
  'orange', 'purple', 'brown', 'gray', 'black', 'white'
44
  ]
45
 
46
+ # Fashion-CLIP baseline (used for "baseline" heatmaps).
47
+ BASELINE_MODEL_NAME = "patrickjohncyh/fashion-clip"
48
+
49
+ # Degradation strength for the similarity heatmaps.
50
+ # Higher values mix each color centroid more strongly towards the global centroid,
51
+ # which increases cross-color confusion ("more degraded colors").
52
+ COLOR_CENTROID_DEGRADATION_STRENGTH = 0.30
53
+
54
  class ColorEncoder:
55
  def __init__(self, main_model_path, device='mps'):
56
  self.device = torch.device(device)
 
72
 
73
  # Create processor
74
  self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
75
+
76
+ # Load baseline Fashion-CLIP model (for baseline heatmaps).
77
+ print(f"📦 Loading Baseline Fashion-CLIP model from {BASELINE_MODEL_NAME} ...")
78
+ self.baseline_model = CLIPModel_transformers.from_pretrained(BASELINE_MODEL_NAME).to(self.device)
79
+ self.baseline_model.eval()
80
+ self.baseline_processor = CLIPProcessor.from_pretrained(BASELINE_MODEL_NAME)
81
+ print("✅ Baseline Fashion-CLIP model loaded successfully")
82
 
83
  # Load dataset
84
  self._load_dataset()
 
131
 
132
  return dataloader
133
 
134
+ def extract_color_embeddings(self, dataloader, embedding_type='text', model_kind='main', max_samples=10000):
135
+ """
136
+ Extract color embeddings (first 16 dimensions) from text or image.
137
+
138
+ model_kind:
139
+ - "main": GAP-CLIP specialized checkpoint (self.main_model)
140
+ - "baseline": Fashion-CLIP baseline (self.baseline_model)
141
+ """
142
  all_embeddings = []
143
  all_colors = []
144
 
 
153
  images = images.to(self.device)
154
  images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
155
 
156
+ # Select model/processor.
157
+ if model_kind == 'baseline':
158
+ model = self.baseline_model
159
+ processor = self.baseline_processor
160
+ else:
161
+ model = self.main_model
162
+ processor = self.processor
163
+
164
+ # Process text inputs.
165
+ text_inputs = processor(text=texts, padding=True, return_tensors="pt")
166
  text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
167
 
168
  # Forward pass through main model
169
+ outputs = model(**text_inputs, pixel_values=images)
170
 
171
  # Extract embeddings based on type
172
  if embedding_type == 'text':
 
354
 
355
  return results
356
 
357
+ def create_color_similarity_heatmap(
358
+ self,
359
+ embeddings,
360
+ colors,
361
+ embedding_type='text',
362
+ save_path='evaluation/color_similarity_results/color_similarity_heatmap.png',
363
+ centroid_degradation_strength: float = 0.0,
364
+ heatmap_metric: str = "similarity",
365
+ annot: bool = True,
366
+ mask_diagonal: bool = True,
367
+ contrast_percentiles: tuple[float, float] = (5.0, 95.0),
368
+ print_stats: bool = True,
369
+ ):
370
  """
371
  Create a heatmap of similarities between encoded colors
372
  """
 
380
  if len(color_indices) > 0:
381
  color_embeddings = embeddings[color_indices]
382
  centroids[color] = np.mean(color_embeddings, axis=0)
383
+
384
+ # Degrade colors by mixing each color centroid toward the global centroid.
385
+ # This increases cross-color similarity and visually "degrades" the color separation.
386
+ centroid_degradation_strength = float(centroid_degradation_strength)
387
+ if centroid_degradation_strength > 0 and len(centroids) > 1:
388
+ global_centroid = np.mean(np.stack(list(centroids.values())), axis=0)
389
+ for c in centroids:
390
+ centroids[c] = (1 - centroid_degradation_strength) * centroids[c] + centroid_degradation_strength * global_centroid
391
 
392
  similarity_matrix = np.zeros((len(unique_colors), len(unique_colors)))
393
 
394
  for i, color1 in enumerate(unique_colors):
395
  for j, color2 in enumerate(unique_colors):
396
  if i == j:
397
+ # Cosine between a vector and itself is 1 (centroids are fixed points).
398
  similarity_matrix[i, j] = 1.0
399
  else:
400
  similarity = cosine_similarity([centroids[color1]], [centroids[color2]])[0][0]
401
  similarity_matrix[i, j] = similarity
402
+
403
+ # For visualization: masking diagonal + using off-diagonal auto-contrast
404
+ # makes cross-color differences much more visible.
405
+ n = len(unique_colors)
406
+ mask = np.eye(n, dtype=bool) if mask_diagonal else np.zeros((n, n), dtype=bool)
407
+
408
+ if print_stats:
409
+ off_diag_similarity = similarity_matrix[~mask]
410
+ # Most similar off-diagonal pair = where the model confuses colors most.
411
+ masked_similarity = np.where(mask, -np.inf, similarity_matrix)
412
+ max_i, max_j = np.unravel_index(np.argmax(masked_similarity), similarity_matrix.shape)
413
+ # Least similar off-diagonal pair = most separated colors.
414
+ masked_similarity_min = np.where(mask, np.inf, similarity_matrix)
415
+ min_i, min_j = np.unravel_index(np.argmin(masked_similarity_min), similarity_matrix.shape)
416
+ print(
417
+ f"📈 {embedding_type.upper()} | off-diagonal cosine similarity: "
418
+ f"mean={float(off_diag_similarity.mean()):.3f}, std={float(off_diag_similarity.std()):.3f}"
419
+ )
420
+ print(
421
+ f"📍 {embedding_type.upper()} | most similar pair: "
422
+ f"{unique_colors[max_i]} ↔ {unique_colors[max_j]} = {float(similarity_matrix[max_i, max_j]):.3f}"
423
+ )
424
+ print(
425
+ f"📍 {embedding_type.upper()} | least similar pair: "
426
+ f"{unique_colors[min_i]} ↔ {unique_colors[min_j]} = {float(similarity_matrix[min_i, min_j]):.3f}"
427
+ )
428
+
429
+ if heatmap_metric == "similarity":
430
+ plot_matrix = similarity_matrix
431
+ cbar_label = "Cosine Similarity"
432
+ cmap = "RdYlBu_r"
433
+ # Use off-diagonal values to compute contrast.
434
+ off_diag_vals = plot_matrix[~mask]
435
+ elif heatmap_metric == "separation":
436
+ # Higher values => colors are less similar (more separated).
437
+ plot_matrix = 1.0 - similarity_matrix
438
+ cbar_label = "Separation (1 - Cosine Similarity)"
439
+ cmap = "magma"
440
+ off_diag_vals = plot_matrix[~mask]
441
+ else:
442
+ raise ValueError(f"Unsupported heatmap_metric: {heatmap_metric}")
443
+
444
+ # Robust auto-contrast: percentiles avoid single extreme values dominating.
445
+ lo_p, hi_p = contrast_percentiles
446
+ vmin = float(np.percentile(off_diag_vals, lo_p)) if off_diag_vals.size > 0 else None
447
+ vmax = float(np.percentile(off_diag_vals, hi_p)) if off_diag_vals.size > 0 else None
448
+
449
  plt.figure(figsize=(12, 10))
450
+
451
+ heatmap_kwargs = dict(
452
+ data=plot_matrix,
453
+ mask=mask,
454
+ annot=annot,
455
+ fmt=".3f" if annot else "",
456
  xticklabels=unique_colors,
457
  yticklabels=unique_colors,
458
  square=True,
459
+ cbar_kws={"label": cbar_label},
460
  linewidths=0.5,
 
 
461
  )
462
+
463
+ if heatmap_metric == "similarity":
464
+ # Diverging scale centered at 0 to emphasize "opposite" directions.
465
+ if vmin is not None and vmax is not None and vmin != vmax:
466
+ # TwoSlopeNorm requires: vmin < vcenter < vmax
467
+ if vmin < 0.0 < vmax:
468
+ vcenter = 0.0
469
+ else:
470
+ # If all values are one-sided (e.g. all positive), pick midpoint.
471
+ vcenter = (vmin + vmax) / 2.0
472
+
473
+ if vmin < vcenter < vmax:
474
+ norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
475
+ heatmap_kwargs["norm"] = norm
476
+ else:
477
+ heatmap_kwargs["vmin"] = vmin
478
+ heatmap_kwargs["vmax"] = vmax
479
+ else:
480
+ heatmap_kwargs["vmin"] = vmin
481
+ heatmap_kwargs["vmax"] = vmax
482
+ else:
483
+ # Sequential scale for "separation" (>=0).
484
+ heatmap_kwargs["vmin"] = vmin
485
+ heatmap_kwargs["vmax"] = vmax
486
+
487
+ sns.heatmap(cmap=cmap, **heatmap_kwargs)
488
+
489
+ title_suffix = "separation" if heatmap_metric == "separation" else "similarity"
490
+ plt.title(f"Color {title_suffix} ({embedding_type} embeddings)",
491
  fontsize=16, fontweight='bold', pad=20)
492
  plt.xlabel('Colors', fontsize=14, fontweight='bold')
493
  plt.ylabel('Colors', fontsize=14, fontweight='bold')
 
499
  print(f"💾 Heatmap saved: {save_path}")
500
 
501
  return plt.gcf(), similarity_matrix
502
+
503
+ def generate_similarity_heatmaps(
504
+ self,
505
+ dataloader,
506
+ model_kind: str,
507
+ max_samples: int,
508
+ centroid_degradation_strength: float,
509
+ ):
510
+ """
511
+ Generate and save similarity heatmaps (text + image) for a given model kind.
512
+ """
513
+ if model_kind not in {'main', 'baseline'}:
514
+ raise ValueError(f"Unsupported model_kind: {model_kind}")
515
+
516
+ os.makedirs('evaluation/color_similarity_results', exist_ok=True)
517
+
518
+ print(f"\n🎨 Generating similarity heatmaps for model_kind={model_kind} "
519
+ f"(degradation_strength={centroid_degradation_strength})...")
520
+
521
+ # Text heatmap.
522
+ text_embeddings, text_colors = self.extract_color_embeddings(
523
+ dataloader,
524
+ embedding_type='text',
525
+ model_kind=model_kind,
526
+ max_samples=max_samples,
527
+ )
528
+ main_or_baseline = 'gap_clip' if model_kind == 'main' else 'fashion_clip_baseline'
529
+ text_save_path = (
530
+ 'evaluation/color_similarity_results/text_color_similarity_heatmap.png'
531
+ if model_kind == 'main'
532
+ else f'evaluation/color_similarity_results/{main_or_baseline}_text_color_similarity_heatmap.png'
533
+ )
534
+ text_fig, _ = self.create_color_similarity_heatmap(
535
+ text_embeddings,
536
+ text_colors,
537
+ embedding_type='text',
538
+ save_path=text_save_path,
539
+ centroid_degradation_strength=centroid_degradation_strength,
540
+ )
541
+ plt.close(text_fig)
542
+
543
+ # Text separation heatmap (more visually sensitive than raw similarity).
544
+ text_sep_save_path = (
545
+ 'evaluation/color_similarity_results/text_color_separation_heatmap.png'
546
+ if model_kind == 'main'
547
+ else f'evaluation/color_similarity_results/{main_or_baseline}_text_color_separation_heatmap.png'
548
+ )
549
+ text_sep_fig, _ = self.create_color_similarity_heatmap(
550
+ text_embeddings,
551
+ text_colors,
552
+ embedding_type='text',
553
+ save_path=text_sep_save_path,
554
+ centroid_degradation_strength=centroid_degradation_strength,
555
+ heatmap_metric="separation",
556
+ )
557
+ plt.close(text_sep_fig)
558
+
559
+ # Image heatmap.
560
+ image_embeddings, image_colors = self.extract_color_embeddings(
561
+ dataloader,
562
+ embedding_type='image',
563
+ model_kind=model_kind,
564
+ max_samples=max_samples,
565
+ )
566
+ image_save_path = (
567
+ 'evaluation/color_similarity_results/image_color_similarity_heatmap.png'
568
+ if model_kind == 'main'
569
+ else f'evaluation/color_similarity_results/{main_or_baseline}_image_color_similarity_heatmap.png'
570
+ )
571
+ image_fig, _ = self.create_color_similarity_heatmap(
572
+ image_embeddings,
573
+ image_colors,
574
+ embedding_type='image',
575
+ save_path=image_save_path,
576
+ centroid_degradation_strength=centroid_degradation_strength,
577
+ )
578
+ plt.close(image_fig)
579
+
580
+ # Image separation heatmap.
581
+ image_sep_save_path = (
582
+ 'evaluation/color_similarity_results/image_color_separation_heatmap.png'
583
+ if model_kind == 'main'
584
+ else f'evaluation/color_similarity_results/{main_or_baseline}_image_color_separation_heatmap.png'
585
+ )
586
+ image_sep_fig, _ = self.create_color_similarity_heatmap(
587
+ image_embeddings,
588
+ image_colors,
589
+ embedding_type='image',
590
+ save_path=image_sep_save_path,
591
+ centroid_degradation_strength=centroid_degradation_strength,
592
+ heatmap_metric="separation",
593
+ )
594
+ plt.close(image_sep_fig)
595
 
596
 
597
 
 
733
  device=device
734
  )
735
 
736
+ # Evaluate primary color classification for the main model (keeps previous behavior).
737
  results = color_encoder.evaluate_color_classification(
738
+ color_encoder.val_df,
739
+ max_samples=10000,
740
  )
741
+
742
+ if not results:
743
+ print(" No results generated - check if primary colors exist in dataset")
744
+ raise SystemExit(1)
745
+
746
+ print(f"\n✅ Primary color encoding and confusion matrix generation completed!")
747
+ print(f"📊 Results saved in 'evaluation/color_evaluation_results/' directory")
748
+ print(f"🎨 Text Primary Color Accuracy: {results['text']['accuracy']*100:.1f}%")
749
+ print(f"🖼️ Image Primary Color Accuracy: {results['image']['accuracy']*100:.1f}%")
750
+
751
+ # Heatmaps with additional centroid degradation (main model + baseline).
752
+ dataloader = color_encoder.create_dataloader(color_encoder.val_df, batch_size=8)
753
+ max_samples = 10000
754
+ centroid_degradation_strength = COLOR_CENTROID_DEGRADATION_STRENGTH
755
+
756
+ # Your model (GAP-CLIP main checkpoint): overwrites the existing heatmap filenames.
757
+ color_encoder.generate_similarity_heatmaps(
758
+ dataloader=dataloader,
759
+ model_kind='main',
760
+ max_samples=max_samples,
761
+ centroid_degradation_strength=centroid_degradation_strength,
762
+ )
763
+
764
+ # Baseline Fashion-CLIP: saved as fashion_clip_baseline_* heatmaps.
765
+ color_encoder.generate_similarity_heatmaps(
766
+ dataloader=dataloader,
767
+ model_kind='baseline',
768
+ max_samples=max_samples,
769
+ centroid_degradation_strength=centroid_degradation_strength,
770
+ )
771
+
772
+ print("\n✅ Color similarity analysis completed!")
773
+ print("📊 Similarity heatmaps saved in 'evaluation/color_similarity_results/' directory")