Leacb4 commited on
Commit
463ac82
·
verified ·
1 Parent(s): ec5c397

Upload training/main_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/main_model.py +75 -122
training/main_model.py CHANGED
@@ -22,7 +22,6 @@ import matplotlib.pyplot as plt
22
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
23
  import warnings
24
  from tqdm import tqdm
25
- import json
26
  import config
27
 
28
  # Suppress warnings
@@ -33,9 +32,9 @@ warnings.filterwarnings("ignore", category=UserWarning)
33
  # Loss Functions
34
  # -------------------------------
35
 
36
- def enhanced_contrastive_loss(text_features, image_features, attribute_features,
37
  color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3,
38
- reference_text_features=None, reference_weight=0.1):
39
  """
40
  Enhanced contrastive loss with direct alignment between color/hierarchy models and main model.
41
 
@@ -104,40 +103,34 @@ def enhanced_contrastive_loss(text_features, image_features, attribute_features,
104
  main_hierarchy_text_norm = F.normalize(main_hierarchy_text, dim=-1)
105
  main_hierarchy_image_norm = F.normalize(main_hierarchy_image, dim=-1)
106
 
107
- # Color alignment loss using MSE and cosine similarity
108
- color_text_alignment_loss = F.mse_loss(main_color_text_norm, color_embeddings_norm)
109
- color_image_alignment_loss = F.mse_loss(main_color_image_norm, color_embeddings_norm)
110
  color_text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean()
111
  color_image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean()
112
-
113
- # Color alignment loss
114
- color_alignment_loss = (
115
- color_text_alignment_loss + color_image_alignment_loss +
116
- color_text_cosine_loss + color_image_cosine_loss
117
- ) / 4
118
-
119
- # Hierarchy alignment loss using MSE and cosine similarity
120
- hierarchy_text_alignment_loss = F.mse_loss(main_hierarchy_text_norm, hierarchy_embeddings_norm)
121
- hierarchy_image_alignment_loss = F.mse_loss(main_hierarchy_image_norm, hierarchy_embeddings_norm)
122
  hierarchy_text_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_text_norm, hierarchy_embeddings_norm).mean()
123
  hierarchy_image_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_image_norm, hierarchy_embeddings_norm).mean()
124
-
125
- # Hierarchy alignment loss
126
- hierarchy_alignment_loss = (
127
- hierarchy_text_alignment_loss + hierarchy_image_alignment_loss +
128
- hierarchy_text_cosine_loss + hierarchy_image_cosine_loss
129
- ) / 4
130
 
131
  # Combined alignment loss
132
  alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2
133
 
134
- # Optional guidance to keep text space close to base CLIP (helps cross-domain generalization)
135
  reference_loss = 0.0
136
  if reference_text_features is not None:
137
- reference_loss = F.mse_loss(
138
  F.normalize(text_features, dim=-1),
139
  F.normalize(reference_text_features, dim=-1)
140
  )
 
 
 
 
 
 
 
 
141
 
142
  # Combine losses
143
  total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss
@@ -148,12 +141,8 @@ def enhanced_contrastive_loss(text_features, image_features, attribute_features,
148
  'original_loss': original_loss.item(),
149
  'alignment_loss': alignment_loss.item(),
150
  'reference_loss': reference_loss if isinstance(reference_loss, float) else reference_loss.item(),
151
- 'color_text_alignment': color_text_alignment_loss.item(),
152
- 'color_image_alignment': color_image_alignment_loss.item(),
153
  'color_text_cosine': color_text_cosine_loss.item(),
154
  'color_image_cosine': color_image_cosine_loss.item(),
155
- 'hierarchy_text_alignment': hierarchy_text_alignment_loss.item(),
156
- 'hierarchy_image_alignment': hierarchy_image_alignment_loss.item(),
157
  'hierarchy_text_cosine': hierarchy_text_cosine_loss.item(),
158
  'hierarchy_image_cosine': hierarchy_image_cosine_loss.item()
159
  }
@@ -194,12 +183,8 @@ def train_one_epoch(model, train_loader, optimizer, feature_models, color_model,
194
  'original_loss': 0.0,
195
  'alignment_loss': 0.0,
196
  'reference_loss': 0.0,
197
- 'color_text_alignment': 0.0,
198
- 'color_image_alignment': 0.0,
199
  'color_text_cosine': 0.0,
200
  'color_image_cosine': 0.0,
201
- 'hierarchy_text_alignment': 0.0,
202
- 'hierarchy_image_alignment': 0.0,
203
  'hierarchy_text_cosine': 0.0,
204
  'hierarchy_image_cosine': 0.0
205
  }
@@ -216,19 +201,21 @@ def train_one_epoch(model, train_loader, optimizer, feature_models, color_model,
216
  text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
217
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
218
 
219
- # Optional reference text features to keep close to base CLIP
220
  reference_text_features = None
 
221
  if reference_model is not None:
222
  with torch.no_grad():
223
  reference_text_features = reference_model.get_text_features(**text_inputs)
224
-
 
225
  # Forward pass
226
  optimizer.zero_grad()
227
  outputs = model(**text_inputs, pixel_values=images)
228
-
229
  text_features = outputs.text_embeds
230
  image_features = outputs.image_embeds
231
-
232
  # Get feature embeddings
233
  if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
234
  color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
@@ -236,12 +223,14 @@ def train_one_epoch(model, train_loader, optimizer, feature_models, color_model,
236
  color_features = feature_models[config.color_column].get_text_embeddings(colors)
237
  hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
238
  concat_features = torch.cat((color_features, hierarchy_features), dim=1)
239
-
240
  # Calculate enhanced loss with hierarchy alignment
241
  loss, metrics = enhanced_contrastive_loss(
242
- text_features, image_features, concat_features,
243
  color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight,
244
- reference_text_features=reference_text_features, reference_weight=reference_weight
 
 
245
  )
246
 
247
  # Backward pass
@@ -306,17 +295,19 @@ def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, t
306
  text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
307
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
308
 
309
- # Optional reference text features
310
  reference_text_features = None
 
311
  if reference_model is not None:
312
  reference_text_features = reference_model.get_text_features(**text_inputs)
313
-
 
314
  # Forward pass
315
  outputs = model(**text_inputs, pixel_values=images)
316
-
317
  text_features = outputs.text_embeds
318
  image_features = outputs.image_embeds
319
-
320
  # Get feature embeddings
321
  if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
322
  color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
@@ -324,13 +315,15 @@ def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, t
324
  color_features = feature_models[config.color_column].get_text_embeddings(colors)
325
  hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
326
  concat_features = torch.cat((color_features, hierarchy_features), dim=1)
327
-
328
  # Calculate loss with all required arguments
329
  loss, metrics = enhanced_contrastive_loss(
330
  text_features, image_features, concat_features,
331
- color_model, hierarchy_model, colors, hierarchy,
332
  temperature, alignment_weight,
333
- reference_text_features=reference_text_features, reference_weight=reference_weight
 
 
334
  )
335
 
336
  total_loss += loss.item()
@@ -438,69 +431,28 @@ class CustomDataset(Dataset):
438
  def load_models():
439
  """
