Leacb4 commited on
Commit
3e2b688
·
verified ·
1 Parent(s): 26c2a79

Upload evaluation/sec536_embedding_structure.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/sec536_embedding_structure.py +496 -749
evaluation/sec536_embedding_structure.py CHANGED
@@ -21,12 +21,7 @@ designed for, and tests zero-shot vision-language alignment.
21
  for items sharing a color but differing in category.
22
  Expected result: 1000/1000 pass.
23
 
24
- Test C1 Zero-shot image-to-text classification:
25
- Each image is used as a query; the highest-scoring text label (cosine in
26
- shared latent space) is the predicted class. Accuracy is computed across
27
- three datasets (Fashion-MNIST, KAGL Marqo, Internal).
28
-
29
- Test D — Subspace Decomposition Consistency:
30
  Encode a full description (e.g. "red dress in cotton"), a standalone color
31
  ("red"), and a standalone hierarchy ("dress"). Verify that:
32
  - The color subspace (first 16D) of the full embedding is more similar
@@ -35,6 +30,11 @@ designed for, and tests zero-shot vision-language alignment.
35
  similar to the hierarchy-only embedding than to the color-only embedding.
36
  Expected result: 1000/1000 pass.
37
 
 
 
 
 
 
38
  Paper reference: Section 5.3.6 and Table 4.
39
 
40
  Run directly:
@@ -51,6 +51,9 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
51
  from dataclasses import dataclass
52
  from pathlib import Path
53
  import random
 
 
 
54
  from typing import Dict, List, Optional, Sequence, Tuple
55
 
56
  import numpy as np
@@ -62,16 +65,39 @@ import torch.nn.functional as F
62
  from io import BytesIO
63
  from PIL import Image
64
  from torchvision import transforms
 
 
 
65
  from transformers import CLIPModel as CLIPModelTransformers
66
  from transformers import CLIPProcessor
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  @dataclass
70
  class RuntimeConfig:
71
- color_emb_dim: int = 16
72
- hierarchy_emb_dim: int = 64
73
- main_model_path: str = "models/gap_clip.pth"
74
- device: torch.device = torch.device("cpu")
 
75
 
76
  DEFAULT_NUM_EXAMPLES = 10000
77
  DEFAULT_NUM_PRINTED = 3
@@ -106,6 +132,7 @@ def resolve_runtime_config() -> RuntimeConfig:
106
 
107
  cfg.color_emb_dim = getattr(config, "color_emb_dim", cfg.color_emb_dim)
108
  cfg.hierarchy_emb_dim = getattr(config, "hierarchy_emb_dim", cfg.hierarchy_emb_dim)
 
109
  cfg.main_model_path = getattr(config, "main_model_path", cfg.main_model_path)
110
  cfg.device = getattr(config, "device", cfg.device)
111
  except Exception:
@@ -120,27 +147,50 @@ def resolve_runtime_config() -> RuntimeConfig:
120
 
121
 
122
  def load_main_model(device: torch.device, main_model_path: str) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
123
- """Load GAP-CLIP (LAION CLIP + finetuned checkpoint) and processor.
124
-
125
- Delegates to utils.model_loader.load_gap_clip for consistent loading.
126
- """
127
- from evaluation.utils.model_loader import load_gap_clip # type: ignore
128
- return load_gap_clip(main_model_path, device)
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- def get_text_embedding(
132
- model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, text: str
133
- ) -> torch.Tensor:
134
- """Extract normalized text embedding for a single query."""
135
- text_inputs = processor(text=[text], padding=True, return_tensors="pt")
136
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
137
 
 
 
 
 
 
 
138
  with torch.no_grad():
139
- text_outputs = model.text_model(**text_inputs)
140
- text_features = model.text_projection(text_outputs.pooler_output)
141
- text_features = F.normalize(text_features, dim=-1)
142
 
143
- return text_features.squeeze(0)
 
 
 
144
 
145
 
146
  def cosine(a: torch.Tensor, b: torch.Tensor) -> float:
@@ -179,8 +229,7 @@ def run_test_a(
179
  cfg: RuntimeConfig,
180
  num_examples: int,
181
  num_printed: int,
182
- test_name: str = "Test A",
183
- ) -> Dict[str, bool]:
184
  """
185
  A: different colors + same hierarchy.
186
  Expect hierarchy subspace to be more similar than color subspace.
@@ -292,8 +341,7 @@ def run_test_b(
292
  cfg: RuntimeConfig,
293
  num_examples: int,
294
  num_printed: int,
295
- test_name: str = "Test B",
296
- ) -> Dict[str, bool]:
297
  """
298
  B: same color + different hierarchies.
299
  Expect similarity in first16 (color) to be higher than full512.
