Leacb4 commited on
Commit
942267d
·
verified ·
1 Parent(s): 7a5206c

Upload evaluation/sec51_color_model_eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/sec51_color_model_eval.py +143 -647
evaluation/sec51_color_model_eval.py CHANGED
@@ -20,31 +20,16 @@ Paper reference: Section 5.1, Table 1.
20
  """
21
 
22
  import os
23
- import json
24
- import hashlib
25
- import requests
26
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
  import sys
28
  from pathlib import Path
29
 
30
  import torch
31
- import pandas as pd
32
- import numpy as np
33
  import matplotlib.pyplot as plt
34
- import seaborn as sns
35
- import difflib
36
- from sklearn.metrics.pairwise import cosine_similarity
37
- from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
38
- from collections import defaultdict
39
- from tqdm import tqdm
40
- from torch.utils.data import Dataset, DataLoader
41
- from torchvision import transforms
42
- from PIL import Image
43
- from io import BytesIO
44
  import warnings
45
  warnings.filterwarnings('ignore')
46
- from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
47
- from huggingface_hub import hf_hub_download
48
 
49
  # Ensure project root is importable when running this file directly.
50
  PROJECT_ROOT = Path(__file__).resolve().parent.parent
@@ -54,220 +39,20 @@ if str(PROJECT_ROOT) not in sys.path:
54
  from config import (
55
  color_model_path,
56
  color_emb_dim,
57
- local_dataset_path,
58
- column_local_image_path,
59
- tokeniser_path,
60
- images_dir,
61
  )
62
- from training.color_model import ColorCLIP, Tokenizer
63
-
64
-
65
- class KaggleDataset(Dataset):
66
- """Dataset class for KAGL Marqo dataset"""
67
- def __init__(self, dataframe, image_size=224):
68
- self.dataframe = dataframe
69
- self.image_size = image_size
70
-
71
- # Transforms for validation (no augmentation)
72
- self.transform = transforms.Compose([
73
- transforms.Resize((224, 224)),
74
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # AUGMENTATION
75
- transforms.ToTensor(),
76
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
77
- ])
78
-
79
- def __len__(self):
80
- return len(self.dataframe)
81
-
82
- def __getitem__(self, idx):
83
- row = self.dataframe.iloc[idx]
84
-
85
- # Handle image - it should be in row['image_url'] and contain the image data as bytes
86
- image_data = row['image_url']
87
-
88
- # Check if image_data has 'bytes' key or is already PIL Image
89
- if isinstance(image_data, dict) and 'bytes' in image_data:
90
- image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
91
- elif hasattr(image_data, 'convert'): # Already a PIL Image
92
- image = image_data.convert("RGB")
93
- else:
94
- # Assume it's raw bytes
95
- image = Image.open(BytesIO(image_data)).convert("RGB")
96
-
97
- # Apply validation transform
98
- image = self.transform(image)
99
-
100
- # Get text and labels
101
- description = row['text']
102
- color = row['color']
103
-
104
- return image, description, color
105
-
106
-
107
- def load_kaggle_marqo_dataset(max_samples=5000):
108
- """Load and prepare Kaggle KAGL dataset with memory optimization"""
109
- from datasets import load_dataset
110
- print("📊 Loading Kaggle KAGL dataset...")
111
-
112
- # Load the dataset
113
- dataset = load_dataset("Marqo/KAGL")
114
- df = dataset["data"].to_pandas()
115
- print(f"✅ Dataset Kaggle loaded")
116
- print(f" Before filtering: {len(df)} samples")
117
- print(f" Available columns: {list(df.columns)}")
118
-
119
- # Ensure we have text and image data
120
- df = df.dropna(subset=['text', 'image'])
121
- print(f" After removing missing text/image: {len(df)} samples")
122
-
123
- df_test = df.copy()
124
-
125
- # Limit to max_samples with RANDOM SAMPLING to get diverse colors
126
- if len(df_test) > max_samples:
127
- df_test = df_test.sample(n=max_samples, random_state=42)
128
- print(f"📊 Randomly sampled {max_samples} samples from Kaggle dataset")
129
-
130
- # Create formatted dataset with proper column names
131
- kaggle_formatted = pd.DataFrame({
132
- 'image_url': df_test['image'], # This contains image data as bytes
133
- 'text': df_test['text'],
134
- 'color': df_test['baseColour'].str.lower().str.replace("grey", "gray") # Use actual colors
135
- })
136
-
137
- # Filter out rows with None/NaN colors
138
- before_color_filter = len(kaggle_formatted)
139
- kaggle_formatted = kaggle_formatted.dropna(subset=['color'])
140
- if len(kaggle_formatted) < before_color_filter:
141
- print(f" After removing missing colors: {len(kaggle_formatted)} samples (removed {before_color_filter - len(kaggle_formatted)} samples)")
142
-
143
- # Filter for colors that were used during training (11 colors)
144
- valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
145
- before_valid_filter = len(kaggle_formatted)
146
- kaggle_formatted = kaggle_formatted[kaggle_formatted['color'].isin(valid_colors)]
147
- print(f" After filtering for valid colors: {len(kaggle_formatted)} samples (removed {before_valid_filter - len(kaggle_formatted)} samples)")
148
- print(f" Valid colors found: {sorted(kaggle_formatted['color'].unique())}")
149
-
150
- print(f" Final dataset size: {len(kaggle_formatted)} samples")
151
-
152
- # Show color distribution in final dataset
153
- print(f"🎨 Color distribution in Kaggle dataset:")
154
- color_counts = kaggle_formatted['color'].value_counts()
155
- for color in color_counts.index:
156
- print(f" {color}: {color_counts[color]} samples")
157
-
158
- return KaggleDataset(kaggle_formatted)
159
-
160
-
161
- class LocalDataset(Dataset):
162
- """Dataset class for local validation dataset"""
163
- def __init__(self, dataframe, image_size=224):
164
- self.dataframe = dataframe
165
- self.image_size = image_size
166
-
167
- # Transforms for validation (no augmentation)
168
- self.transform = transforms.Compose([
169
- transforms.Resize((224, 224)),
170
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # AUGMENTATION
171
- transforms.ToTensor(),
172
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
173
- ])
174
-
175
- def __len__(self):
176
- return len(self.dataframe)
177
-
178
- def __getitem__(self, idx):
179
- row = self.dataframe.iloc[idx]
180
-
181
- try:
182
- # Try local path first
183
- image_path = row.get(column_local_image_path) if hasattr(row, 'get') else None
184
- if isinstance(image_path, str) and image_path and os.path.exists(image_path):
185
- image = Image.open(image_path).convert("RGB")
186
- else:
187
- # Fallback: download from image_url with caching
188
- image_url = row.get('image_url') if hasattr(row, 'get') else None
189
- if isinstance(image_url, str) and image_url:
190
- cache_dir = Path(images_dir)
191
- cache_dir.mkdir(parents=True, exist_ok=True)
192
- url_hash = hashlib.md5(image_url.encode("utf-8")).hexdigest()
193
- cache_path = cache_dir / f"{url_hash}.jpg"
194
- if cache_path.exists():
195
- image = Image.open(cache_path).convert("RGB")
196
- else:
197
- resp = requests.get(image_url, timeout=10)
198
- resp.raise_for_status()
199
- image = Image.open(BytesIO(resp.content)).convert("RGB")
200
- image.save(cache_path, "JPEG", quality=85, optimize=True)
201
- else:
202
- raise ValueError("No valid image_path or image_url")
203
- except Exception as e:
204
- image = Image.new('RGB', (224, 224), color='gray')
205
-
206
- # Apply transform
207
- image = self.transform(image)
208
-
209
- # Get text and labels
210
- description = row['text']
211
- color = row['color']
212
-
213
- return image, description, color
214
-
215
-
216
- def load_local_validation_dataset(max_samples=5000):
217
- """Load and prepare local validation dataset"""
218
- print("📊 Loading local validation dataset...")
219
-
220
- df = pd.read_csv(local_dataset_path)
221
- print(f"✅ Dataset loaded: {len(df)} samples")
222
-
223
- # Filter out rows with NaN values in image path (use whichever column exists)
224
- img_col = column_local_image_path if column_local_image_path in df.columns else 'image_url'
225
- df_clean = df.dropna(subset=[img_col])
226
- print(f"📊 After filtering NaN image paths ({img_col}): {len(df_clean)} samples")
227
-
228
- # Filter for colors that were used during training (11 colors)
229
- valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
230
- if 'color' in df_clean.columns:
231
- before_valid_filter = len(df_clean)
232
- df_clean = df_clean[df_clean['color'].isin(valid_colors)]
233
- print(f"📊 After filtering for valid colors: {len(df_clean)} samples (removed {before_valid_filter - len(df_clean)} samples)")
234
- print(f"🎨 Valid colors found: {sorted(df_clean['color'].unique())}")
235
-
236
- # Limit to max_samples with RANDOM SAMPLING to get diverse colors
237
- if len(df_clean) > max_samples:
238
- df_clean = df_clean.sample(n=max_samples, random_state=42)
239
- print(f"📊 Randomly sampled {max_samples} samples")
240
-
241
- print(f"📊 Using {len(df_clean)} samples for evaluation")
242
-
243
- # Show color distribution after sampling
244
- if 'color' in df_clean.columns:
245
- print(f"🎨 Color distribution in sampled data:")
246
- color_counts = df_clean['color'].value_counts()
247
- print(f" Total unique colors: {len(color_counts)}")
248
- for color in color_counts.index[:15]: # Show top 15
249
- print(f" {color}: {color_counts[color]} samples")
250
-
251
- return LocalDataset(df_clean)
252
-
253
-
254
- def collate_fn_filter_none(batch):
255
- """Collate function that filters out None values from batch with debug print"""
256
- # Filter out None values
257
- original_len = len(batch)
258
- batch = [item for item in batch if item is not None]
259
-
260
- if original_len > len(batch):
261
- print(f"⚠️ Filtered out {original_len - len(batch)} None values from batch (original: {original_len}, filtered: {len(batch)})")
262
-
263
- if len(batch) == 0:
264
- # Return empty batch with correct structure
265
- print("⚠️ Empty batch after filtering None values")
266
- return torch.tensor([]), [], []
267
-
268
- images, texts, colors = zip(*batch)
269
- images = torch.stack(images, dim=0)
270
- return images, list(texts), list(colors)
271
 
272
 
273
  class ColorEvaluator:
@@ -277,325 +62,54 @@ class ColorEvaluator:
277
  self,
278
  device='mps',
279
  directory="figures/confusion_matrices/cm_color",
280
- repo_id="Leacb4/gap-clip",
281
- cache_dir="./models_cache",
 
 
 
282
  ):
283
  self.device = torch.device(device)
284
  self.directory = directory
285
  self.color_emb_dim = color_emb_dim
286
- self.repo_id = repo_id
287
- self.cache_dir = cache_dir
 
288
  os.makedirs(self.directory, exist_ok=True)
289
-
290
- # Load baseline Fashion CLIP model
291
- print("📦 Loading baseline Fashion CLIP model...")
292
- patrick_model_name = "patrickjohncyh/fashion-clip"
293
- self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name)
294
- self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device)
295
- self.baseline_model.eval()
296
- print("✅ Baseline Fashion CLIP model loaded successfully")
297
-
298
- # Load specialized color model (16D)
299
- self.color_model = None
300
- self.color_tokenizer = None
301
- self._load_color_model()
302
-
303
- def _load_color_model(self):
304
- """Load the specialized 16D color model and tokenizer."""
305
- if self.color_model is not None and self.color_tokenizer is not None:
306
- return
307
-
308
- local_model_exists = os.path.exists(color_model_path)
309
- local_tokenizer_exists = os.path.exists(tokeniser_path)
310
-
311
- if local_model_exists and local_tokenizer_exists:
312
- print("🎨 Loading specialized color model (16D) from local files...")
313
- state_dict = torch.load(color_model_path, map_location=self.device)
314
- with open(tokeniser_path, "r") as f:
315
- vocab = json.load(f)
316
- else:
317
- print("🎨 Local color model/tokenizer not found. Loading from Hugging Face...")
318
- print(f" Repo: {self.repo_id}")
319
- hf_model_path = hf_hub_download(
320
- repo_id=self.repo_id,
321
- filename="color_model.pt",
322
- cache_dir=self.cache_dir,
323
- )
324
- hf_vocab_path = hf_hub_download(
325
- repo_id=self.repo_id,
326
- filename="tokenizer_vocab.json",
327
- cache_dir=self.cache_dir,
328
- )
329
- state_dict = torch.load(hf_model_path, map_location=self.device)
330
- with open(hf_vocab_path, "r") as f:
331
- vocab = json.load(f)
332
-
333
- # Get vocab size from the embedding weight shape in checkpoint
334
- vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
335
- print(f" Detected vocab size from checkpoint: {vocab_size}")
336
-
337
- self.color_tokenizer = Tokenizer()
338
- self.color_tokenizer.load_vocab(vocab)
339
-
340
- # Create model with the vocab size from checkpoint (not from tokenizer)
341
- self.color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=self.color_emb_dim)
342
-
343
- # Load state dict
344
- self.color_model.load_state_dict(state_dict)
345
- self.color_model.to(self.device)
346
- self.color_model.eval()
347
- print("✅ Color model loaded successfully")
348
-
349
- def _tokenize_color_texts(self, texts):
350
- """Tokenize texts with the color tokenizer and return padded tensors."""
351
- token_lists = [self.color_tokenizer(t) for t in texts]
352
- max_len = max((len(toks) for toks in token_lists), default=0)
353
- max_len = max_len if max_len > 0 else 1
354
-
355
- input_ids = torch.zeros(len(texts), max_len, dtype=torch.long, device=self.device)
356
- lengths = torch.zeros(len(texts), dtype=torch.long, device=self.device)
357
-
358
- for i, toks in enumerate(token_lists):
359
- if len(toks) > 0:
360
- input_ids[i, :len(toks)] = torch.tensor(toks, dtype=torch.long, device=self.device)
361
- lengths[i] = len(toks)
362
- else:
363
- lengths[i] = 1 # avoid zero-length
364
-
365
- return input_ids, lengths
366
-
367
- def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
368
- """Extract 16D color embeddings from specialized color model."""
369
- self._load_color_model()
370
- all_embeddings = []
371
- all_colors = []
372
-
373
- sample_count = 0
374
- with torch.no_grad():
375
- for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings"):
376
- if sample_count >= max_samples:
377
- break
378
-
379
- images, texts, colors = batch
380
- images = images.to(self.device)
381
- images = images.expand(-1, 3, -1, -1)
382
-
383
- if embedding_type == 'text':
384
- input_ids, lengths = self._tokenize_color_texts(texts)
385
- embeddings = self.color_model.text_encoder(input_ids, lengths)
386
- elif embedding_type == 'image':
387
- embeddings = self.color_model.image_encoder(images)
388
- else:
389
- input_ids, lengths = self._tokenize_color_texts(texts)
390
- embeddings = self.color_model.text_encoder(input_ids, lengths)
391
-
392
- all_embeddings.append(embeddings.cpu().numpy())
393
- normalized_colors = [str(c).lower().strip().replace("grey", "gray") for c in colors]
394
- all_colors.extend(normalized_colors)
395
-
396
- sample_count += len(images)
397
-
398
- del images, embeddings
399
- if embedding_type != 'image':
400
- del input_ids, lengths
401
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
402
-
403
- return np.vstack(all_embeddings), all_colors
404
-
405
- def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
406
- """Extract embeddings from baseline Fashion CLIP model"""
407
- all_embeddings = []
408
- all_colors = []
409
-
410
- sample_count = 0
411
-
412
- with torch.no_grad():
413
- for batch in tqdm(dataloader, desc=f"Extracting baseline {embedding_type} embeddings"):
414
- if sample_count >= max_samples:
415
- break
416
-
417
- images, texts, colors = batch
418
- images = images.to(self.device)
419
- images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
420
-
421
- # Process text inputs with baseline processor
422
- text_inputs = self.baseline_processor(text=texts, padding=True, return_tensors="pt")
423
- text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
424
-
425
- # Forward pass through baseline model
426
- outputs = self.baseline_model(**text_inputs, pixel_values=images)
427
-
428
- # Extract embeddings based on type
429
- if embedding_type == 'text':
430
- embeddings = outputs.text_embeds
431
- elif embedding_type == 'image':
432
- embeddings = outputs.image_embeds
433
- else:
434
- embeddings = outputs.text_embeds
435
-
436
- all_embeddings.append(embeddings.cpu().numpy())
437
- all_colors.extend(colors)
438
-
439
- sample_count += len(images)
440
-
441
- # Clear GPU memory
442
- del images, text_inputs, outputs, embeddings
443
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
444
-
445
- return np.vstack(all_embeddings), all_colors
446
-
447
- def compute_similarity_metrics(self, embeddings, labels):
448
- """Compute intra-class and inter-class similarities - optimized version"""
449
- max_samples = min(5000, len(embeddings))
450
- if len(embeddings) > max_samples:
451
- indices = np.random.choice(len(embeddings), max_samples, replace=False)
452
- embeddings = embeddings[indices]
453
- labels = [labels[i] for i in indices]
454
-
455
- similarities = cosine_similarity(embeddings)
456
-
457
- # Create label groups using numpy for faster indexing
458
- label_array = np.array(labels)
459
- unique_labels = np.unique(label_array)
460
- label_groups = {label: np.where(label_array == label)[0] for label in unique_labels}
461
-
462
- # Compute intra-class similarities using vectorized operations
463
- intra_class_similarities = []
464
- for label, indices in label_groups.items():
465
- if len(indices) > 1:
466
- # Extract submatrix for this class
467
- class_similarities = similarities[np.ix_(indices, indices)]
468
- # Get upper triangle (excluding diagonal)
469
- triu_indices = np.triu_indices_from(class_similarities, k=1)
470
- intra_class_similarities.extend(class_similarities[triu_indices].tolist())
471
-
472
- # Compute inter-class similarities using vectorized operations
473
- inter_class_similarities = []
474
- labels_list = list(label_groups.keys())
475
- for i in range(len(labels_list)):
476
- for j in range(i + 1, len(labels_list)):
477
- label1_indices = label_groups[labels_list[i]]
478
- label2_indices = label_groups[labels_list[j]]
479
- # Extract submatrix between two classes
480
- inter_sims = similarities[np.ix_(label1_indices, label2_indices)]
481
- inter_class_similarities.extend(inter_sims.flatten().tolist())
482
-
483
- nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
484
- centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
485
 
486
- return {
487
- 'intra_class_similarities': intra_class_similarities,
488
- 'inter_class_similarities': inter_class_similarities,
489
- 'intra_class_mean': float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0,
490
- 'inter_class_mean': float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0,
491
- 'separation_score': float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) if intra_class_similarities and inter_class_similarities else 0.0,
492
- 'accuracy': nn_accuracy,
493
- 'centroid_accuracy': centroid_accuracy,
494
- }
495
 
496
- def compute_embedding_accuracy(self, embeddings, labels, similarities):
497
- """Compute classification accuracy using nearest neighbor"""
498
- correct_predictions = 0
499
- total_predictions = len(labels)
500
- for i in range(len(embeddings)):
501
- true_label = labels[i]
502
- similarities_row = similarities[i].copy()
503
- similarities_row[i] = -1
504
- nearest_neighbor_idx = int(np.argmax(similarities_row))
505
- predicted_label = labels[nearest_neighbor_idx]
506
- if predicted_label == true_label:
507
- correct_predictions += 1
508
- return correct_predictions / total_predictions if total_predictions > 0 else 0.0
509
-
510
- def compute_centroid_accuracy(self, embeddings, labels):
511
- """Compute classification accuracy using centroids - optimized vectorized version"""
512
- unique_labels = list(set(labels))
513
-
514
- # Compute centroids efficiently
515
- centroids = {}
516
- for label in unique_labels:
517
- label_mask = np.array(labels) == label
518
- centroids[label] = np.mean(embeddings[label_mask], axis=0)
519
-
520
- # Stack centroids for vectorized similarity computation
521
- centroid_matrix = np.vstack([centroids[label] for label in unique_labels])
522
-
523
- # Compute all similarities at once
524
- similarities = cosine_similarity(embeddings, centroid_matrix)
525
-
526
- # Get predicted labels
527
- predicted_indices = np.argmax(similarities, axis=1)
528
- predicted_labels = [unique_labels[idx] for idx in predicted_indices]
529
-
530
- # Compute accuracy
531
- correct_predictions = sum(pred == true for pred, true in zip(predicted_labels, labels))
532
- return correct_predictions / len(labels) if len(labels) > 0 else 0.0
533
-
534
- def predict_labels_from_embeddings(self, embeddings, labels):
535
- """Predict labels from embeddings using centroid-based classification - optimized vectorized version"""
536
- # Filter out None labels when computing centroids
537
- unique_labels = [l for l in set(labels) if l is not None]
538
- if len(unique_labels) == 0:
539
- # If no valid labels, return None for all predictions
540
- return [None] * len(embeddings)
541
-
542
- # Compute centroids efficiently
543
- centroids = {}
544
- for label in unique_labels:
545
- label_mask = np.array(labels) == label
546
- if np.any(label_mask):
547
- centroids[label] = np.mean(embeddings[label_mask], axis=0)
548
-
549
- # Stack centroids for vectorized similarity computation
550
- centroid_labels = list(centroids.keys())
551
- centroid_matrix = np.vstack([centroids[label] for label in centroid_labels])
552
-
553
- # Compute all similarities at once
554
- similarities = cosine_similarity(embeddings, centroid_matrix)
555
-
556
- # Get predicted labels
557
- predicted_indices = np.argmax(similarities, axis=1)
558
- predictions = [centroid_labels[idx] for idx in predicted_indices]
559
-
560
- return predictions
561
-
562
- def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"):
563
- """Create and plot confusion matrix"""
564
- unique_labels = sorted(list(set(true_labels + predicted_labels)))
565
- cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
566
- accuracy = accuracy_score(true_labels, predicted_labels)
567
- plt.figure(figsize=(12, 10))
568
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels)
569
- plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
570
- plt.ylabel(f'True {label_type}')
571
- plt.xlabel(f'Predicted {label_type}')
572
- plt.xticks(rotation=45)
573
- plt.yticks(rotation=0)
574
- plt.tight_layout()
575
- return plt.gcf(), accuracy, cm
576
 
577
  def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"):
578
  """