440
  Load color and hierarchy models from checkpoints.
441
-
442
- This function loads the pre-trained color and hierarchy models along with
443
- their tokenizers and extractors, and prepares them for use in main model training.
444
-
445
  Returns:
446
  Dictionary mapping model names to model instances:
447
  - 'color': ColorCLIP model instance
448
- - 'hierarchy': Hierarchy model instance
449
  """
450
- from training.color_model import ColorCLIP, Tokenizer
451
- from training.hierarchy_model import Model, HierarchyExtractor
452
-
453
- # Initialize tokenizer first
454
- tokenizer = Tokenizer()
455
-
456
- # Load vocabulary if available
457
- if os.path.exists(config.tokeniser_path):
458
- with open(config.tokeniser_path, 'r') as f:
459
- vocab_dict = json.load(f)
460
- tokenizer.load_vocab(vocab_dict)
461
- print(f"Tokenizer vocabulary loaded from {config.tokeniser_path}")
462
- else:
463
- print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.")
464
-
465
- # Load trained model first to get correct vocab size
466
- checkpoint = torch.load(config.color_model_path, map_location=config.device)
467
-
468
- # Extract vocab size from the checkpoint's embedding layer
469
- vocab_size_from_checkpoint = checkpoint['text_encoder.embedding.weight'].shape[0]
470
- print(f"Vocab size from checkpoint: {vocab_size_from_checkpoint}")
471
- print(f"Vocab size from tokenizer: {tokenizer.counter}")
472
-
473
- # Use the larger of the two to ensure compatibility
474
- vocab_size = max(vocab_size_from_checkpoint, tokenizer.counter)
475
-
476
- # Initialize model with correct vocab size
477
- color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=config.color_emb_dim).to(config.device)
478
- color_model.tokenizer = tokenizer
479
-
480
- # Load the checkpoint
481
- color_model.load_state_dict(checkpoint)
482
- print(f"Color model loaded from {config.color_model_path}")
483
-
484
  color_model.eval()
485
  color_model.name = config.color_column
486
 
487
- # Load hierarchy model
488
- hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=config.device)
489
- hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
490
- hierarchy_model = Model(
491
- num_hierarchy_classes=len(hierarchy_classes),
492
- embed_dim=config.hierarchy_emb_dim
493
- ).to(config.device)
494
- hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
495
-
496
- # Set up hierarchy extractor
497
- hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
498
- hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
499
  hierarchy_model.eval()
500
  hierarchy_model.name = config.hierarchy_column
501
 
502
  feature_models = {model.name: model for model in [color_model, hierarchy_model]}
503
-
504
  return feature_models
505
 
506
  # -------------------------------
@@ -683,13 +635,14 @@ def train_model(model, train_loader, val_loader, feature_models, device,
683
  plt.grid(True, alpha=0.3)
684
 
685
  plt.tight_layout()
686
- plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
 
687
  plt.close()
688
-
689
  print(f"\nTraining completed!")
690
  print(f"Best validation loss: {best_val_loss:.4f}")
691
  print(f"Final model saved to: {save_path}")
692
- print(f"Training curves saved to: training_curves.png")
693
 
694
  return train_losses, val_losses
695
 
@@ -699,43 +652,43 @@ def train_model(model, train_loader, val_loader, feature_models, device,
699
 
700
  def main():
701
  print("="*80)
702
- print("🚀 Training of the model with alignement color and hierarchy")
703
  print("="*80)
704
 
705
- # Configuration (optimized to reduce overfitting)
706
- num_epochs = 20
707
- learning_rate = 1.5e-5 # Reduced slightly to prevent overfitting
708
- temperature = 0.09 # Increased from 0.07 for softer contrastive learning
709
- alignment_weight = 0.2 # Reduced from 0.3 to prevent overfitting on alignment
710
- weight_decay = 5e-4 # Increased weight decay for stronger regularization
711
- batch_size = 32
712
- subset_size = 20000 # Increased dataset size for better generalization
713
-
 
714
  # Load the data
715
  print(f"\n📂 Loading the data...")
716
  df = pd.read_csv(config.local_dataset_path)
717
  print(f" Data downloaded: {len(df)} samples")
718
-
719
  # filter the rows with NaN values
720
  df_clean = df.dropna(subset=[config.column_local_image_path])
 
721
  print(f" After filtering NaN: {len(df_clean)} samples")
722
-
723
  # Creation of datasets
724
  dataset = CustomDataset(df_clean)
725
-
726
- # Creation of a subset for a faster training
727
- print(f"\n📊 Creation of a subset of {subset_size} samples...")
728
  subset_size = min(subset_size, len(dataset))
729
  train_size = int(0.8 * subset_size)
730
  val_size = subset_size - train_size
731
-
732
- # Creation of a subset with random indexes but reproductibles
733
  np.random.seed(42)
734
  subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
735
  subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
736
-
737
  train_dataset, val_dataset = random_split(
738
- subset_dataset,
739
  [train_size, val_size],
740
  generator=torch.Generator().manual_seed(42)
741
  )
@@ -798,13 +751,13 @@ def main():
798
  color_alignment_model=feature_models[config.color_column],
799
  weight_decay=weight_decay,
800
  reference_model=reference_clip,
801
- reference_weight=0.1
802
  )
803
 
804
  print("\n" + "="*80)
805
  print("✅ Training finished!")
806
  print(f" Model saved: {config.main_model_path}")
807
- print(f" Training curves: training_curves.png")
808
  print("\n📊 Final results:")
809
  print(f" Last train loss: {train_losses[-1]:.4f}")
810
  print(f" Last validation loss: {val_losses[-1]:.4f}")
 
22
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
23
  import warnings
24
  from tqdm import tqdm
 
25
  import config
26
 
27
  # Suppress warnings
 
32
  # Loss Functions
33
  # -------------------------------
34
 
35
+ def enhanced_contrastive_loss(text_features, image_features, attribute_features,
36
  color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3,
37
+ reference_text_features=None, reference_image_features=None, reference_weight=0.1):
38
  """