@@ -398,16 +446,15 @@ def run_test_b(
398
 
399
 
400
 
401
- def run_test_d(
402
  model: CLIPModelTransformers,
403
  processor: CLIPProcessor,
404
  cfg: RuntimeConfig,
405
  num_examples: int,
406
  num_printed: int,
407
- test_name: str = "Test D",
408
- ) -> Dict[str, object]:
409
  """
410
- D: Subspace Decomposition Consistency.
411
  Encode a full description (e.g. "red dress in cotton"), a standalone color
412
  ("red"), and a standalone hierarchy ("dress"). Then verify:
413
  - The color subspace (first 16D) of the full embedding aligns with the
@@ -568,36 +615,26 @@ def fashion_mnist_pixels_to_tensor(pixel_values: np.ndarray, image_size: int = 2
568
  def get_image_embedding(
569
  model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image_tensor: torch.Tensor
570
  ) -> torch.Tensor:
 
571
  image_tensor = image_tensor.unsqueeze(0).to(device)
572
- with torch.no_grad():
573
- vision_outputs = model.vision_model(pixel_values=image_tensor)
574
- image_features = model.visual_projection(vision_outputs.pooler_output)
575
- image_features = F.normalize(image_features, dim=-1)
576
- return image_features.squeeze(0)
577
 
578
 
579
  def get_image_embedding_from_pil(
580
  model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image: Image.Image
581
  ) -> torch.Tensor:
582
- image_inputs = processor(images=[image], return_tensors="pt")
583
- image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
584
- with torch.no_grad():
585
- vision_outputs = model.vision_model(**image_inputs)
586
- image_features = model.visual_projection(vision_outputs.pooler_output)
587
- image_features = F.normalize(image_features, dim=-1)
588
- return image_features.squeeze(0)
589
 
590
 
591
  def get_text_embeddings_batch(
592
  model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, texts: List[str]
593
  ) -> torch.Tensor:
594
- text_inputs = processor(text=texts, padding=True, return_tensors="pt")
595
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
596
- with torch.no_grad():
597
- text_outputs = model.text_model(**text_inputs)
598
- text_features = model.text_projection(text_outputs.pooler_output)
599
- text_features = F.normalize(text_features, dim=-1)
600
- return text_features
601
 
602
 
603
  def get_prompt_ensembled_text_embeddings(
@@ -678,79 +715,187 @@ def get_adaptive_label_prior(labels: List[str]) -> Tuple[torch.Tensor, float]:
678
  return probs, recommended_weight
679
 
680
 
681
- def run_test_c(
682
- model: CLIPModelTransformers,
683
- processor: CLIPProcessor,
684
- cfg: RuntimeConfig,
685
- num_examples: int,
686
- num_printed: int,
687
- csv_path: str = FASHION_MNIST_CSV,
688
- ) -> Dict[str, object]:
689
- """
690
- C: Zero-shot image classification.
691
- For each image, compute cosine similarity against all candidate text labels
692
- and check whether the highest-scoring text matches the ground truth.
693
- """
694
- csv_file = Path(csv_path)
695
- if not csv_file.exists():
696
- print(f" Skipping Test C: {csv_path} not found")
697
- return {"overall": True, "accuracy": None}
698
-
699
- df = pd.read_csv(csv_path)
700
- df = df.sample(n=min(num_examples, len(df)), random_state=42).reset_index(drop=True)
701
 
702
- candidate_labels = sorted(set(FASHION_MNIST_ORIGINAL_LABELS.values()))
703
- candidate_texts = [f"a photo of a {label}" for label in candidate_labels]
704
- text_embs = get_text_embeddings_batch(model, processor, cfg.device, candidate_texts)
705
 
706
- pixel_cols = [f"pixel{i}" for i in range(1, 785)]
707
- rows: List[List[str]] = []
708
- failed_rows: List[List[str]] = []
709
  correct = 0
 
710
 
711
- for idx in range(len(df)):
712
- row = df.iloc[idx]
713
- label_id = int(row["label"])
714
- ground_truth = FASHION_MNIST_ORIGINAL_LABELS.get(label_id, "unknown")
715
 
716
- pixels = row[pixel_cols].values.astype(float)
717
- img_tensor = fashion_mnist_pixels_to_tensor(pixels)
718
- img_emb = get_image_embedding(model, processor, cfg.device, img_tensor)
719
 
720
- sims = F.cosine_similarity(img_emb.unsqueeze(0), text_embs, dim=1)
721
- best_idx = sims.argmax().item()
722
- predicted = candidate_labels[best_idx]
723
- best_sim = sims[best_idx].item()
724
 
725
- ok = predicted == ground_truth
726
- if ok:
727
- correct += 1
728
 
729
- rows.append([
730
- str(idx),
731
- ground_truth,
732
- predicted,
733
- f"{best_sim:.4f}",
734
- format_bool(ok),
735
- ])
736
- if not ok:
737
- failed_rows.append([
738
- str(idx),
739
- ground_truth,
740
- predicted,
741
- f"{best_sim:.4f}",
742
- ])
743
 
744
- accuracy = correct / len(df)
745
 
746
- print_table(
747
- f"Test C: Zero-shot image classification (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
748
- ["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
749
- rows[:num_printed],
750
- )
751
- print(f"Test C aggregate: {correct}/{len(df)} correct ({accuracy:.2%})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
 
753
- return {"overall": True, "accuracy": accuracy}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
 
756
  def normalize_hierarchy_label(raw_label: str) -> str:
@@ -805,544 +950,203 @@ def normalize_hierarchy_label(raw_label: str) -> str:
805
  "innerwear": "underwear",
806
  "loungewear and nightwear": "underwear",
807
  "saree": "dress",
 
 
 
 
 
 
808
  }
809
- return synonyms.get(label, label)
810
-
811
-
812
- def get_candidate_labels_from_internal_csv() -> List[str]:
813
- csv_file = Path(INTERNAL_DATASET_CSV)
814
- if csv_file.exists():
815
- df = pd.read_csv(INTERNAL_DATASET_CSV, usecols=["hierarchy"]).dropna()
816
- labels = sorted(set(normalize_hierarchy_label(v) for v in df["hierarchy"].astype(str)))
817
- if labels:
818
- return labels
819
- return sorted(set(FASHION_MNIST_LABELS.values()))
820
-
821
-
822
- def load_hierarchy_model_for_eval(device: torch.device):
823
- """Load the trained hierarchy model for evaluation strategies. Returns None on failure."""
824
- try:
825
- from training.hierarchy_model import Model as _HierarchyModel, HierarchyExtractor as _HierarchyExtractor
826
- import config as _cfg
827
- except ImportError:
828
- return None
829
- model_path = Path(getattr(_cfg, "hierarchy_model_path", "models/hierarchy_model.pth"))
830
- if not model_path.exists():
831
- return None
832
- try:
833
- checkpoint = torch.load(str(model_path), map_location=device)
834
- hierarchy_classes = checkpoint.get("hierarchy_classes", [])
835
- if not hierarchy_classes:
836
- return None
837
- _model = _HierarchyModel(
838
- num_hierarchy_classes=len(hierarchy_classes),
839
- embed_dim=getattr(_cfg, "hierarchy_emb_dim", 64),
840
- ).to(device)
841
- _model.load_state_dict(checkpoint["model_state"])
842
- _model.set_hierarchy_extractor(_HierarchyExtractor(hierarchy_classes, verbose=False))
843
- _model.eval()
844
- return _model
845
- except Exception:
846
- return None
847
-
848
-
849
- def evaluate_zero_shot_gap(
850
- model: CLIPModelTransformers,
851
- processor: CLIPProcessor,
852
- device: torch.device,
853
- samples: List[Tuple[Image.Image, str]],
854
- candidate_labels: List[str],
855
- title_prefix: str,
856
- num_printed: int,
857
- color_emb_dim: int = 16,
858
- hierarchy_emb_dim: int = 64,
859
- hierarchy_model=None,
860
- ) -> Dict[str, Optional[float]]:
861
- if len(samples) == 0:
862
- print(f" Skipping {title_prefix}: no valid samples")
863
- return {"accuracy_c1": None, "strategy": None}
864
-
865
- # Strategy 1 (baseline prompt) and prompt-ensemble embeddings.
866
- base_templates = ["a photo of a {label}"]
867
- ensemble_templates = [
868
- "a photo of a {label}",
869
- "a product photo of a {label}",
870
- "a studio photo of a {label}",
871
- "a fashion item: {label}",
872
- "an image of a {label}",
873
- ]
874
- text_embs_single = get_prompt_ensembled_text_embeddings(
875
- model=model,
876
- processor=processor,
877
- device=device,
878
- labels=candidate_labels,
879
- templates=base_templates,
880
- )
881
- text_embs_ensemble = get_prompt_ensembled_text_embeddings(
882
- model=model,
883
- processor=processor,
884
- device=device,
885
- labels=candidate_labels,
886
- templates=ensemble_templates,
887
- )
888
-
889
- # Precompute image embeddings once for C1.
890
- image_embs: List[torch.Tensor] = []
891
- for image, _ in samples:
892
- image_embs.append(get_image_embedding_from_pil(model, processor, device, image))
893
- image_embs_tensor = torch.stack(image_embs, dim=0)
894
-
895
- # Similarity matrices (N images x C labels)
896
- sims_single = image_embs_tensor @ text_embs_single.T
897
- sims_ensemble = image_embs_tensor @ text_embs_ensemble.T
898
-
899
- # Calibration and prior terms.
900
- class_bias = sims_ensemble.mean(dim=0, keepdim=True)
901
- class_prior = get_internal_label_prior(candidate_labels).to(device)
902
- log_prior = torch.log(class_prior + 1e-8).unsqueeze(0)
903
-
904
- # Baseline inference-time strategies (full 512-d embedding).
905
- strategy_scores: Dict[str, torch.Tensor] = {
906
- "single_prompt": sims_single,
907
- "prompt_ensemble": sims_ensemble,
908
- "ensemble_plus_calibration": sims_ensemble - 0.2 * class_bias,
909
- "ensemble_plus_prior": sims_ensemble + 0.15 * log_prior,
910
- "ensemble_calibration_plus_prior": sims_ensemble - 0.2 * class_bias + 0.15 * log_prior,
911
- }
912
-
913
- # Extended prompt ensemble for broader category coverage.
914
- extended_templates = [
915
- "a photo of a {label}",
916
- "a product photo of a {label}",
917
- "a studio photo of a {label}",
918
- "a fashion item: {label}",
919
- "an image of a {label}",
920
- "{label}",
921
- "a picture of a {label}",
922
- "this is a {label}",
923
- "a fashion product: {label}",
924
- "a {label} clothing item",
925
  ]
926
- text_embs_extended = get_prompt_ensembled_text_embeddings(
927
- model=model, processor=processor, device=device,
928
- labels=candidate_labels, templates=extended_templates,
929
- )
930
- sims_extended = image_embs_tensor @ text_embs_extended.T
931
-
932
- # Subspace: exclude color dimensions (keep hierarchy + residual).
933
- hier_end = color_emb_dim + hierarchy_emb_dim
934
- img_no_color = F.normalize(image_embs_tensor[:, color_emb_dim:], dim=-1)
935
- text_ext_no_color = F.normalize(text_embs_extended[:, color_emb_dim:], dim=-1)
936
- text_ens_no_color = F.normalize(text_embs_ensemble[:, color_emb_dim:], dim=-1)
937
- sims_no_color = img_no_color @ text_ens_no_color.T
938
- sims_no_color_ext = img_no_color @ text_ext_no_color.T
939
-
940
- # Subspace: hierarchy-only dimensions.
941
- img_hier = F.normalize(image_embs_tensor[:, color_emb_dim:hier_end], dim=-1)
942
- text_ens_hier = F.normalize(text_embs_ensemble[:, color_emb_dim:hier_end], dim=-1)
943
- text_ext_hier = F.normalize(text_embs_extended[:, color_emb_dim:hier_end], dim=-1)
944
- sims_hier_ens = img_hier @ text_ens_hier.T
945
- sims_hier_ext = img_hier @ text_ext_hier.T
946
-
947
- # Adaptive prior (reduces influence for out-of-domain label sets).
948
- adaptive_prior, adaptive_weight = get_adaptive_label_prior(candidate_labels)
949
- adaptive_prior = adaptive_prior.to(device)
950
- log_adaptive_prior = torch.log(adaptive_prior + 1e-8).unsqueeze(0)
951
-
952
- class_bias_no_color = sims_no_color.mean(dim=0, keepdim=True)
953
-
954
- strategy_scores.update({
955
- "extended_ensemble": sims_extended,
956
- "no_color_ensemble": sims_no_color,
957
- "no_color_extended": sims_no_color_ext,
958
- "hierarchy_only_ensemble": sims_hier_ens,
959
- "hierarchy_only_extended": sims_hier_ext,
960
- "no_color_calibrated": sims_no_color - 0.2 * class_bias_no_color,
961
- "no_color_adaptive_prior": sims_no_color + adaptive_weight * log_adaptive_prior,
962
- "no_color_ext_adaptive_prior": sims_no_color_ext + adaptive_weight * log_adaptive_prior,
963
- "extended_adaptive_prior": sims_extended + adaptive_weight * log_adaptive_prior,
964
- })
965
-
966
- # Weighted embeddings: amplify hierarchy dims relative to residual.
967
- for amp_factor in (2.0, 4.0):
968
- weights = torch.ones(image_embs_tensor.shape[1], device=device)
969
- weights[:color_emb_dim] = 0.0
970
- weights[color_emb_dim:hier_end] = amp_factor
971
- weighted_img = F.normalize(image_embs_tensor * weights.unsqueeze(0), dim=-1)
972
- weighted_text = F.normalize(text_embs_extended * weights.unsqueeze(0), dim=-1)
973
- tag = f"weighted_hier_{amp_factor:.0f}x"
974
- strategy_scores[tag] = weighted_img @ weighted_text.T
975
-
976
- # Hierarchy model direct strategy (uses dedicated hierarchy encoder).
977
- if hierarchy_model is not None:
978
- hier_text_embs: List[torch.Tensor] = []
979
- known_label_mask: List[bool] = []
980
- for label in candidate_labels:
981
- try:
982
- emb = hierarchy_model.get_text_embeddings(label).squeeze(0)
983
- hier_text_embs.append(emb)
984
- known_label_mask.append(True)
985
- except (ValueError, Exception):
986
- hier_text_embs.append(text_ext_hier[candidate_labels.index(label)])
987
- known_label_mask.append(False)
988
- hier_text_matrix = F.normalize(torch.stack(hier_text_embs).to(device), dim=-1)
989
- sims_hier_model = img_hier @ hier_text_matrix.T
990
- strategy_scores["hierarchy_model_direct"] = sims_hier_model
991
- class_bias_hier = sims_hier_model.mean(dim=0, keepdim=True)
992
- strategy_scores["hier_model_calibrated"] = sims_hier_model - 0.2 * class_bias_hier
993
- strategy_scores["hier_model_adaptive_prior"] = sims_hier_model + adaptive_weight * log_adaptive_prior
994
-
995
- # Hybrid: hierarchy model scores for known labels, CLIP for unknown.
996
- hybrid_scores = sims_no_color_ext.clone()
997
- for label_idx, is_known in enumerate(known_label_mask):
998
- if is_known:
999
- hybrid_scores[:, label_idx] = sims_hier_model[:, label_idx]
1000
- strategy_scores["hybrid_hier_clip"] = hybrid_scores
1001
-
1002
- # Blended: z-score-normalised mix of hierarchy and full-space scores.
1003
- hier_mu = sims_hier_model.mean()
1004
- hier_std = sims_hier_model.std() + 1e-8
1005
- full_mu = sims_extended.mean()
1006
- full_std = sims_extended.std() + 1e-8
1007
- hier_z = (sims_hier_model - hier_mu) / hier_std
1008
- full_z = (sims_extended - full_mu) / full_std
1009
- for alpha in (0.3, 0.5, 0.7):
1010
- strategy_scores[f"blend_hier_full_{alpha:.1f}"] = alpha * hier_z + (1 - alpha) * full_z
1011
-
1012
- # Select best strategy for C1.
1013
-
1014
- best_strategy_c1 = "single_prompt"
1015
- best_acc_c1 = -1.0
1016
- best_scores_c1 = sims_single
1017
-
1018
- # Track per-strategy accuracies and weighted-F1 for fair comparison.
1019
- all_strategy_acc_c1: Dict[str, float] = {}
1020
- all_strategy_wf1_c1: Dict[str, float] = {}
1021
- ground_truths = [gt for _, gt in samples]
1022
-
1023
- for strategy_name, score_mat in strategy_scores.items():
1024
- pred_idx = score_mat.argmax(dim=1).tolist()
1025
- preds = [candidate_labels[i] for i in pred_idx]
1026
- correct = sum(1 for p, g in zip(preds, ground_truths) if p == g)
1027
- acc = correct / len(samples)
1028
- wf1 = f1_score(ground_truths, preds, average="weighted", zero_division=0)
1029
-
1030
- all_strategy_acc_c1[strategy_name] = acc
1031
- all_strategy_wf1_c1[strategy_name] = wf1
1032
-
1033
- if acc > best_acc_c1:
1034
- best_acc_c1 = acc
1035
- best_strategy_c1 = strategy_name
1036
- best_scores_c1 = score_mat
1037
-
1038
- best_wf1_c1 = all_strategy_wf1_c1[best_strategy_c1]
1039
- print(f"{title_prefix} selected C1 strategy: {best_strategy_c1} (acc={best_acc_c1:.2%}, wF1={best_wf1_c1:.2%})")
1040
-
1041
- # C1: image -> all texts (classification)
1042
- rows: List[List[str]] = []
1043
- correct = 0
1044
- all_preds: List[str] = []
1045
-
1046
- for idx, (_, ground_truth) in enumerate(samples):
1047
- sims = best_scores_c1[idx]
1048
- best_idx = int(sims.argmax().item())
1049
- predicted = candidate_labels[best_idx]
1050
- best_sim = float(sims[best_idx].item())
1051
-
1052
- ok = predicted == ground_truth
1053
- if ok:
1054
- correct += 1
1055
- all_preds.append(predicted)
1056
-
1057
- rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}", format_bool(ok)])
1058
-
1059
- accuracy_c1 = correct / len(samples)
1060
- wf1_c1 = f1_score(ground_truths, all_preds, average="weighted", zero_division=0)
1061
-
1062
- print_table(
1063
- f"{title_prefix} C1 image->texts (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
1064
- ["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
1065
- rows[:num_printed],
1066
- )
1067
- print(f"{title_prefix} C1 aggregate: {correct}/{len(samples)} correct (acc={accuracy_c1:.2%}, wF1={wf1_c1:.2%})")
1068
-
1069
- return {
1070
- "accuracy_c1": accuracy_c1,
1071
- "wf1_c1": wf1_c1,
1072
- "strategy": best_strategy_c1,
1073
- "all_strategy_acc_c1": all_strategy_acc_c1,
1074
- "all_strategy_wf1_c1": all_strategy_wf1_c1,
1075
- }
1076
-
1077
 
1078
- def evaluate_zero_shot_baseline(
1079
- baseline_model: CLIPModelTransformers,
1080
- baseline_processor: CLIPProcessor,
1081
- device: torch.device,
1082
- samples: List[Tuple[Image.Image, str]],
1083
- candidate_labels: List[str],
1084
- title_prefix: str,
1085
- num_printed: int,
1086
- ) -> Dict[str, Optional[float]]:
1087
- if len(samples) == 0:
1088
- print(f" Skipping baseline {title_prefix}: no valid samples")
1089
- return {"accuracy_c1": None}
1090
-
1091
- candidate_texts = [f"a photo of a {label}" for label in candidate_labels]
1092
- text_inputs = baseline_processor(text=candidate_texts, return_tensors="pt", padding=True, truncation=True)
1093
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
1094
- with torch.no_grad():
1095
- text_embs = baseline_model.get_text_features(**text_inputs)
1096
- text_embs = F.normalize(text_embs, dim=-1)
1097
-
1098
- # Precompute image embeddings once for C1.
1099
- image_embs: List[torch.Tensor] = []
1100
- for image, _ in samples:
1101
- image_inputs = baseline_processor(images=[image], return_tensors="pt")
1102
- image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
1103
- with torch.no_grad():
1104
- img_emb = baseline_model.get_image_features(**image_inputs)
1105
- img_emb = F.normalize(img_emb, dim=-1)
1106
- image_embs.append(img_emb.squeeze(0))
1107
- image_embs_tensor = torch.stack(image_embs, dim=0)
1108
-
1109
- # C1: image -> all texts (classification)
1110
- rows: List[List[str]] = []
1111
- correct = 0
1112
- all_preds: List[str] = []
1113
- ground_truths = [gt for _, gt in samples]
1114
-
1115
- for idx, (_, ground_truth) in enumerate(samples):
1116
- img_emb = image_embs_tensor[idx].unsqueeze(0)
1117
- sims = F.cosine_similarity(img_emb, text_embs, dim=1)
1118
- best_idx = sims.argmax().item()
1119
- predicted = candidate_labels[best_idx]
1120
- best_sim = sims[best_idx].item()
1121
-
1122
- ok = predicted == ground_truth
1123
- if ok:
1124
- correct += 1
1125
- all_preds.append(predicted)
1126
-
1127
- rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}", format_bool(ok)])
1128
-
1129
- accuracy_c1 = correct / len(samples)
1130
- wf1_c1 = f1_score(ground_truths, all_preds, average="weighted", zero_division=0)
1131
- baseline_title = f"Baseline {title_prefix}"
1132
- print_table(
1133
- f"{baseline_title} C1 image->texts (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
1134
- ["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
1135
- rows[:num_printed],
1136
- )
1137
- print(f"{baseline_title} C1 aggregate: {correct}/{len(samples)} correct (acc={accuracy_c1:.2%}, wF1={wf1_c1:.2%})")
1138
 
1139
- return {"accuracy_c1": accuracy_c1, "wf1_c1": wf1_c1}
1140
 
1141
 
1142
- def load_fashion_mnist_samples(
1143
- num_examples: int,
1144
- ) -> Tuple[List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]]]:
1145
- """Return (baseline_samples, gap_samples) with native and GAP-CLIP labels."""
1146
- csv_file = Path(FASHION_MNIST_CSV)
1147
- if not csv_file.exists():
1148
- return [], []
1149
- df = pd.read_csv(FASHION_MNIST_CSV)
1150
- df = df.sample(n=min(num_examples, len(df)), random_state=42).reset_index(drop=True)
1151
- pixel_cols = [f"pixel{i}" for i in range(1, 785)]
1152
 
1153
- baseline_samples: List[Tuple[Image.Image, str]] = []
1154
- gap_samples: List[Tuple[Image.Image, str]] = []
1155
- for _, row in df.iterrows():
1156
- label_id = int(row["label"])
1157
- pixels = row[pixel_cols].values.astype(float)
1158
- img_array = pixels.reshape(28, 28).astype(np.uint8)
1159
- img_array = np.stack([img_array] * 3, axis=-1)
1160
- image = Image.fromarray(img_array)
1161
- baseline_samples.append((image, FASHION_MNIST_ORIGINAL_LABELS.get(label_id, "unknown")))
1162
- gap_samples.append((image, FASHION_MNIST_LABELS.get(label_id, "unknown")))
1163
- return baseline_samples, gap_samples
1164
 
1165
 
1166
- def load_kagl_marqo_samples(
1167
  num_examples: int,
1168
- ) -> Tuple[List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]]]:
1169
- """Return (baseline_samples, gap_samples) with native and GAP-CLIP labels."""
1170
- try:
1171
- from datasets import load_dataset # type: ignore
1172
- except Exception:
1173
- print(" Skipping KAGL Marqo: datasets package not available")
1174
- return [], []
1175
 
1176
- try:
1177
- dataset = load_dataset("Marqo/KAGL", split="data")
1178
- except Exception as exc:
1179
- print(f" Skipping KAGL Marqo: failed to load dataset ({exc})")
1180
- return [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1181
 
1182
- dataset = dataset.shuffle(seed=42).select(range(min(num_examples, len(dataset))))
1183
  baseline_samples: List[Tuple[Image.Image, str]] = []
1184
  gap_samples: List[Tuple[Image.Image, str]] = []
1185
- for item in dataset:
1186
- raw_label = item.get("category2")
1187
- if raw_label is None:
1188
- continue
1189
- native_label = str(raw_label).strip()
1190
- gap_label = normalize_hierarchy_label(native_label)
1191
- image_obj = item.get("image")
1192
- if image_obj is None:
1193
- continue
1194
- if hasattr(image_obj, "convert"):
1195
- image = image_obj.convert("RGB")
1196
- elif isinstance(image_obj, dict) and "bytes" in image_obj:
1197
- image = Image.open(BytesIO(image_obj["bytes"])).convert("RGB")
1198
- else:
1199
- continue
1200
- baseline_samples.append((image, native_label))
1201
- gap_samples.append((image, gap_label))
1202
- return baseline_samples, gap_samples
1203
-
1204
-
1205
- def load_internal_samples(
1206
- num_examples: int,
1207
- ) -> Tuple[List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]]]:
1208
- """Return (baseline_samples, gap_samples) — same labels for both on this dataset."""
1209
- csv_file = Path(INTERNAL_DATASET_CSV)
1210
- if not csv_file.exists():
1211
- print(f" Skipping internal dataset: {INTERNAL_DATASET_CSV} not found")
1212
- return [], []
1213
 
1214
- df = pd.read_csv(INTERNAL_DATASET_CSV)
1215
- if "hierarchy" not in df.columns:
1216
- print(" Skipping internal dataset: missing 'hierarchy' column")
1217
- return [], []
1218
-
1219
- df = df.dropna(subset=["hierarchy", "image_url"]).sample(frac=1.0, random_state=42)
1220
- samples: List[Tuple[Image.Image, str]] = []
1221
-
1222
- for _, row in df.iterrows():
1223
- if len(samples) >= num_examples:
1224
  break
1225
- ground_truth = normalize_hierarchy_label(str(row["hierarchy"]))
1226
- image_url = str(row["image_url"])
 
 
 
 
1227
  try:
1228
- response = requests.get(image_url, timeout=5)
1229
- response.raise_for_status()
1230
- image = Image.open(BytesIO(response.content)).convert("RGB")
1231
- samples.append((image, ground_truth))
1232
  except Exception:
1233
  continue
1234
- return samples, samples
1235
-
1236
-
1237
- def run_test_c_baseline_fashion_clip(
1238
- device: torch.device,
1239
- num_examples: int,
1240
- num_printed: int,
1241
- csv_path: str = FASHION_MNIST_CSV,
1242
- ) -> Dict[str, Optional[float]]:
1243
- """
1244
- Same zero-shot protocol as Test C, but using baseline Fashion-CLIP.
1245
- """
1246
- csv_file = Path(csv_path)
1247
- if not csv_file.exists():
1248
- print(f" Skipping Baseline Test C: {csv_path} not found")
1249
- return {"accuracy": None}
1250
-
1251
- print("\nLoading baseline model (patrickjohncyh/fashion-clip)...")
1252
- baseline_name = "patrickjohncyh/fashion-clip"
1253
- baseline_processor = CLIPProcessor.from_pretrained(baseline_name)
1254
- baseline_model = CLIPModelTransformers.from_pretrained(baseline_name).to(device)
1255
- baseline_model.eval()
1256
- print("Baseline model loaded.")
1257
 
1258
- df = pd.read_csv(csv_path)
1259
- df = df.sample(n=min(num_examples, len(df)), random_state=42).reset_index(drop=True)
1260
-
1261
- candidate_labels = sorted(set(FASHION_MNIST_ORIGINAL_LABELS.values()))
1262
- candidate_texts = [f"a photo of a {label}" for label in candidate_labels]
1263
-
1264
- text_inputs = baseline_processor(text=candidate_texts, return_tensors="pt", padding=True, truncation=True)
1265
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
1266
- with torch.no_grad():
1267
- text_embs = baseline_model.get_text_features(**text_inputs)
1268
- text_embs = F.normalize(text_embs, dim=-1)
1269
-
1270
- pixel_cols = [f"pixel{i}" for i in range(1, 785)]
1271
- rows: List[List[str]] = []
1272
- failed_rows: List[List[str]] = []
1273
- correct = 0
1274
-
1275
- for idx in range(len(df)):
1276
- row = df.iloc[idx]
1277
- label_id = int(row["label"])
1278
- ground_truth = FASHION_MNIST_ORIGINAL_LABELS.get(label_id, "unknown")
1279
-
1280
- pixels = row[pixel_cols].values.astype(float)
1281
- img_array = pixels.reshape(28, 28).astype(np.uint8)
1282
- img_array = np.stack([img_array] * 3, axis=-1)
1283
- image = Image.fromarray(img_array)
1284
-
1285
- image_inputs = baseline_processor(images=[image], return_tensors="pt")
1286
- image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
1287
- with torch.no_grad():
1288
- img_emb = baseline_model.get_image_features(**image_inputs)
1289
- img_emb = F.normalize(img_emb, dim=-1)
1290
-
1291
- sims = F.cosine_similarity(img_emb, text_embs, dim=1)
1292
- best_idx = sims.argmax().item()
1293
- predicted = candidate_labels[best_idx]
1294
- best_sim = sims[best_idx].item()
1295
-
1296
- ok = predicted == ground_truth
1297
- if ok:
1298
- correct += 1
1299
-
1300
- rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}", format_bool(ok)])
1301
- if not ok:
1302
- failed_rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}"])
1303
-
1304
- accuracy = correct / len(df)
1305
-
1306
- print_table(
1307
- f"Baseline Test C (Fashion-CLIP): zero-shot (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
1308
- ["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
1309
- rows[:num_printed],
1310
- )
1311
- print(f"Baseline Test C aggregate: {correct}/{len(df)} correct ({accuracy:.2%})")
1312
-
1313
- return {"accuracy": accuracy}
1314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1315
 
1316
- def main(selected_tests: set[str]) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1317
  random.seed(42)
1318
  cfg = resolve_runtime_config()
1319
- model_path = Path(cfg.main_model_path)
1320
- if not model_path.exists():
1321
- raise FileNotFoundError(f"Main model checkpoint not found: {cfg.main_model_path}")
1322
 
1323
- print("Loading model...")
1324
- print(f" device: {cfg.device}")
1325
- print(f" checkpoint: {cfg.main_model_path}")
1326
- print(f" dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total=512")
1327
- model, processor = load_main_model(cfg.device, cfg.main_model_path)
1328
- print("Model loaded.")
 
 
 
 
 
 
1329
 
1330
  result_a: Optional[Dict[str, object]] = None
1331
  result_b: Optional[Dict[str, object]] = None
1332
- result_d: Optional[Dict[str, object]] = None
1333
  baseline_result_a: Optional[Dict[str, object]] = None
1334
  baseline_result_b: Optional[Dict[str, object]] = None
1335
- baseline_result_d: Optional[Dict[str, object]] = None
1336
-
1337
- baseline_processor = None
1338
- baseline_model = None
1339
- if any(t in selected_tests for t in ("A", "B", "C", "D")):
1340
- print("\nLoading baseline model (patrickjohncyh/fashion-clip)...")
1341
- baseline_name = "patrickjohncyh/fashion-clip"
1342
- baseline_processor = CLIPProcessor.from_pretrained(baseline_name)
1343
- baseline_model = CLIPModelTransformers.from_pretrained(baseline_name).to(cfg.device)
1344
- baseline_model.eval()
1345
- print("Baseline model loaded.")
1346
 
1347
  if "A" in selected_tests:
1348
  result_a = run_test_a(
@@ -1378,8 +1182,8 @@ def main(selected_tests: set[str]) -> None:
1378
  num_printed=DEFAULT_NUM_PRINTED,
1379
  test_name="Baseline Test B",
1380
  )
1381
- if "D" in selected_tests:
1382
- result_d = run_test_d(
1383
  model,
1384
  processor,
1385
  cfg,
@@ -1387,83 +1191,59 @@ def main(selected_tests: set[str]) -> None:
1387
  num_printed=DEFAULT_NUM_PRINTED,
1388
  )
1389
  if baseline_model is not None and baseline_processor is not None:
1390
- baseline_result_d = run_test_d(
1391
  baseline_model,
1392
  baseline_processor,
1393
  cfg,
1394
  num_examples=DEFAULT_NUM_EXAMPLES,
1395
  num_printed=DEFAULT_NUM_PRINTED,
1396
- test_name="Baseline Test D",
1397
  )
1398
 
1399
- # Collect results for fair comparison.
1400
- gap_fixed_c1: Dict[str, Optional[float]] = {}
1401
- gap_fixed_wf1_c1: Dict[str, Optional[float]] = {}
1402
- gap_best_wf1_c1: Dict[str, Optional[float]] = {}
1403
- gap_best_strategy: Dict[str, Optional[str]] = {}
1404
- base_fixed_c1: Dict[str, Optional[float]] = {}
1405
- base_fixed_wf1_c1: Dict[str, Optional[float]] = {}
1406
-
1407
- if "C" in selected_tests:
1408
  assert baseline_model is not None and baseline_processor is not None
1409
 
1410
- hierarchy_model_eval = load_hierarchy_model_for_eval(cfg.device)
1411
- if hierarchy_model_eval is not None:
1412
- print("Hierarchy model loaded for evaluation strategies.")
1413
- else:
1414
- print("Hierarchy model not available; subspace strategies will use CLIP-only fallback.")
1415
-
1416
- datasets_for_c = {
1417
- "Fashion-MNIST": load_fashion_mnist_samples(DEFAULT_NUM_EXAMPLES),
1418
- "KAGL Marqo": load_kagl_marqo_samples(DEFAULT_NUM_EXAMPLES),
1419
- "Internal dataset": load_internal_samples(min(DEFAULT_NUM_EXAMPLES, 200)),
 
 
 
 
 
 
 
 
 
 
1420
  }
1421
- for dataset_name, (baseline_samples, gap_samples) in datasets_for_c.items():
1422
- print(f"\n{'=' * 120}")
1423
- print(f"Test C on {dataset_name}")
1424
- print(f"{'=' * 120}")
1425
- print(f"Valid samples: {len(baseline_samples)} (baseline), {len(gap_samples)} (GAP-CLIP)")
1426
-
1427
- # Baseline uses dataset-native labels (matches published benchmarks).
1428
- baseline_candidate_labels = sorted(set(label for _, label in baseline_samples))
1429
- # GAP-CLIP uses its training-vocabulary labels.
1430
- gap_candidate_labels = sorted(set(label for _, label in gap_samples))
1431
- print(f"Baseline candidate labels ({len(baseline_candidate_labels)}): {baseline_candidate_labels}")
1432
- print(f"GAP-CLIP candidate labels ({len(gap_candidate_labels)}): {gap_candidate_labels}")
1433
-
1434
- # GAP-CLIP: full strategy search with its own label vocabulary.
1435
- gap_metrics = evaluate_zero_shot_gap(
1436
- model=model,
1437
- processor=processor,
1438
- device=cfg.device,
1439
- samples=gap_samples,
1440
- candidate_labels=gap_candidate_labels,
1441
- title_prefix=f"Test C ({dataset_name}) GAP-CLIP",
1442
- num_printed=DEFAULT_NUM_PRINTED,
1443
- color_emb_dim=cfg.color_emb_dim,
1444
- hierarchy_emb_dim=cfg.hierarchy_emb_dim,
1445
- hierarchy_model=hierarchy_model_eval,
1446
- )
1447
-
1448
- # Baseline: single_prompt with native labels.
1449
- baseline_metrics = evaluate_zero_shot_baseline(
1450
- baseline_model=baseline_model,
1451
- baseline_processor=baseline_processor,
1452
- device=cfg.device,
1453
- samples=baseline_samples,
1454
- candidate_labels=baseline_candidate_labels,
1455
- title_prefix=f"Test C ({dataset_name}) Baseline",
1456
- num_printed=DEFAULT_NUM_PRINTED,
1457
- )
1458
-
1459
- # Store results.
1460
- gap_fixed_c1[dataset_name] = gap_metrics.get("all_strategy_acc_c1", {}).get("single_prompt")
1461
- gap_fixed_wf1_c1[dataset_name] = gap_metrics.get("all_strategy_wf1_c1", {}).get("single_prompt")
1462
- gap_best_wf1_c1[dataset_name] = gap_metrics.get("wf1_c1")
1463
- gap_best_strategy[dataset_name] = gap_metrics.get("strategy")
1464
- base_fixed_c1[dataset_name] = baseline_metrics["accuracy_c1"]
1465
- base_fixed_wf1_c1[dataset_name] = baseline_metrics.get("wf1_c1")
1466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1467
  print("\n" + "=" * 120)
1468
  print("Final Summary")
1469
  print("=" * 120)
@@ -1478,70 +1258,37 @@ def main(selected_tests: set[str]) -> None:
1478
  print(f"Test B full512 accuracy: {float(result_b['accuracy_full512']):.2%}")
1479
  if baseline_result_b is not None:
1480
  print(f"Baseline Test B full512 accuracy: {float(baseline_result_b['accuracy_full512']):.2%}")
1481
- if "C" in selected_tests:
1482
- def _pct(v: Optional[float]) -> str:
1483
- return f"{v:.2%}" if v is not None else "N/A"
1484
-
1485
- def _gain(gap_v: Optional[float], base_v: Optional[float]) -> str:
1486
- if gap_v is None or base_v is None or base_v == 0:
1487
- return "N/A"
1488
- return f"{((gap_v - base_v) / abs(base_v)) * 100:+.1f}%"
1489
-
1490
- print("\n" + "-" * 120)
1491
- print("Test C — Fair comparison (weighted F1): single_prompt for both models")
1492
- print("-" * 120)
1493
- fair_rows: List[List[str]] = []
1494
- for ds in ["Fashion-MNIST", "KAGL Marqo", "Internal dataset"]:
1495
- fair_rows.append([
1496
- ds,
1497
- _pct(gap_fixed_wf1_c1.get(ds)), _pct(base_fixed_wf1_c1.get(ds)),
1498
- _gain(gap_fixed_wf1_c1.get(ds), base_fixed_wf1_c1.get(ds)),
1499
- ])
1500
- print_table(
1501
- "Test C: single_prompt weighted F1 for both",
1502
- ["Dataset", "GAP-CLIP (wF1)", "Baseline (wF1)", "Gain"],
1503
- fair_rows,
1504
- )
1505
-
1506
- print("\n" + "-" * 120)
1507
- print("Test C — GAP-CLIP best strategy vs Baseline")
1508
- print("-" * 120)
1509
- best_rows: List[List[str]] = []
1510
- for ds in ["Fashion-MNIST", "KAGL Marqo", "Internal dataset"]:
1511
- strat = gap_best_strategy.get(ds) or "N/A"
1512
- best_rows.append([
1513
- ds,
1514
- f"{strat}",
1515
- _pct(gap_best_wf1_c1.get(ds)), _pct(base_fixed_wf1_c1.get(ds)),
1516
- _gain(gap_best_wf1_c1.get(ds), base_fixed_wf1_c1.get(ds)),
1517
- ])
1518
- print_table(
1519
- "Test C: GAP-CLIP best strategy vs Baseline single_prompt",
1520
- ["Dataset", "GAP-CLIP strategy", "GAP-CLIP (wF1)", "Baseline (wF1)", "Gain"],
1521
- best_rows,
1522
- )
1523
-
1524
- if result_d is not None:
1525
- print(f"Test D overall: {format_bool(bool(result_d['overall']))}")
1526
- print(f" pass rate: {float(result_d['pass_rate']):.2%}")
1527
- print(f" avg color_match={float(result_d['avg_color_match']):.4f} vs cross={float(result_d['avg_color_cross']):.4f}")
1528
- print(f" avg hier_match={float(result_d['avg_hier_match']):.4f} vs cross={float(result_d['avg_hier_cross']):.4f}")
1529
- if baseline_result_d is not None:
1530
- print(f"Baseline Test D overall: {format_bool(bool(baseline_result_d['overall']))}")
1531
- print(f" baseline pass rate: {float(baseline_result_d['pass_rate']):.2%}")
1532
 
1533
  if result_a is not None:
1534
- assert bool(result_a["overall"]), "Test A failed: hierarchy behavior did not match expected pattern."
 
 
1535
  if result_b is not None:
1536
- assert bool(result_b["overall"]), "Test B failed: first16 correlation was not consistently above full512."
1537
- if result_d is not None:
1538
- assert float(result_d["pass_rate"]) >= 0.95, (
1539
- f"Test D failed: subspace decomposition pass rate {float(result_d['pass_rate']):.2%} < 95%."
 
 
1540
  )
1541
 
1542
  print("\nAll embedding-structure tests passed.")
1543
 
1544
 
1545
  if __name__ == "__main__":
1546
- selected_tests = 'ABCD'
 
 
 
 
 
 
1547
  main(selected_tests)
 
21
  for items sharing a color but differing in category.
22
  Expected result: 1000/1000 pass.
23
 
24
+ Test C Subspace Decomposition Consistency:
 
 
 
 
 
25
  Encode a full description (e.g. "red dress in cotton"), a standalone color
26
  ("red"), and a standalone hierarchy ("dress"). Verify that:
27
  - The color subspace (first 16D) of the full embedding is more similar
 
30
  similar to the hierarchy-only embedding than to the color-only embedding.
31
  Expected result: 1000/1000 pass.
32
 
33
+ Test D — Zero-shot image-to-text classification:
34
+ Each image is used as a query; the highest-scoring text label (cosine in
35
+ shared latent space) is the predicted class. Accuracy is computed across
36
+ three datasets (Fashion-MNIST, KAGL Marqo, Internal).
37
+
38
  Paper reference: Section 5.3.6 and Table 4.
39
 
40
  Run directly:
 
51
  from dataclasses import dataclass
52
  from pathlib import Path
53
  import random
54
+ import sys
55
+
56
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
57
  from typing import Dict, List, Optional, Sequence, Tuple
58
 
59
  import numpy as np
 
65
  from io import BytesIO
66
  from PIL import Image
67
  from torchvision import transforms
68
+ from torchvision import datasets
69
+ from torch.utils.data import DataLoader
70
+ from tqdm import tqdm
71
  from transformers import CLIPModel as CLIPModelTransformers
72
  from transformers import CLIPProcessor
73
 
74
+ from training.hierarchy_model import HierarchyExtractor
75
+
76
+ try:
77
+ import config as project_config # type: ignore
78
+ except Exception:
79
+ project_config = None
80
+
81
+ DEFAULT_COLOR_EMB_DIM = getattr(project_config, "color_emb_dim", 16)
82
+ DEFAULT_HIERARCHY_EMB_DIM = getattr(project_config, "hierarchy_emb_dim", 64)
83
+ DEFAULT_MAIN_EMB_DIM = getattr(project_config, "main_emb_dim", 512)
84
+ DEFAULT_MAIN_MODEL_PATH = getattr(project_config, "main_model_path", "models/gap_clip.pth")
85
+ DEFAULT_DEVICE = getattr(project_config, "device", torch.device("cpu"))
86
+
87
+ _HIERARCHY_EXTRACTOR = HierarchyExtractor([
88
+ "accessories", "bodysuits", "bras", "coat", "dress", "jacket",
89
+ "legging", "pant", "polo", "shirt", "shoes", "short", "skirt",
90
+ "socks", "sweater", "swimwear", "top", "underwear",
91
+ ], verbose=False)
92
+
93
 
94
  @dataclass
95
  class RuntimeConfig:
96
+ color_emb_dim: int = DEFAULT_COLOR_EMB_DIM
97
+ hierarchy_emb_dim: int = DEFAULT_HIERARCHY_EMB_DIM
98
+ main_emb_dim: int = DEFAULT_MAIN_EMB_DIM
99
+ main_model_path: str = DEFAULT_MAIN_MODEL_PATH
100
+ device: torch.device = DEFAULT_DEVICE
101
 
102
  DEFAULT_NUM_EXAMPLES = 10000
103
  DEFAULT_NUM_PRINTED = 3
 
132
 
133
  cfg.color_emb_dim = getattr(config, "color_emb_dim", cfg.color_emb_dim)
134
  cfg.hierarchy_emb_dim = getattr(config, "hierarchy_emb_dim", cfg.hierarchy_emb_dim)
135
+ cfg.main_emb_dim = getattr(config, "main_emb_dim", cfg.main_emb_dim)
136
  cfg.main_model_path = getattr(config, "main_model_path", cfg.main_model_path)
137
  cfg.device = getattr(config, "device", cfg.device)
138
  except Exception:
 
147
 
148
 
149
  def load_main_model(device: torch.device, main_model_path: str) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
150
+ """Load GAP-CLIP from local checkpoint path only."""
151
+ model_path = Path(main_model_path)
152
+ if not model_path.exists():
153
+ raise FileNotFoundError(f"Main model checkpoint not found: {main_model_path}")
154
+
155
+ clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
156
+ model = CLIPModelTransformers.from_pretrained(clip_name)
157
+ checkpoint = torch.load(str(model_path), map_location=device)
158
+ if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
159
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
160
+ else:
161
+ model.load_state_dict(checkpoint, strict=False)
162
+ model = model.to(device)
163
+ model.eval()
164
+ processor = CLIPProcessor.from_pretrained(clip_name)
165
+ return model, processor
166
+
167
+
168
+ def encode_text(model, processor, text_queries, device):
169
+ """Encode text queries into embeddings (unnormalized)."""
170
+ if isinstance(text_queries, str):
171
+ text_queries = [text_queries]
172
+ inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True)
173
+ inputs = {k: v.to(device) for k, v in inputs.items()}
174
+ with torch.no_grad():
175
+ text_features = model.get_text_features(**inputs)
176
+ return text_features
177
 
 
 
 
 
 
 
178
 
179
+ def encode_image(model, processor, images, device):
180
+ """Encode images into embeddings (unnormalized)."""
181
+ if not isinstance(images, list):
182
+ images = [images]
183
+ inputs = processor(images=images, return_tensors="pt")
184
+ inputs = {k: v.to(device) for k, v in inputs.items()}
185
  with torch.no_grad():
186
+ image_features = model.get_image_features(**inputs)
187
+ return image_features
188
+
189
 
190
+ def get_text_embedding(
191
+ model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, text: str) -> torch.Tensor:
192
+ """Normalized single text embedding (shape: [512])."""
193
+ return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0)
194
 
195
 
196
  def cosine(a: torch.Tensor, b: torch.Tensor) -> float:
 
229
  cfg: RuntimeConfig,
230
  num_examples: int,
231
  num_printed: int,
232
+ test_name: str = "Test A") -> Dict[str, bool]:
 
233
  """
234
  A: different colors + same hierarchy.
235
  Expect hierarchy subspace to be more similar than color subspace.
 
341
  cfg: RuntimeConfig,
342
  num_examples: int,
343
  num_printed: int,
344
+ test_name: str = "Test B",) -> Dict[str, bool]:
 