579
  Evaluate classification performance and create confusion matrix.
580
-
581
  Args:
582
  embeddings: Embeddings
583
  labels: True labels
584
  embedding_type: Type of embeddings for display
585
  label_type: Type of labels (Color)
586
- full_embeddings: Optional full 512-dim embeddings for ensemble (if None, uses only embeddings)
587
- ensemble_weight: Weight for embeddings in ensemble (0.0 = only full, 1.0 = only embeddings)
588
  """
589
-
590
- predictions = self.predict_labels_from_embeddings(embeddings, labels)
591
- title_suffix = ""
592
-
593
  # Filter out None values from labels and predictions
594
- valid_indices = [i for i, (label, pred) in enumerate(zip(labels, predictions))
595
  if label is not None and pred is not None]
596
-
597
  if len(valid_indices) == 0:
598
- print(f"⚠️ Warning: No valid labels/predictions found (all are None)")
599
  return {
600
  'accuracy': 0.0,
601
  'predictions': predictions,
@@ -603,12 +117,12 @@ class ColorEvaluator:
603
  'classification_report': None,
604
  'figure': None,
605
  }
606
-
607
  filtered_labels = [labels[i] for i in valid_indices]
608
  filtered_predictions = [predictions[i] for i in valid_indices]
609
-
610
  accuracy = accuracy_score(filtered_labels, filtered_predictions)
611
- fig, acc, cm = self.create_confusion_matrix(
612
  filtered_labels, filtered_predictions,
613
  embedding_type,
614
  label_type
@@ -631,27 +145,31 @@ class ColorEvaluator:
631
  print(f"Max samples: {max_samples}")
632
  print(f"{'='*60}")
633
 
634
- kaggle_dataset = load_kaggle_marqo_dataset(max_samples)
635
  if kaggle_dataset is None:
636
- print("Failed to load KAGL dataset")
637
  return None
638
 
639
  dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
640
-
641
  results = {}
642
 
643
- # ========== EXTRACT BASELINE EMBEDDINGS ==========
644
- print("\n📦 Extracting baseline embeddings...")
645
- text_full_embeddings, text_colors_full = self.extract_color_embeddings(dataloader, embedding_type='text', max_samples=max_samples)
646
- image_full_embeddings, image_colors_full = self.extract_color_embeddings(dataloader, embedding_type='image', max_samples=max_samples)
647
- text_color_metrics = self.compute_similarity_metrics(text_full_embeddings, text_colors_full)
 
 
 
 
648
  text_color_class = self.evaluate_classification_performance(
649
  text_full_embeddings, text_colors_full,
650
  "KAGL Marqo, text, color confusion matrix", "Color",
651
  )
652
  text_color_metrics.update(text_color_class)
653
  results['text_color'] = text_color_metrics
654
- image_color_metrics = self.compute_similarity_metrics(image_full_embeddings, image_colors_full)
655
  image_color_class = self.evaluate_classification_performance(
656
  image_full_embeddings, image_colors_full,
657
  "KAGL Marqo, image, color confusion matrix", "Color",
@@ -681,20 +199,22 @@ class ColorEvaluator:
681
  print(f"Max samples: {max_samples}")
682
  print(f"{'='*60}")
683
 
684
- local_dataset = load_local_validation_dataset(max_samples)
685
  dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
686
 
687
  results = {}
688
 
689
  # ========== COLOR EVALUATION ==========
690
- print("\n🎨 COLOR EVALUATION ")
691
  print("=" * 50)
692
-
693
  # Text color embeddings
694
- print("\n📝 Extracting text color embeddings...")
695
- text_color_embeddings, text_colors = self.extract_color_embeddings(dataloader, 'text', max_samples)
 
 
696
  print(f" Text color embeddings shape: {text_color_embeddings.shape}")
697
- text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors)
698
  text_color_class = self.evaluate_classification_performance(
699
  text_color_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color"
700
  )
@@ -705,10 +225,12 @@ class ColorEvaluator:
705
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
706
 
707
  # Image color embeddings
708
- print("\n🖼️ Extracting image color embeddings...")
709
- image_color_embeddings, image_colors = self.extract_color_embeddings(dataloader, 'image', max_samples)
 
 
710
  print(f" Image color embeddings shape: {image_color_embeddings.shape}")
711
- image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors)
712
  image_color_class = self.evaluate_classification_performance(
713
  image_color_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color"
714
  )
@@ -736,24 +258,27 @@ class ColorEvaluator:
736
  print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset")
737
  print(f"Max samples: {max_samples}")
738
  print(f"{'='*60}")
739
-
740
  # Load KAGL Marqo dataset
741
- kaggle_dataset = load_kaggle_marqo_dataset(max_samples)
742
  if kaggle_dataset is None:
743
- print("Failed to load KAGL dataset")
744
  return None
745
-
746
  # Create dataloader
747
  dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
748
-
749
  results = {}
750
-
751
  # Evaluate text embeddings
752
- print("\n📝 Extracting baseline text embeddings from KAGL Marqo...")
753
- text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
 
 
 
754
  print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
755
- text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
756
-
757
  text_color_classification = self.evaluate_classification_performance(
758
  text_embeddings, text_colors, "KAGL Marqo, text, color confusion matrix", "Color"
759
  )
@@ -761,17 +286,20 @@ class ColorEvaluator:
761
  results['text'] = {
762
  'color': text_color_metrics
763
  }
764
-
765
  # Clear memory
766
  del text_embeddings
767
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
768
-
769
  # Evaluate image embeddings
770
- print("\n🖼️ Extracting baseline image embeddings from KAGL Marqo...")
771
- image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
 
 
 
772
  print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
773
- image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
774
-
775
  image_color_classification = self.evaluate_classification_performance(
776
  image_embeddings, image_colors, "KAGL Marqo, image, color confusion matrix", "Color"
777
  )
@@ -779,11 +307,11 @@ class ColorEvaluator:
779
  results['image'] = {
780
  'color': image_color_metrics
781
  }
782
-
783
  # Clear memory
784
  del image_embeddings
785
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
786
-
787
  # ========== SAVE VISUALIZATIONS ==========
788
  os.makedirs(self.directory, exist_ok=True)
789
  for key in ['text', 'image']:
@@ -795,7 +323,7 @@ class ColorEvaluator:
795
  bbox_inches='tight',
796
  )
797
  plt.close(figure)
798
-
799
  return results
800
 
801
  def evaluate_baseline_local_validation(self, max_samples=5000):
@@ -804,24 +332,27 @@ class ColorEvaluator:
804
  print("Evaluating Baseline Fashion CLIP on Local Validation Dataset")
805
  print(f"Max samples: {max_samples}")
806
  print(f"{'='*60}")
807
-
808
  # Load local validation dataset
809
- local_dataset = load_local_validation_dataset(max_samples)
810
  if local_dataset is None:
811
- print("Failed to load local validation dataset")
812
  return None
813
-
814
  # Create dataloader
815
  dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
816
-
817
  results = {}
818
-
819
  # Evaluate text embeddings
820
- print("\n📝 Extracting baseline text embeddings from Local Validation...")
821
- text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
 
 
 
822
  print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
823
- text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
824
-
825
  text_color_classification = self.evaluate_classification_performance(
826
  text_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color"
827
  )
@@ -829,17 +360,20 @@ class ColorEvaluator:
829
  results['text'] = {
830
  'color': text_color_metrics
831
  }
832
-
833
  # Clear memory
834
  del text_embeddings
835
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
836
-
837
  # Evaluate image embeddings
838
- print("\n🖼️ Extracting baseline image embeddings from Local Validation...")
839
- image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
 
 
 
840
  print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
841
- image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
842
-
843
  image_color_classification = self.evaluate_classification_performance(
844
  image_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color"
845
  )
@@ -847,11 +381,11 @@ class ColorEvaluator:
847
  results['image'] = {
848
  'color': image_color_metrics
849
  }
850
-
851
  # Clear memory
852
  del image_embeddings
853
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
854
-
855
  # ========== SAVE VISUALIZATIONS ==========
856
  os.makedirs(self.directory, exist_ok=True)
857
  for key in ['text', 'image']:
@@ -863,27 +397,17 @@ class ColorEvaluator:
863
  bbox_inches='tight',
864
  )
865
  plt.close(figure)
866
-
867
  return results
868
 
869
  def analyze_baseline_vs_trained_performance(self, results_trained, results_baseline, dataset_name):
870
- """
871
- Analyse et explique pourquoi la baseline peut performer mieux que le modèle entraîné
872
-
873
- Raisons possibles:
874
- 1. Capacité dimensionnelle: Baseline utilise toutes les dimensions (512), modèle entraîné utilise seulement des sous-espaces (17 ou 64 dims)
875
- 2. Distribution shift: Dataset de validation différent de celui d'entraînement
876
- 3. Overfitting: Modèle trop spécialisé sur le dataset d'entraînement
877
- 4. Généralisation: Baseline pré-entraînée sur un dataset plus large et diversifié
878
- 5. Perte d'information: Spécialisation excessive peut causer perte d'information générale
879
- """
880
  print(f"\n{'='*60}")
881
- print(f"📊 ANALYSE: Baseline vs Modèle Entraîné - {dataset_name}")
882
  print(f"{'='*60}")
883
-
884
- # Comparer les métriques pour chaque type d'embedding
885
  comparisons = []
886
-
887
  # Text Color
888
  trained_color_text_acc = results_trained.get('text_color', {}).get('accuracy', 0)
889
  baseline_color_text_acc = results_baseline.get('text', {}).get('color', {}).get('accuracy', 0)
@@ -894,10 +418,10 @@ class ColorEvaluator:
894
  'trained': trained_color_text_acc,
895
  'baseline': baseline_color_text_acc,
896
  'diff': diff,
897
- 'trained_dims': '0-15 (16 dims)',
898
- 'baseline_dims': 'All dimensions (512 dims)'
899
  })
900
-
901
  # Image Color
902
  trained_color_img_acc = results_trained.get('image_color', {}).get('accuracy', 0)
903
  baseline_color_img_acc = results_baseline.get('image', {}).get('color', {}).get('accuracy', 0)
@@ -908,8 +432,8 @@ class ColorEvaluator:
908
  'trained': trained_color_img_acc,
909
  'baseline': baseline_color_img_acc,
910
  'diff': diff,
911
- 'trained_dims': '0-15 (16 dims)',
912
- 'baseline_dims': 'All dimensions (512 dims)'
913
  })
914
 
915
  return comparisons
@@ -924,39 +448,11 @@ if __name__ == "__main__":
924
  max_samples = 10000
925
  local_max_samples = 10000
926
 
927
- evaluator = ColorEvaluator(device=device, directory=directory, repo_id="Leacb4/gap-clip")
928
-
929
- # # Evaluate KAGL Marqo (skipped — CMs already generated)
930
- # print("\n" + "="*60)
931
- # print("🚀 Starting evaluation of KAGL Marqo with Color embeddings")
932
- # print("="*60)
933
- # results_kaggle = evaluator.evaluate_kaggle_marqo(max_samples=max_samples)
934
- #
935
- # print(f"\n{'='*60}")
936
- # print("KAGL MARQO EVALUATION SUMMARY")
937
- # print(f"{'='*60}")
938
- #
939
- # print("\n🎨 COLOR CLASSIFICATION RESULTS:")
940
- # print(f" Text - NN Acc: {results_kaggle['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_color']['separation_score']:.4f}")
941
- # print(f" Image - NN Acc: {results_kaggle['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_color']['separation_score']:.4f}")
942
- #
943
- # # Evaluate Baseline Fashion CLIP on KAGL Marqo
944
- # print("\n" + "="*60)
945
- # print("🚀 Starting evaluation of Baseline Fashion CLIP on KAGL Marqo")
946
- # print("="*60)
947
- # results_baseline_kaggle = evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples)
948
- #
949
- # print(f"\n{'='*60}")
950
- # print("BASELINE KAGL MARQO EVALUATION SUMMARY")
951
- # print(f"{'='*60}")
952
- #
953
- # print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):")
954
- # print(f" Text - NN Acc: {results_baseline_kaggle['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['color']['separation_score']:.4f}")
955
- # print(f" Image - NN Acc: {results_baseline_kaggle['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['color']['separation_score']:.4f}")
956
 
957
  # Evaluate Local Validation Dataset
958
  print("\n" + "="*60)
959
- print("🚀 Starting evaluation of Local Validation Dataset with Color embeddings")
960
  print("="*60)
961
  results_local = evaluator.evaluate_local_validation(max_samples=local_max_samples)
962
 
@@ -964,25 +460,25 @@ if __name__ == "__main__":
964
  print(f"\n{'='*60}")
965
  print("LOCAL VALIDATION DATASET EVALUATION SUMMARY")
966
  print(f"{'='*60}")
967
-
968
- print("\n🎨 COLOR CLASSIFICATION RESULTS:")
969
  print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}")
970
  print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}")
971
-
972
  # Evaluate Baseline Fashion CLIP on Local Validation
973
  print("\n" + "="*60)
974
- print("🚀 Starting evaluation of Baseline Fashion CLIP on Local Validation")
975
  print("="*60)
976
  results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=local_max_samples)
977
-
978
  if results_baseline_local is not None:
979
  print(f"\n{'='*60}")
980
  print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY")
981
  print(f"{'='*60}")
982
-
983
- print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):")
984
  print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}")
985
  print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}")
986
-
987
-
988
- print(f"\n✅ Evaluation completed! Check '{directory}/' for visualization files.")
 
20
  """