39
  Enhanced contrastive loss with direct alignment between color/hierarchy models and main model.
40
 
 
103
  main_hierarchy_text_norm = F.normalize(main_hierarchy_text, dim=-1)
104
  main_hierarchy_image_norm = F.normalize(main_hierarchy_image, dim=-1)
105
 
106
+ # Color alignment loss (cosine-only: more natural for normalized embeddings)
 
 
107
  color_text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean()
108
  color_image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean()
109
+ color_alignment_loss = (color_text_cosine_loss + color_image_cosine_loss) / 2
110
+
111
+ # Hierarchy alignment loss (cosine-only)
 
 
 
 
 
 
 
112
  hierarchy_text_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_text_norm, hierarchy_embeddings_norm).mean()
113
  hierarchy_image_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_image_norm, hierarchy_embeddings_norm).mean()
114
+ hierarchy_alignment_loss = (hierarchy_text_cosine_loss + hierarchy_image_cosine_loss) / 2
 
 
 
 
 
115
 
116
  # Combined alignment loss
117
  alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2
118
 
119
+ # Reference loss to keep embeddings close to base CLIP (preserves zero-shot capability)
120
  reference_loss = 0.0
121
  if reference_text_features is not None:
122
+ text_ref_loss = F.mse_loss(
123
  F.normalize(text_features, dim=-1),
124
  F.normalize(reference_text_features, dim=-1)
125
  )