345
  """
346
  B: same color + different hierarchies.
347
  Expect similarity in first16 (color) to be higher than full512.
 
446
 
447
 
448
 
449
+ def run_test_c(
450
  model: CLIPModelTransformers,
451
  processor: CLIPProcessor,
452
  cfg: RuntimeConfig,
453
  num_examples: int,
454
  num_printed: int,
455
+ test_name: str = "Test C",) -> Dict[str, object]:
 
456
  """
457
+ C: Subspace Decomposition Consistency.
458
  Encode a full description (e.g. "red dress in cotton"), a standalone color
459
  ("red"), and a standalone hierarchy ("dress"). Then verify:
460
  - The color subspace (first 16D) of the full embedding aligns with the
 
615
  def get_image_embedding(
616
  model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image_tensor: torch.Tensor
617
  ) -> torch.Tensor:
618
+ """Normalized image embedding from a preprocessed tensor (shape: [512])."""
619
  image_tensor = image_tensor.unsqueeze(0).to(device)
620
+ # Convert tensor to PIL for encode_image
621
+ from torchvision.transforms.functional import to_pil_image
622
+ pil_img = to_pil_image(image_tensor.squeeze(0).cpu())
623
+ return F.normalize(encode_image(model, processor, pil_img, device), dim=-1).squeeze(0)
 