21
 
22
  import os
 
 
 
23
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
24
  import sys
25
  from pathlib import Path
26
 
27
  import torch
 
 
28
  import matplotlib.pyplot as plt
29
+ from sklearn.metrics import classification_report, accuracy_score
30
+ from torch.utils.data import DataLoader
 
 
 
 
 
 
 
 
31
  import warnings
32
  warnings.filterwarnings('ignore')
 
 
33
 
34
  # Ensure project root is importable when running this file directly.
35
  PROJECT_ROOT = Path(__file__).resolve().parent.parent
 
39
  from config import (
40
  color_model_path,
41
  color_emb_dim,
42
+ main_emb_dim,
 
 
 
43
  )
44
+ from utils.datasets import (
45
+ load_kaggle_marqo_dataset,
46
+ load_local_validation_dataset,
47
+ collate_fn_filter_none,
48
+ )
49
+ from utils.embeddings import extract_clip_embeddings, extract_color_model_embeddings
50
+ from utils.metrics import (
51
+ compute_similarity_metrics,
52
+ predict_labels_from_embeddings,
53
+ create_confusion_matrix,
54
+ )
55
+ from utils.model_loader import load_color_model, load_baseline_fashion_clip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  class ColorEvaluator:
 
62
  self,
63
  device='mps',