126
+ if reference_image_features is not None:
127
+ image_ref_loss = F.mse_loss(
128
+ F.normalize(image_features, dim=-1),
129
+ F.normalize(reference_image_features, dim=-1)
130
+ )
131
+ reference_loss = (text_ref_loss + image_ref_loss) / 2
132
+ else:
133
+ reference_loss = text_ref_loss
134
 
135
  # Combine losses
136
  total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss
 
141
  'original_loss': original_loss.item(),
142
  'alignment_loss': alignment_loss.item(),
143
  'reference_loss': reference_loss if isinstance(reference_loss, float) else reference_loss.item(),
 
 
144
  'color_text_cosine': color_text_cosine_loss.item(),
145
  'color_image_cosine': color_image_cosine_loss.item(),
 
 
146
  'hierarchy_text_cosine': hierarchy_text_cosine_loss.item(),
147
  'hierarchy_image_cosine': hierarchy_image_cosine_loss.item()
148
  }
 
183
  'original_loss': 0.0,
184
  'alignment_loss': 0.0,
185
  'reference_loss': 0.0,
 
 
186
  'color_text_cosine': 0.0,
187
  'color_image_cosine': 0.0,
 
 
188
  'hierarchy_text_cosine': 0.0,
189
  'hierarchy_image_cosine': 0.0
190
  }
 
201
  text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
202
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
203
 
204
+ # Reference features to keep embeddings close to base CLIP
205
  reference_text_features = None
206
+ reference_image_features = None
207
  if reference_model is not None:
208
  with torch.no_grad():
209
  reference_text_features = reference_model.get_text_features(**text_inputs)
210
+ reference_image_features = reference_model.get_image_features(pixel_values=images)
211
+
212
  # Forward pass
213
  optimizer.zero_grad()
214
  outputs = model(**text_inputs, pixel_values=images)
215
+
216
  text_features = outputs.text_embeds
217
  image_features = outputs.image_embeds
218
+
219
  # Get feature embeddings
220
  if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
221
  color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
 
223
  color_features = feature_models[config.color_column].get_text_embeddings(colors)
224
  hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
225
  concat_features = torch.cat((color_features, hierarchy_features), dim=1)
226
+
227
  # Calculate enhanced loss with hierarchy alignment
228
  loss, metrics = enhanced_contrastive_loss(
229
+ text_features, image_features, concat_features,
230
  color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight,
231
+ reference_text_features=reference_text_features,
232
+ reference_image_features=reference_image_features,
233
+ reference_weight=reference_weight
234
  )
235
 
236
  # Backward pass
 
295
  text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
296
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
297
 
298
+ # Reference features to keep embeddings close to base CLIP
299
  reference_text_features = None
300
+ reference_image_features = None
301
  if reference_model is not None:
302
  reference_text_features = reference_model.get_text_features(**text_inputs)
303
+ reference_image_features = reference_model.get_image_features(pixel_values=images)
304
+
305
  # Forward pass
306
  outputs = model(**text_inputs, pixel_values=images)
307
+
308
  text_features = outputs.text_embeds
309
  image_features = outputs.image_embeds
310
+
311
  # Get feature embeddings
312
  if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
313
  color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
 
315
  color_features = feature_models[config.color_column].get_text_embeddings(colors)
316
  hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
317
  concat_features = torch.cat((color_features, hierarchy_features), dim=1)
318
+
319
  # Calculate loss with all required arguments