624
 
625
 
626
  def get_image_embedding_from_pil(
627
  model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image: Image.Image
628
  ) -> torch.Tensor:
629
+ """Normalized image embedding from a PIL image (shape: [512])."""
630
+ return F.normalize(encode_image(model, processor, image, device), dim=-1).squeeze(0)
 
 
 
 
 
631
 
632
 
633
  def get_text_embeddings_batch(
634
  model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, texts: List[str]
635
  ) -> torch.Tensor:
636
+ """Normalized text embeddings for a batch (shape: [N, 512])."""
637
+ return F.normalize(encode_text(model, processor, texts, device), dim=-1)
 
 
 
 
 
638
 
639
 
640
  def get_prompt_ensembled_text_embeddings(
 
715
  return probs, recommended_weight
716
 
717
 
718
+ def zero_shot_fashion_mnist(
719
+ model,
720
+ processor,
721
+ device,
722
+ batch_size: int = 64,
723
+ data_root: str = "./data") -> float:
724
+ """Notebook-equivalent zero-shot accuracy on all Fashion-MNIST test samples."""
725
+ dataset = datasets.FashionMNIST(
726
+ root=data_root, train=False, download=True,
727
+ transform=transforms.Grayscale(num_output_channels=3),
728
+ )
729
+ loader = DataLoader(
730
+ dataset, batch_size=batch_size, shuffle=False,
731
+ collate_fn=lambda batch: (
732
+ [item[0] for item in batch],
733
+ torch.tensor([item[1] for item in batch]),
734
+ ),
735
+ )
 
 
736
 
737
+ prompts = [f"a photo of a {label}" for label in dataset.classes]
738
+ text_embs = encode_text(model, processor, prompts, device).to(device).float()
739
+ text_embs = F.normalize(text_embs, dim=-1)
740
 
 
 
 
741
  correct = 0
742
+ total = 0
743
 
744
+ for pil_images, labels in tqdm(loader, desc="Zero-shot Fashion-MNIST"):
745
+ img_embs = encode_image(model, processor, pil_images, device)
746
+ img_embs = img_embs.to(device).float()
747
+ img_embs = F.normalize(img_embs, dim=-1)
748
 
749
+ sim = img_embs @ text_embs.T
750
+ preds = sim.argmax(dim=-1).cpu()
 
751
 
752
+ correct += (preds == labels).sum().item()
753
+ total += labels.size(0)
 
 
754
 
755
+ accuracy = correct / total
756
+ print(f"Zero-shot accuracy on Fashion MNIST: {accuracy:.4f} ({correct}/{total})")
757
+ return accuracy
758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
 
 
760
 
761
+ def zero_shot_kagl(
762
+ model,
763
+ processor,
764
+ device,
765
+ batch_size: int = 64,
766
+ num_examples: int = 10000,
767
+ ) -> Optional[Dict[str, float]]:
768
+ """Notebook-equivalent zero-shot accuracy/F1 on KAGL Marqo (category2)."""
769
+ try:
770
+ from datasets import load_dataset # type: ignore
771
+ except Exception:
772
+ print("Skipping zero_shot_kagl: datasets package not available")
773
+ return None
774
+
775
+ try:
776
+ dataset = load_dataset("Marqo/KAGL", split="data")
777
+ except Exception as exc:
778
+ print(f"Skipping zero_shot_kagl: failed to load dataset ({exc})")
779
+ return None
780
+
781
+ dataset = dataset.shuffle(seed=42).select(range(min(num_examples, len(dataset))))
782
+
783
+ pil_images: List[Image.Image] = []
784
+ labels_text: List[str] = []
785
+ for item in dataset:
786
+ raw_label = item.get("category2")
787
+ image_obj = item.get("image")
788
+ if raw_label is None or image_obj is None:
789
+ continue
790
+
791
+ if hasattr(image_obj, "convert"):
792
+ image = image_obj.convert("RGB")
793
+ elif isinstance(image_obj, dict) and "bytes" in image_obj:
794
+ image = Image.open(BytesIO(image_obj["bytes"])).convert("RGB")
795
+ else:
796
+ continue
797
+ pil_images.append(image)
798
+ labels_text.append(str(raw_label).strip())
799
+
800
+ if not pil_images:
801
+ print("Skipping zero_shot_kagl: no valid samples")
802
+ return None
803
 
804
+ candidate_labels = sorted(set(labels_text))
805
+ label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)}
806
+ all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64)
807
+
808
+ prompts = [f"a photo of a {label}" for label in candidate_labels]
809
+ text_embs = encode_text(model, processor, prompts, device).to(device).float()
810
+ text_embs = F.normalize(text_embs, dim=-1)
811
+
812
+ all_preds: List[np.ndarray] = []
813
+ for start in tqdm(range(0, len(pil_images), batch_size), desc="Zero-shot KAGL"):
814
+ batch_images = pil_images[start : start + batch_size]
815
+ img_embs = encode_image(model, processor, batch_images, device).to(device).float()
816
+ img_embs = F.normalize(img_embs, dim=-1)
817
+ sim = img_embs @ text_embs.T
818
+ preds = sim.argmax(dim=-1).cpu().numpy()
819
+ all_preds.append(preds)
820
+
821
+ pred_array = np.concatenate(all_preds, axis=0) if all_preds else np.array([], dtype=np.int64)
822
+ accuracy = float((pred_array == all_labels).mean()) if len(all_labels) else 0.0
823
+ weighted_f1 = f1_score(all_labels, pred_array, average="weighted") if len(all_labels) else 0.0
824
+ print(f"KAGL accuracy: {accuracy:.4f}")
825
+ print(f"KAGL weighted macro F1: {weighted_f1:.4f}")
826
+ return {"accuracy": accuracy, "weighted_f1": float(weighted_f1)}
827
+
828
+
829
+ def zero_shot_internal(
830
+ model,
831
+ processor,
832
+ device,
833
+ batch_size: int = 64,
834
+ num_examples: int = 10000,
835
+ csv_path: str = INTERNAL_DATASET_CSV,) -> Optional[Dict[str, float]]:
836
+ """Notebook-equivalent zero-shot accuracy/F1 on internal dataset."""
837
+ csv_file = Path(csv_path)
838
+ if not csv_file.exists():
839
+ print(f"Skipping zero_shot_internal: {csv_path} not found")
840
+ return None
841
+
842
+ df = pd.read_csv(csv_file)
843
+ use_local = "local_image_path" in df.columns
844
+ required_cols = {"hierarchy", "local_image_path"} if use_local else {"hierarchy", "image_url"}
845
+ if not required_cols.issubset(df.columns):
846
+ print(f"Skipping zero_shot_internal: missing required columns {required_cols}")
847
+ return None
848
+
849
+ img_col = "local_image_path" if use_local else "image_url"
850
+ df = df.dropna(subset=["hierarchy", img_col]).sample(frac=1.0, random_state=42)
851
+ pil_images: List[Image.Image] = []
852
+ labels_text: List[str] = []
853
+ for _, row in df.iterrows():
854
+ if len(pil_images) >= num_examples:
855
+ break
856
+ try:
857
+ if use_local:
858
+ img_path = Path(str(row["local_image_path"]))
859
+ if not img_path.exists():
860
+ continue
861
+ image = Image.open(img_path).convert("RGB")
862
+ else:
863
+ response = requests.get(str(row["image_url"]), timeout=5)
864
+ response.raise_for_status()
865
+ image = Image.open(BytesIO(response.content)).convert("RGB")
866
+ except Exception:
867
+ continue
868
+ label = normalize_hierarchy_label(str(row["hierarchy"]))
869
+ pil_images.append(image)
870
+ labels_text.append(label)
871
+
872
+ if not pil_images:
873
+ print("Skipping zero_shot_internal: no valid samples")
874
+ return None
875
+
876
+ candidate_labels = sorted(set(labels_text))
877
+ label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)}
878
+ all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64)
879
+
880
+ prompts = [f"a photo of a {label}" for label in candidate_labels]
881
+ text_embs = encode_text(model, processor, prompts, device).to(device).float()
882
+ text_embs = F.normalize(text_embs, dim=-1)
883
+
884
+ all_preds: List[np.ndarray] = []
885
+ for start in tqdm(range(0, len(pil_images), batch_size), desc="Zero-shot Internal"):
886
+ batch_images = pil_images[start : start + batch_size]
887
+ img_embs = encode_image(model, processor, batch_images, device).to(device).float()
888
+ img_embs = F.normalize(img_embs, dim=-1)
889
+ sim = img_embs @ text_embs.T
890
+ preds = sim.argmax(dim=-1).cpu().numpy()
891
+ all_preds.append(preds)
892
+
893
+ pred_array = np.concatenate(all_preds, axis=0) if all_preds else np.array([], dtype=np.int64)
894
+ accuracy = float((pred_array == all_labels).mean()) if len(all_labels) else 0.0
895
+ weighted_f1 = f1_score(all_labels, pred_array, average="weighted") if len(all_labels) else 0.0
896
+ print(f"Internal accuracy: {accuracy:.4f}")
897
+ print(f"Internal weighted macro F1: {weighted_f1:.4f}")
898
+ return {"accuracy": accuracy, "weighted_f1": float(weighted_f1)}
899
 