64
  directory="figures/confusion_matrices/cm_color",
65
+ baseline_model=None,
66
+ baseline_processor=None,
67
+ color_model=None,
68
+ kaggle_raw_df=None,
69
+ local_raw_df=None,
70
  ):
71
  self.device = torch.device(device)
72
  self.directory = directory
73
  self.color_emb_dim = color_emb_dim
74
+ self.main_emb_dim = main_emb_dim
75
+ self.kaggle_raw_df = kaggle_raw_df
76
+ self.local_raw_df = local_raw_df
77
  os.makedirs(self.directory, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # Load baseline Fashion CLIP model (or reuse pre-loaded)
80
+ if baseline_model is not None and baseline_processor is not None:
81
+ self.baseline_model = baseline_model
82
+ self.baseline_processor = baseline_processor
83
+ else:
84
+ print("Loading baseline Fashion CLIP model...")
85
+ self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device)
86
+ print("Baseline Fashion CLIP model loaded successfully")
 
87
 
88
+ # Load specialized color model (or reuse pre-loaded)
89
+ if color_model is not None:
90
+ self.color_model = color_model
91
+ else:
92
+ self.color_model, _ = load_color_model(color_model_path, self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"):
95
  """
96
  Evaluate classification performance and create confusion matrix.
97
+
98
  Args:
99
  embeddings: Embeddings
100
  labels: True labels
101
  embedding_type: Type of embeddings for display
102
  label_type: Type of labels (Color)
 
 
103
  """
104
+
105
+ predictions = predict_labels_from_embeddings(embeddings, labels)
106
+
 
107
  # Filter out None values from labels and predictions
108
+ valid_indices = [i for i, (label, pred) in enumerate(zip(labels, predictions))
109
  if label is not None and pred is not None]
110
+
111
  if len(valid_indices) == 0:
112
+ print(f"Warning: No valid labels/predictions found (all are None)")
113
  return {
114
  'accuracy': 0.0,
115
  'predictions': predictions,
 
117
  'classification_report': None,
118
  'figure': None,
119
  }
120
+
121
  filtered_labels = [labels[i] for i in valid_indices]
122
  filtered_predictions = [predictions[i] for i in valid_indices]
123
+
124
  accuracy = accuracy_score(filtered_labels, filtered_predictions)
125
+ fig, _, cm = create_confusion_matrix(
126
  filtered_labels, filtered_predictions,
127
  embedding_type,
128
  label_type
 
145
  print(f"Max samples: {max_samples}")
146
  print(f"{'='*60}")
147
 
148
+ kaggle_dataset = load_kaggle_marqo_dataset(max_samples, raw_df=self.kaggle_raw_df)
149
  if kaggle_dataset is None:
150
+ print("Failed to load KAGL dataset")
151
  return None
152
 
153
  dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
154
+
155
  results = {}
156
 
157
+ # ========== EXTRACT COLOR MODEL EMBEDDINGS ==========
158
+ print("\nExtracting color model embeddings...")
159
+ text_full_embeddings, text_colors_full = extract_color_model_embeddings(
160
+ self.color_model, dataloader, self.device, embedding_type='text', max_samples=max_samples
161
+ )
162
+ image_full_embeddings, image_colors_full = extract_color_model_embeddings(
163
+ self.color_model, dataloader, self.device, embedding_type='image', max_samples=max_samples
164
+ )
165
+ text_color_metrics = compute_similarity_metrics(text_full_embeddings, text_colors_full)
166
  text_color_class = self.evaluate_classification_performance(
167
  text_full_embeddings, text_colors_full,
168
  "KAGL Marqo, text, color confusion matrix", "Color",
169
  )
170
  text_color_metrics.update(text_color_class)
171
  results['text_color'] = text_color_metrics
172
+ image_color_metrics = compute_similarity_metrics(image_full_embeddings, image_colors_full)
173
  image_color_class = self.evaluate_classification_performance(
174
  image_full_embeddings, image_colors_full,
175
  "KAGL Marqo, image, color confusion matrix", "Color",
 
199
  print(f"Max samples: {max_samples}")
200
  print(f"{'='*60}")
201
 
202
+ local_dataset = load_local_validation_dataset(max_samples, raw_df=self.local_raw_df)
203
  dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
204
 
205
  results = {}
206
 
207
  # ========== COLOR EVALUATION ==========
208
+ print("\nCOLOR EVALUATION")
209
  print("=" * 50)
210
+
211
  # Text color embeddings
212
+ print("\nExtracting text color embeddings...")
213
+ text_color_embeddings, text_colors = extract_color_model_embeddings(
214
+ self.color_model, dataloader, self.device, embedding_type='text', max_samples=max_samples
215
+ )
216
  print(f" Text color embeddings shape: {text_color_embeddings.shape}")
217
+ text_color_metrics = compute_similarity_metrics(text_color_embeddings, text_colors)
218
  text_color_class = self.evaluate_classification_performance(
219
  text_color_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color"
220
  )
 
225
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
226
 
227
  # Image color embeddings
228
+ print("\nExtracting image color embeddings...")
229
+ image_color_embeddings, image_colors = extract_color_model_embeddings(
230
+ self.color_model, dataloader, self.device, embedding_type='image', max_samples=max_samples
231
+ )
232
  print(f" Image color embeddings shape: {image_color_embeddings.shape}")
233
+ image_color_metrics = compute_similarity_metrics(image_color_embeddings, image_colors)
234
  image_color_class = self.evaluate_classification_performance(
235
  image_color_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color"
236
  )
 
258
  print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset")
259
  print(f"Max samples: {max_samples}")
260
  print(f"{'='*60}")
261
+
262
  # Load KAGL Marqo dataset
263
+ kaggle_dataset = load_kaggle_marqo_dataset(max_samples, raw_df=self.kaggle_raw_df)
264
  if kaggle_dataset is None:
265
+ print("Failed to load KAGL dataset")
266
  return None
267
+
268
  # Create dataloader
269
  dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
270
+
271
  results = {}
272
+
273
  # Evaluate text embeddings
274
+ print("\nExtracting baseline text embeddings from KAGL Marqo...")
275
+ text_embeddings, text_colors, _ = extract_clip_embeddings(
276
+ self.baseline_model, self.baseline_processor, dataloader, self.device,
277
+ embedding_type='text', max_samples=max_samples
278
+ )
279
  print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
280
+ text_color_metrics = compute_similarity_metrics(text_embeddings, text_colors)
281
+
282
  text_color_classification = self.evaluate_classification_performance(
283
  text_embeddings, text_colors, "KAGL Marqo, text, color confusion matrix", "Color"
284
  )
 
286
  results['text'] = {
287
  'color': text_color_metrics
288
  }
289
+
290
  # Clear memory
291
  del text_embeddings
292
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
293
+
294
  # Evaluate image embeddings
295
+ print("\nExtracting baseline image embeddings from KAGL Marqo...")
296
+ image_embeddings, image_colors, _ = extract_clip_embeddings(
297
+ self.baseline_model, self.baseline_processor, dataloader, self.device,
298
+ embedding_type='image', max_samples=max_samples
299
+ )
300
  print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
301
+ image_color_metrics = compute_similarity_metrics(image_embeddings, image_colors)
302
+
303
  image_color_classification = self.evaluate_classification_performance(
304
  image_embeddings, image_colors, "KAGL Marqo, image, color confusion matrix", "Color"
305
  )
 
307
  results['image'] = {
308
  'color': image_color_metrics
309
  }
310
+
311
  # Clear memory
312
  del image_embeddings
313
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
314
+
315
  # ========== SAVE VISUALIZATIONS ==========
316
  os.makedirs(self.directory, exist_ok=True)
317
  for key in ['text', 'image']:
 
323
  bbox_inches='tight',
324
  )
325
  plt.close(figure)
326
+
327
  return results
328
 
329
  def evaluate_baseline_local_validation(self, max_samples=5000):
 
332
  print("Evaluating Baseline Fashion CLIP on Local Validation Dataset")
333
  print(f"Max samples: {max_samples}")
334
  print(f"{'='*60}")
335
+
336
  # Load local validation dataset
337
+ local_dataset = load_local_validation_dataset(max_samples, raw_df=self.local_raw_df)
338
  if local_dataset is None:
339
+ print("Failed to load local validation dataset")
340
  return None
341
+
342
  # Create dataloader
343
  dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
344
+
345
  results = {}
346
+
347
  # Evaluate text embeddings
348
+ print("\nExtracting baseline text embeddings from Local Validation...")
349
+ text_embeddings, text_colors, _ = extract_clip_embeddings(
350
+ self.baseline_model, self.baseline_processor, dataloader, self.device,
351
+ embedding_type='text', max_samples=max_samples
352
+ )
353
  print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
354
+ text_color_metrics = compute_similarity_metrics(text_embeddings, text_colors)
355
+
356
  text_color_classification = self.evaluate_classification_performance(
357
  text_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color"
358
  )
 
360
  results['text'] = {
361
  'color': text_color_metrics
362
  }
363
+
364
  # Clear memory
365
  del text_embeddings
366
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
367
+
368
  # Evaluate image embeddings
369
+ print("\nExtracting baseline image embeddings from Local Validation...")
370
+ image_embeddings, image_colors, _ = extract_clip_embeddings(
371
+ self.baseline_model, self.baseline_processor, dataloader, self.device,
372
+ embedding_type='image', max_samples=max_samples
373
+ )
374
  print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
375
+ image_color_metrics = compute_similarity_metrics(image_embeddings, image_colors)
376
+
377
  image_color_classification = self.evaluate_classification_performance(
378
  image_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color"
379
  )
 
381
  results['image'] = {
382
  'color': image_color_metrics
383
  }
384
+
385
  # Clear memory
386
  del image_embeddings
387
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
388
+
389
  # ========== SAVE VISUALIZATIONS ==========
390
  os.makedirs(self.directory, exist_ok=True)
391
  for key in ['text', 'image']:
 
397
  bbox_inches='tight',
398
  )
399
  plt.close(figure)
400
+
401
  return results
402
 
403
  def analyze_baseline_vs_trained_performance(self, results_trained, results_baseline, dataset_name):
404
+ """Analyse baseline vs trained model performance."""
 
 
 
 
 
 
 
 
 
405
  print(f"\n{'='*60}")
406
+ print(f"ANALYSE: Baseline vs Trained - {dataset_name}")
407
  print(f"{'='*60}")
408
+
 
409
  comparisons = []
410
+
411
  # Text Color
412
  trained_color_text_acc = results_trained.get('text_color', {}).get('accuracy', 0)
413
  baseline_color_text_acc = results_baseline.get('text', {}).get('color', {}).get('accuracy', 0)
 
418
  'trained': trained_color_text_acc,
419
  'baseline': baseline_color_text_acc,
420
  'diff': diff,
421
+ 'trained_dims': f'0-{self.color_emb_dim - 1} ({self.color_emb_dim} dims)',
422
+ 'baseline_dims': f'All dimensions ({self.main_emb_dim} dims)'
423
  })