320
  loss, metrics = enhanced_contrastive_loss(
321
  text_features, image_features, concat_features,
322
+ color_model, hierarchy_model, colors, hierarchy,
323
  temperature, alignment_weight,
324
+ reference_text_features=reference_text_features,
325
+ reference_image_features=reference_image_features,
326
+ reference_weight=reference_weight
327
  )
328
 
329
  total_loss += loss.item()
 
431
  def load_models():
432
  """
433
  Load color and hierarchy models from checkpoints.
434
+
 
 
 
435
  Returns:
436
  Dictionary mapping model names to model instances:
437
  - 'color': ColorCLIP model instance
438
+ - 'hierarchy': HierarchyModel instance
439
  """
440
+ from training.color_model import ColorCLIP
441
+ from training.hierarchy_model import HierarchyModel
442
+
443
+ # --- Color model ---
444
+ print("Loading ColorCLIP (CLIP-backbone) ...")
445
+ color_model = ColorCLIP.from_checkpoint(config.color_model_path, device=config.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  color_model.eval()
447
  color_model.name = config.color_column
448
 
449
+ # --- Hierarchy model ---
450
+ print("Loading HierarchyModel (CLIP-backbone) ...")
451
+ hierarchy_model = HierarchyModel.from_checkpoint(config.hierarchy_model_path, device=config.device)
 
 
 
 
 
 
 
 
 
452
  hierarchy_model.eval()
453
  hierarchy_model.name = config.hierarchy_column
454
 
455
  feature_models = {model.name: model for model in [color_model, hierarchy_model]}
 
456
  return feature_models
457
 
458
  # -------------------------------
 
635
  plt.grid(True, alpha=0.3)
636
 
637
  plt.tight_layout()
638
+ curves_path = str(config.ROOT_DIR / "figures" / "training_curves.png")
639
+ plt.savefig(curves_path, dpi=300, bbox_inches='tight')
640
  plt.close()
641
+
642
  print(f"\nTraining completed!")
643
  print(f"Best validation loss: {best_val_loss:.4f}")
644
  print(f"Final model saved to: {save_path}")
645
+ print(f"Training curves saved to: {curves_path}")
646
 
647
  return train_losses, val_losses
648
 
 
652
 
653
  def main():
654
  print("="*80)
655
+ print("🚀 Training of the model with alignment color and hierarchy")
656
  print("="*80)
657
 
658
+ # Configuration (tuned for zero-shot + separation balance)
659
+ num_epochs = 10
660
+ learning_rate = 1.5e-5
661
+ temperature = 0.09
662
+ alignment_weight = 0.10 # reduced from 0.2: softer alignment preserves CLIP zero-shot
663
+ reference_weight = 0.25 # increased from 0.1: stronger regularization toward base CLIP
664
+ weight_decay = 1e-3 # increased from 5e-4: better generalization
665
+ batch_size = 128
666
+ subset_size = 100000
667
+
668
  # Load the data
669
  print(f"\n📂 Loading the data...")
670
  df = pd.read_csv(config.local_dataset_path)
671
  print(f" Data downloaded: {len(df)} samples")
672
+
673
  # filter the rows with NaN values
674
  df_clean = df.dropna(subset=[config.column_local_image_path])
675
+ df_clean = df_clean[df_clean[config.column_local_image_path].astype(str).str.len() > 0]
676
  print(f" After filtering NaN: {len(df_clean)} samples")
677
+
678
  # Creation of datasets
679
  dataset = CustomDataset(df_clean)
680
+
681
+ # Sample 100k for training
 
682
  subset_size = min(subset_size, len(dataset))
683
  train_size = int(0.8 * subset_size)
684
  val_size = subset_size - train_size
685
+
 
686
  np.random.seed(42)
687
  subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
688
  subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
689
+
690
  train_dataset, val_dataset = random_split(
691
+ subset_dataset,
692
  [train_size, val_size],
693
  generator=torch.Generator().manual_seed(42)
694
  )
 
751
  color_alignment_model=feature_models[config.color_column],
752
  weight_decay=weight_decay,
753
  reference_model=reference_clip,
754
+ reference_weight=reference_weight
755
  )
756
 
757
  print("\n" + "="*80)
758
  print("✅ Training finished!")
759
  print(f" Model saved: {config.main_model_path}")
760
+ print(f" Training curves: figures/training_curves.png")
761
  print("\n📊 Final results:")
762
  print(f" Last train loss: {train_losses[-1]:.4f}")
763
  print(f" Last validation loss: {val_losses[-1]:.4f}")