900
 
901
  def normalize_hierarchy_label(raw_label: str) -> str:
 
950
  "innerwear": "underwear",
951
  "loungewear and nightwear": "underwear",
952
  "saree": "dress",
953
+ "boots": "shoes",
954
+ "outer": "coat",
955
+ "sunglasses": "accessories",
956
+ "scarf & tie": "accessories",
957
+ "scarf/tie": "accessories",
958
+ "belt": "accessories",
959
  }
960
+ exact = synonyms.get(label, None)
961
+ if exact is not None:
962
+ return exact
963
+
964
+ # Phase 2: substring/regex fallback via HierarchyExtractor
965
+ # Handles Internal dataset's multi-word hierarchy strings like
966
+ # "womens wms woven shirts sleeveless linen" -> "shirt"
967
+ result = _HIERARCHY_EXTRACTOR.extract_hierarchy(label)
968
+ if result:
969
+ return result
970
+
971
+ # Phase 3: extra keywords for the ~9 labels HierarchyExtractor misses
972
+ _EXTRA_KEYWORDS = [
973
+ ("capri", "pant"),
974
+ ("denim", "pant"),
975
+ ("skinny", "pant"),
976
+ ("boyfriend", "pant"),
977
+ ("graphic", "top"),
978
+ ("longsleeve", "top"),
979
+ ("leather", "jacket"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
  ]
981
+ for keyword, category in _EXTRA_KEYWORDS:
982
+ if keyword in label:
983
+ return category
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984
 
985
+ return label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986
 
 
987
 
988
 
989
+ # ModaNet 13 categories (category_id -> label)
990
+ MODANET_CATEGORIES = {
991
+ 1: "bag", 2: "belt", 3: "boots", 4: "footwear", 5: "outer",
992
+ 6: "dress", 7: "sunglasses", 8: "pants", 9: "top", 10: "shorts",
993
+ 11: "skirt", 12: "headwear", 13: "scarf/tie",
994
+ }
 
 
 
 
995
 
996
+ MODANET_ANNOTATIONS_JSON = "data/modanet_instances_train.json"
997
+ MODANET_IMAGES_DIR = "data/modanet_images/images"
 
 
 
 
 
 
 
 
 
998
 
999
 
1000
+ def load_modanet_samples(
1001
  num_examples: int,
1002
+ ) -> Tuple[List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]]]:
1003
+ """Return (baseline_samples, gap_samples, color_samples) from ModaNet.
 
 
 
 
 
1004
 
1005
+ Loads from local COCO JSON annotations + image directory.
1006
+ Each image may have multiple annotations — we pick the largest bbox area.
1007
+ """
1008
+ import json as _json
1009
+
1010
+ ann_path = Path(MODANET_ANNOTATIONS_JSON)
1011
+ img_dir = Path(MODANET_IMAGES_DIR)
1012
+
1013
+ if not ann_path.exists():
1014
+ print(f" Skipping ModaNet: annotations not found at {MODANET_ANNOTATIONS_JSON}")
1015
+ return [], [], []
1016
+ if not img_dir.exists():
1017
+ print(f" Skipping ModaNet: images directory not found at {MODANET_IMAGES_DIR}")
1018
+ return [], [], []
1019
+
1020
+ print(" Loading ModaNet annotations...")
1021
+ with open(ann_path) as f:
1022
+ coco = _json.load(f)
1023
+
1024
+ cat_map = {c["id"]: c["name"] for c in coco["categories"]}
1025
+ img_map = {img["id"]: img["file_name"] for img in coco["images"]}
1026
+
1027
+ # For each image, find the annotation with the largest area.
1028
+ best_per_image: Dict[int, Tuple[int, float]] = {} # image_id -> (category_id, area)
1029
+ for ann in coco["annotations"]:
1030
+ img_id = ann["image_id"]
1031
+ cat_id = ann["category_id"]
1032
+ area = ann.get("area", 0)
1033
+ if img_id not in best_per_image or area > best_per_image[img_id][1]:
1034
+ best_per_image[img_id] = (cat_id, area)
1035
+
1036
+ # Shuffle deterministically and load images.
1037
+ image_ids = list(best_per_image.keys())
1038
+ rng = random.Random(42)
1039
+ rng.shuffle(image_ids)
1040
 
 
1041
  baseline_samples: List[Tuple[Image.Image, str]] = []