424
+
425
  # Image Color
426
  trained_color_img_acc = results_trained.get('image_color', {}).get('accuracy', 0)
427
  baseline_color_img_acc = results_baseline.get('image', {}).get('color', {}).get('accuracy', 0)
 
432
  'trained': trained_color_img_acc,
433
  'baseline': baseline_color_img_acc,
434
  'diff': diff,
435
+ 'trained_dims': f'0-{self.color_emb_dim - 1} ({self.color_emb_dim} dims)',
436
+ 'baseline_dims': f'All dimensions ({self.main_emb_dim} dims)'
437
  })
438
 
439
  return comparisons
 
448
  max_samples = 10000
449
  local_max_samples = 10000
450
 
451
+ evaluator = ColorEvaluator(device=device, directory=directory)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
  # Evaluate Local Validation Dataset
454
  print("\n" + "="*60)
455
+ print("Starting evaluation of Local Validation Dataset with Color embeddings")
456
  print("="*60)
457
  results_local = evaluator.evaluate_local_validation(max_samples=local_max_samples)
458
 
 
460
  print(f"\n{'='*60}")
461
  print("LOCAL VALIDATION DATASET EVALUATION SUMMARY")
462
  print(f"{'='*60}")
463
+
464
+ print("\nCOLOR CLASSIFICATION RESULTS:")
465
  print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}")
466
  print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}")
467
+
468
  # Evaluate Baseline Fashion CLIP on Local Validation
469
  print("\n" + "="*60)
470
+ print("Starting evaluation of Baseline Fashion CLIP on Local Validation")
471
  print("="*60)
472
  results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=local_max_samples)
473
+
474
  if results_baseline_local is not None:
475
  print(f"\n{'='*60}")
476
  print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY")
477
  print(f"{'='*60}")
478
+
479
+ print("\nCOLOR CLASSIFICATION RESULTS (Baseline):")
480
  print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}")
481
  print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}")
482
+
483
+
484
+ print(f"\nEvaluation completed! Check '{directory}/' for visualization files.")