1042
  gap_samples: List[Tuple[Image.Image, str]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1043
 
1044
+ for img_id in image_ids:
1045
+ if len(baseline_samples) >= num_examples:
 
 
 
 
 
 
 
 
1046
  break
1047
+ file_name = img_map.get(img_id)
1048
+ if file_name is None:
1049
+ continue
1050
+ img_path = img_dir / file_name
1051
+ if not img_path.exists():
1052
+ continue
1053
  try:
1054
+ image = Image.open(img_path).convert("RGB")
 
 
 
1055
  except Exception:
1056
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
 
1058
+ cat_id, _ = best_per_image[img_id]
1059
+ native_label = cat_map.get(cat_id, "unknown")
1060
+ gap_label = normalize_hierarchy_label(native_label)
1061
+ baseline_samples.append((image, native_label))
1062
+ gap_samples.append((image, gap_label))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1063
 
1064
+ print(f" ModaNet: loaded {len(baseline_samples)} valid samples (from {len(best_per_image)} annotated images)")
1065
+ return baseline_samples, gap_samples, []
1066
+
1067
+
1068
+ def zero_shot_modanet(
1069
+ model,
1070
+ processor,
1071
+ device,
1072
+ batch_size: int = 64,
1073
+ num_examples: int = 10000,
1074
+ use_gap_labels: bool = True,
1075
+ ) -> Optional[Dict[str, float]]:
1076
+ """Zero-shot accuracy/F1 on ModaNet dataset."""
1077
+ baseline_samples, gap_samples, _ = load_modanet_samples(num_examples)
1078
+ samples = gap_samples if use_gap_labels else baseline_samples
1079
+ if not samples:
1080
+ print("Skipping zero_shot_modanet: no valid samples")
1081
+ return None
1082
 
1083
+ pil_images = [img for img, _ in samples]
1084
+ labels_text = [label for _, label in samples]
1085
+
1086
+ candidate_labels = sorted(set(labels_text))
1087
+ label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)}
1088
+ all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64)
1089
+
1090
+ prompts = [f"a photo of a {label}" for label in candidate_labels]
1091
+ text_embs = encode_text(model, processor, prompts, device).to(device).float()
1092
+ text_embs = F.normalize(text_embs, dim=-1)
1093
+
1094
+ all_preds: List[np.ndarray] = []
1095
+ for start in tqdm(range(0, len(pil_images), batch_size), desc="Zero-shot ModaNet"):
1096
+ batch_images = pil_images[start : start + batch_size]
1097
+ img_embs = encode_image(model, processor, batch_images, device).to(device).float()
1098
+ img_embs = F.normalize(img_embs, dim=-1)
1099
+ sim = img_embs @ text_embs.T
1100
+ preds = sim.argmax(dim=-1).cpu().numpy()
1101
+ all_preds.append(preds)
1102
+
1103
+ pred_array = np.concatenate(all_preds, axis=0) if all_preds else np.array([], dtype=np.int64)
1104
+ accuracy = float((pred_array == all_labels).mean()) if len(all_labels) else 0.0
1105
+ weighted_f1 = f1_score(all_labels, pred_array, average="weighted") if len(all_labels) else 0.0
1106
+ label_kind = "GAP" if use_gap_labels else "native"
1107
+ print(f"ModaNet ({label_kind}) accuracy: {accuracy:.4f}")
1108
+ print(f"ModaNet ({label_kind}) weighted macro F1: {weighted_f1:.4f}")
1109
+ return {"accuracy": accuracy, "weighted_f1": float(weighted_f1)}
1110
+
1111
+
1112
+ def main(
1113
+ selected_tests: set[str],
1114
+ model=None,
1115
+ processor=None,
1116
+ baseline_model=None,
1117
+ baseline_processor=None,
1118
+ ) -> None:
1119
  random.seed(42)
1120
  cfg = resolve_runtime_config()
 
 
 
1121
 
1122
+ if model is None or processor is None:
1123
+ model_path = Path(cfg.main_model_path)
1124
+ if not model_path.exists():
1125
+ raise FileNotFoundError(f"Main model checkpoint not found: {cfg.main_model_path}")
1126
+ print("Loading model...")
1127
+ print(f" device: {cfg.device}")
1128
+ print(f" checkpoint: {cfg.main_model_path}")
1129
+ print(f" dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total={cfg.main_emb_dim}")
1130
+ model, processor = load_main_model(cfg.device, cfg.main_model_path)
1131
+ print("Model loaded.")
1132
+ else:
1133
+ print(f"Using pre-loaded GAP-CLIP model (dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total={cfg.main_emb_dim})")
1134
 
1135
  result_a: Optional[Dict[str, object]] = None
1136
  result_b: Optional[Dict[str, object]] = None
1137
+ result_c: Optional[Dict[str, object]] = None
1138
  baseline_result_a: Optional[Dict[str, object]] = None
1139
  baseline_result_b: Optional[Dict[str, object]] = None
1140
+ baseline_result_c: Optional[Dict[str, object]] = None
1141
+
1142
+ if baseline_model is None or baseline_processor is None:
1143
+ if any(t in selected_tests for t in ("A", "B", "C", "D")):
1144
+ print("\nLoading baseline model (patrickjohncyh/fashion-clip)...")
1145
+ baseline_name = "patrickjohncyh/fashion-clip"
1146
+ baseline_processor = CLIPProcessor.from_pretrained(baseline_name)
1147
+ baseline_model = CLIPModelTransformers.from_pretrained(baseline_name).to(cfg.device)
1148
+ baseline_model.eval()
1149
+ print("Baseline model loaded.")
 
1150
 
1151
  if "A" in selected_tests:
1152
  result_a = run_test_a(
 
1182
  num_printed=DEFAULT_NUM_PRINTED,
1183
  test_name="Baseline Test B",
1184
  )
1185
+ if "C" in selected_tests:
1186
+ result_c = run_test_c(
1187
  model,
1188
  processor,
1189
  cfg,
 
1191
  num_printed=DEFAULT_NUM_PRINTED,
1192
  )
1193
  if baseline_model is not None and baseline_processor is not None:
1194
+ baseline_result_c = run_test_c(
1195
  baseline_model,
1196
  baseline_processor,
1197
  cfg,
1198
  num_examples=DEFAULT_NUM_EXAMPLES,
1199
  num_printed=DEFAULT_NUM_PRINTED,
1200
+ test_name="Baseline Test C",
1201
  )
1202
 
1203
+ if "D" in selected_tests:
 
 
 
 
 
 
 
 
1204
  assert baseline_model is not None and baseline_processor is not None
1205
 
1206
+ print("\n" + "=" * 120)
1207
+ print("Test D Notebook-style zero-shot accuracy")
1208
+ print("=" * 120)
1209
+ d_results: Dict[str, Dict[str, Optional[Dict[str, float]]]] = {
1210
+ "Fashion-MNIST": {
1211
+ "gap": {"accuracy": zero_shot_fashion_mnist(model=model, processor=processor, device=cfg.device, batch_size=64)},
1212
+ "base": {"accuracy": zero_shot_fashion_mnist(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64)},
1213
+ },
1214
+ "KAGL Marqo": {
1215
+ "gap": zero_shot_kagl(model=model, processor=processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES),
1216
+ "base": zero_shot_kagl(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES),
1217
+ },
1218
+ "Internal dataset": {
1219
+ "gap": zero_shot_internal(model=model, processor=processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES),
1220
+ "base": zero_shot_internal(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES),
1221
+ },
1222
+ "ModaNet": {
1223
+ "gap": zero_shot_modanet(model=model, processor=processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, use_gap_labels=True),
1224
+ "base": zero_shot_modanet(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, use_gap_labels=True),
1225
+ },
1226
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1227
 
1228
+ print("\n" + "-" * 120)
1229
+ print("Test D summary")
1230
+ print("-" * 120)
1231
+ summary_rows: List[List[str]] = []
1232
+ for ds in ["Fashion-MNIST", "KAGL Marqo", "ModaNet", "Internal dataset"]:
1233
+ gap_result = d_results[ds]["gap"]
1234
+ base_result = d_results[ds]["base"]
1235
+ gap_acc = None if gap_result is None else gap_result.get("accuracy")
1236
+ base_acc = None if base_result is None else base_result.get("accuracy")
1237
+ summary_rows.append([
1238
+ ds,
1239
+ f"{gap_acc:.2%}" if gap_acc is not None else "N/A",
1240
+ f"{base_acc:.2%}" if base_acc is not None else "N/A",
1241
+ ])
1242
+ print_table(
1243
+ "Test D — zero-shot accuracy (notebook protocol)",
1244
+ ["Dataset", "GAP-CLIP", "Fashion-CLIP (baseline)"],
1245
+ summary_rows,
1246
+ )
1247
  print("\n" + "=" * 120)
1248
  print("Final Summary")
1249
  print("=" * 120)
 
1258
  print(f"Test B full512 accuracy: {float(result_b['accuracy_full512']):.2%}")
1259
  if baseline_result_b is not None:
1260
  print(f"Baseline Test B full512 accuracy: {float(baseline_result_b['accuracy_full512']):.2%}")
1261
+ if result_c is not None:
1262
+ print(f"Test C overall: {format_bool(bool(result_c['overall']))}")
1263
+ print(f" pass rate: {float(result_c['pass_rate']):.2%}")
1264
+ print(f" avg color_match={float(result_c['avg_color_match']):.4f} vs cross={float(result_c['avg_color_cross']):.4f}")
1265
+ print(f" avg hier_match={float(result_c['avg_hier_match']):.4f} vs cross={float(result_c['avg_hier_cross']):.4f}")
1266
+ if baseline_result_c is not None:
1267
+ print(f"Baseline Test C overall: {format_bool(bool(baseline_result_c['overall']))}")
1268
+ print(f" baseline pass rate: {float(baseline_result_c['pass_rate']):.2%}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1269
 
1270
  if result_a is not None:
1271
+ assert float(result_a["pass_rate"]) >= 0.95, (
1272
+ f"Test A failed: pass rate {float(result_a['pass_rate']):.2%} < 95%."
1273
+ )
1274
  if result_b is not None:
1275
+ assert float(result_b["pass_rate"]) >= 0.95, (
1276
+ f"Test B failed: pass rate {float(result_b['pass_rate']):.2%} < 95%."
1277
+ )
1278
+ if result_c is not None:
1279
+ assert float(result_c["pass_rate"]) >= 0.95, (
1280
+ f"Test C failed: subspace decomposition pass rate {float(result_c['pass_rate']):.2%} < 95%."
1281
  )
1282
 
1283
  print("\nAll embedding-structure tests passed.")
1284
 
1285
 
1286
  if __name__ == "__main__":
1287
+ parser = argparse.ArgumentParser(description="Embedding structure evaluation")
1288
+ parser.add_argument("--tests", default="ABCD", help="Which tests to run, e.g. 'C' or 'ABCD'")
1289
+ parser.add_argument("--num-examples", type=int, default=None, help="Override DEFAULT_NUM_EXAMPLES")
1290
+ args = parser.parse_args()
1291
+ if args.num_examples is not None:
1292
+ DEFAULT_NUM_EXAMPLES = args.num_examples
1293
+ selected_tests = set(args.tests.upper())
1294
  main(selected_tests)