Upload evaluation/sec536_embedding_structure.py with huggingface_hub
Browse files- evaluation/sec536_embedding_structure.py +496 -749
evaluation/sec536_embedding_structure.py
CHANGED
|
@@ -21,12 +21,7 @@ designed for, and tests zero-shot vision-language alignment.
|
|
| 21 |
for items sharing a color but differing in category.
|
| 22 |
Expected result: 1000/1000 pass.
|
| 23 |
|
| 24 |
-
Test
|
| 25 |
-
Each image is used as a query; the highest-scoring text label (cosine in
|
| 26 |
-
shared latent space) is the predicted class. Accuracy is computed across
|
| 27 |
-
three datasets (Fashion-MNIST, KAGL Marqo, Internal).
|
| 28 |
-
|
| 29 |
-
Test D — Subspace Decomposition Consistency:
|
| 30 |
Encode a full description (e.g. "red dress in cotton"), a standalone color
|
| 31 |
("red"), and a standalone hierarchy ("dress"). Verify that:
|
| 32 |
- The color subspace (first 16D) of the full embedding is more similar
|
|
@@ -35,6 +30,11 @@ designed for, and tests zero-shot vision-language alignment.
|
|
| 35 |
similar to the hierarchy-only embedding than to the color-only embedding.
|
| 36 |
Expected result: 1000/1000 pass.
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
Paper reference: Section 5.3.6 and Table 4.
|
| 39 |
|
| 40 |
Run directly:
|
|
@@ -51,6 +51,9 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
| 51 |
from dataclasses import dataclass
|
| 52 |
from pathlib import Path
|
| 53 |
import random
|
|
|
|
|
|
|
|
|
|
| 54 |
from typing import Dict, List, Optional, Sequence, Tuple
|
| 55 |
|
| 56 |
import numpy as np
|
|
@@ -62,16 +65,39 @@ import torch.nn.functional as F
|
|
| 62 |
from io import BytesIO
|
| 63 |
from PIL import Image
|
| 64 |
from torchvision import transforms
|
|
|
|
|
|
|
|
|
|
| 65 |
from transformers import CLIPModel as CLIPModelTransformers
|
| 66 |
from transformers import CLIPProcessor
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
@dataclass
|
| 70 |
class RuntimeConfig:
|
| 71 |
-
color_emb_dim: int =
|
| 72 |
-
hierarchy_emb_dim: int =
|
| 73 |
-
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
DEFAULT_NUM_EXAMPLES = 10000
|
| 77 |
DEFAULT_NUM_PRINTED = 3
|
|
@@ -106,6 +132,7 @@ def resolve_runtime_config() -> RuntimeConfig:
|
|
| 106 |
|
| 107 |
cfg.color_emb_dim = getattr(config, "color_emb_dim", cfg.color_emb_dim)
|
| 108 |
cfg.hierarchy_emb_dim = getattr(config, "hierarchy_emb_dim", cfg.hierarchy_emb_dim)
|
|
|
|
| 109 |
cfg.main_model_path = getattr(config, "main_model_path", cfg.main_model_path)
|
| 110 |
cfg.device = getattr(config, "device", cfg.device)
|
| 111 |
except Exception:
|
|
@@ -120,27 +147,50 @@ def resolve_runtime_config() -> RuntimeConfig:
|
|
| 120 |
|
| 121 |
|
| 122 |
def load_main_model(device: torch.device, main_model_path: str) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
|
| 123 |
-
"""Load GAP-CLIP
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
def get_text_embedding(
|
| 132 |
-
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, text: str
|
| 133 |
-
) -> torch.Tensor:
|
| 134 |
-
"""Extract normalized text embedding for a single query."""
|
| 135 |
-
text_inputs = processor(text=[text], padding=True, return_tensors="pt")
|
| 136 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
with torch.no_grad():
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
def cosine(a: torch.Tensor, b: torch.Tensor) -> float:
|
|
@@ -179,8 +229,7 @@ def run_test_a(
|
|
| 179 |
cfg: RuntimeConfig,
|
| 180 |
num_examples: int,
|
| 181 |
num_printed: int,
|
| 182 |
-
test_name: str = "Test A",
|
| 183 |
-
) -> Dict[str, bool]:
|
| 184 |
"""
|
| 185 |
A: different colors + same hierarchy.
|
| 186 |
Expect hierarchy subspace to be more similar than color subspace.
|
|
@@ -292,8 +341,7 @@ def run_test_b(
|
|
| 292 |
cfg: RuntimeConfig,
|
| 293 |
num_examples: int,
|
| 294 |
num_printed: int,
|
| 295 |
-
test_name: str = "Test B",
|
| 296 |
-
) -> Dict[str, bool]:
|
| 297 |
"""
|
| 298 |
B: same color + different hierarchies.
|
| 299 |
Expect similarity in first16 (color) to be higher than full512.
|
|
@@ -398,16 +446,15 @@ def run_test_b(
|
|
| 398 |
|
| 399 |
|
| 400 |
|
| 401 |
-
def
|
| 402 |
model: CLIPModelTransformers,
|
| 403 |
processor: CLIPProcessor,
|
| 404 |
cfg: RuntimeConfig,
|
| 405 |
num_examples: int,
|
| 406 |
num_printed: int,
|
| 407 |
-
test_name: str = "Test
|
| 408 |
-
) -> Dict[str, object]:
|
| 409 |
"""
|
| 410 |
-
|
| 411 |
Encode a full description (e.g. "red dress in cotton"), a standalone color
|
| 412 |
("red"), and a standalone hierarchy ("dress"). Then verify:
|
| 413 |
- The color subspace (first 16D) of the full embedding aligns with the
|
|
@@ -568,36 +615,26 @@ def fashion_mnist_pixels_to_tensor(pixel_values: np.ndarray, image_size: int = 2
|
|
| 568 |
def get_image_embedding(
|
| 569 |
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image_tensor: torch.Tensor
|
| 570 |
) -> torch.Tensor:
|
|
|
|
| 571 |
image_tensor = image_tensor.unsqueeze(0).to(device)
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
return image_features.squeeze(0)
|
| 577 |
|
| 578 |
|
| 579 |
def get_image_embedding_from_pil(
|
| 580 |
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image: Image.Image
|
| 581 |
) -> torch.Tensor:
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
with torch.no_grad():
|
| 585 |
-
vision_outputs = model.vision_model(**image_inputs)
|
| 586 |
-
image_features = model.visual_projection(vision_outputs.pooler_output)
|
| 587 |
-
image_features = F.normalize(image_features, dim=-1)
|
| 588 |
-
return image_features.squeeze(0)
|
| 589 |
|
| 590 |
|
| 591 |
def get_text_embeddings_batch(
|
| 592 |
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, texts: List[str]
|
| 593 |
) -> torch.Tensor:
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
with torch.no_grad():
|
| 597 |
-
text_outputs = model.text_model(**text_inputs)
|
| 598 |
-
text_features = model.text_projection(text_outputs.pooler_output)
|
| 599 |
-
text_features = F.normalize(text_features, dim=-1)
|
| 600 |
-
return text_features
|
| 601 |
|
| 602 |
|
| 603 |
def get_prompt_ensembled_text_embeddings(
|
|
@@ -678,79 +715,187 @@ def get_adaptive_label_prior(labels: List[str]) -> Tuple[torch.Tensor, float]:
|
|
| 678 |
return probs, recommended_weight
|
| 679 |
|
| 680 |
|
| 681 |
-
def
|
| 682 |
-
model
|
| 683 |
-
processor
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
df = pd.read_csv(csv_path)
|
| 700 |
-
df = df.sample(n=min(num_examples, len(df)), random_state=42).reset_index(drop=True)
|
| 701 |
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
text_embs =
|
| 705 |
|
| 706 |
-
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
|
| 707 |
-
rows: List[List[str]] = []
|
| 708 |
-
failed_rows: List[List[str]] = []
|
| 709 |
correct = 0
|
|
|
|
| 710 |
|
| 711 |
-
for
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
img_emb = get_image_embedding(model, processor, cfg.device, img_tensor)
|
| 719 |
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
predicted = candidate_labels[best_idx]
|
| 723 |
-
best_sim = sims[best_idx].item()
|
| 724 |
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
|
| 729 |
-
rows.append([
|
| 730 |
-
str(idx),
|
| 731 |
-
ground_truth,
|
| 732 |
-
predicted,
|
| 733 |
-
f"{best_sim:.4f}",
|
| 734 |
-
format_bool(ok),
|
| 735 |
-
])
|
| 736 |
-
if not ok:
|
| 737 |
-
failed_rows.append([
|
| 738 |
-
str(idx),
|
| 739 |
-
ground_truth,
|
| 740 |
-
predicted,
|
| 741 |
-
f"{best_sim:.4f}",
|
| 742 |
-
])
|
| 743 |
|
| 744 |
-
accuracy = correct / len(df)
|
| 745 |
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 752 |
|
| 753 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
|
| 755 |
|
| 756 |
def normalize_hierarchy_label(raw_label: str) -> str:
|
|
@@ -805,544 +950,203 @@ def normalize_hierarchy_label(raw_label: str) -> str:
|
|
| 805 |
"innerwear": "underwear",
|
| 806 |
"loungewear and nightwear": "underwear",
|
| 807 |
"saree": "dress",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 808 |
}
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
model_path = Path(getattr(_cfg, "hierarchy_model_path", "models/hierarchy_model.pth"))
|
| 830 |
-
if not model_path.exists():
|
| 831 |
-
return None
|
| 832 |
-
try:
|
| 833 |
-
checkpoint = torch.load(str(model_path), map_location=device)
|
| 834 |
-
hierarchy_classes = checkpoint.get("hierarchy_classes", [])
|
| 835 |
-
if not hierarchy_classes:
|
| 836 |
-
return None
|
| 837 |
-
_model = _HierarchyModel(
|
| 838 |
-
num_hierarchy_classes=len(hierarchy_classes),
|
| 839 |
-
embed_dim=getattr(_cfg, "hierarchy_emb_dim", 64),
|
| 840 |
-
).to(device)
|
| 841 |
-
_model.load_state_dict(checkpoint["model_state"])
|
| 842 |
-
_model.set_hierarchy_extractor(_HierarchyExtractor(hierarchy_classes, verbose=False))
|
| 843 |
-
_model.eval()
|
| 844 |
-
return _model
|
| 845 |
-
except Exception:
|
| 846 |
-
return None
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
def evaluate_zero_shot_gap(
|
| 850 |
-
model: CLIPModelTransformers,
|
| 851 |
-
processor: CLIPProcessor,
|
| 852 |
-
device: torch.device,
|
| 853 |
-
samples: List[Tuple[Image.Image, str]],
|
| 854 |
-
candidate_labels: List[str],
|
| 855 |
-
title_prefix: str,
|
| 856 |
-
num_printed: int,
|
| 857 |
-
color_emb_dim: int = 16,
|
| 858 |
-
hierarchy_emb_dim: int = 64,
|
| 859 |
-
hierarchy_model=None,
|
| 860 |
-
) -> Dict[str, Optional[float]]:
|
| 861 |
-
if len(samples) == 0:
|
| 862 |
-
print(f" Skipping {title_prefix}: no valid samples")
|
| 863 |
-
return {"accuracy_c1": None, "strategy": None}
|
| 864 |
-
|
| 865 |
-
# Strategy 1 (baseline prompt) and prompt-ensemble embeddings.
|
| 866 |
-
base_templates = ["a photo of a {label}"]
|
| 867 |
-
ensemble_templates = [
|
| 868 |
-
"a photo of a {label}",
|
| 869 |
-
"a product photo of a {label}",
|
| 870 |
-
"a studio photo of a {label}",
|
| 871 |
-
"a fashion item: {label}",
|
| 872 |
-
"an image of a {label}",
|
| 873 |
-
]
|
| 874 |
-
text_embs_single = get_prompt_ensembled_text_embeddings(
|
| 875 |
-
model=model,
|
| 876 |
-
processor=processor,
|
| 877 |
-
device=device,
|
| 878 |
-
labels=candidate_labels,
|
| 879 |
-
templates=base_templates,
|
| 880 |
-
)
|
| 881 |
-
text_embs_ensemble = get_prompt_ensembled_text_embeddings(
|
| 882 |
-
model=model,
|
| 883 |
-
processor=processor,
|
| 884 |
-
device=device,
|
| 885 |
-
labels=candidate_labels,
|
| 886 |
-
templates=ensemble_templates,
|
| 887 |
-
)
|
| 888 |
-
|
| 889 |
-
# Precompute image embeddings once for C1.
|
| 890 |
-
image_embs: List[torch.Tensor] = []
|
| 891 |
-
for image, _ in samples:
|
| 892 |
-
image_embs.append(get_image_embedding_from_pil(model, processor, device, image))
|
| 893 |
-
image_embs_tensor = torch.stack(image_embs, dim=0)
|
| 894 |
-
|
| 895 |
-
# Similarity matrices (N images x C labels)
|
| 896 |
-
sims_single = image_embs_tensor @ text_embs_single.T
|
| 897 |
-
sims_ensemble = image_embs_tensor @ text_embs_ensemble.T
|
| 898 |
-
|
| 899 |
-
# Calibration and prior terms.
|
| 900 |
-
class_bias = sims_ensemble.mean(dim=0, keepdim=True)
|
| 901 |
-
class_prior = get_internal_label_prior(candidate_labels).to(device)
|
| 902 |
-
log_prior = torch.log(class_prior + 1e-8).unsqueeze(0)
|
| 903 |
-
|
| 904 |
-
# Baseline inference-time strategies (full 512-d embedding).
|
| 905 |
-
strategy_scores: Dict[str, torch.Tensor] = {
|
| 906 |
-
"single_prompt": sims_single,
|
| 907 |
-
"prompt_ensemble": sims_ensemble,
|
| 908 |
-
"ensemble_plus_calibration": sims_ensemble - 0.2 * class_bias,
|
| 909 |
-
"ensemble_plus_prior": sims_ensemble + 0.15 * log_prior,
|
| 910 |
-
"ensemble_calibration_plus_prior": sims_ensemble - 0.2 * class_bias + 0.15 * log_prior,
|
| 911 |
-
}
|
| 912 |
-
|
| 913 |
-
# Extended prompt ensemble for broader category coverage.
|
| 914 |
-
extended_templates = [
|
| 915 |
-
"a photo of a {label}",
|
| 916 |
-
"a product photo of a {label}",
|
| 917 |
-
"a studio photo of a {label}",
|
| 918 |
-
"a fashion item: {label}",
|
| 919 |
-
"an image of a {label}",
|
| 920 |
-
"{label}",
|
| 921 |
-
"a picture of a {label}",
|
| 922 |
-
"this is a {label}",
|
| 923 |
-
"a fashion product: {label}",
|
| 924 |
-
"a {label} clothing item",
|
| 925 |
]
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
)
|
| 930 |
-
sims_extended = image_embs_tensor @ text_embs_extended.T
|
| 931 |
-
|
| 932 |
-
# Subspace: exclude color dimensions (keep hierarchy + residual).
|
| 933 |
-
hier_end = color_emb_dim + hierarchy_emb_dim
|
| 934 |
-
img_no_color = F.normalize(image_embs_tensor[:, color_emb_dim:], dim=-1)
|
| 935 |
-
text_ext_no_color = F.normalize(text_embs_extended[:, color_emb_dim:], dim=-1)
|
| 936 |
-
text_ens_no_color = F.normalize(text_embs_ensemble[:, color_emb_dim:], dim=-1)
|
| 937 |
-
sims_no_color = img_no_color @ text_ens_no_color.T
|
| 938 |
-
sims_no_color_ext = img_no_color @ text_ext_no_color.T
|
| 939 |
-
|
| 940 |
-
# Subspace: hierarchy-only dimensions.
|
| 941 |
-
img_hier = F.normalize(image_embs_tensor[:, color_emb_dim:hier_end], dim=-1)
|
| 942 |
-
text_ens_hier = F.normalize(text_embs_ensemble[:, color_emb_dim:hier_end], dim=-1)
|
| 943 |
-
text_ext_hier = F.normalize(text_embs_extended[:, color_emb_dim:hier_end], dim=-1)
|
| 944 |
-
sims_hier_ens = img_hier @ text_ens_hier.T
|
| 945 |
-
sims_hier_ext = img_hier @ text_ext_hier.T
|
| 946 |
-
|
| 947 |
-
# Adaptive prior (reduces influence for out-of-domain label sets).
|
| 948 |
-
adaptive_prior, adaptive_weight = get_adaptive_label_prior(candidate_labels)
|
| 949 |
-
adaptive_prior = adaptive_prior.to(device)
|
| 950 |
-
log_adaptive_prior = torch.log(adaptive_prior + 1e-8).unsqueeze(0)
|
| 951 |
-
|
| 952 |
-
class_bias_no_color = sims_no_color.mean(dim=0, keepdim=True)
|
| 953 |
-
|
| 954 |
-
strategy_scores.update({
|
| 955 |
-
"extended_ensemble": sims_extended,
|
| 956 |
-
"no_color_ensemble": sims_no_color,
|
| 957 |
-
"no_color_extended": sims_no_color_ext,
|
| 958 |
-
"hierarchy_only_ensemble": sims_hier_ens,
|
| 959 |
-
"hierarchy_only_extended": sims_hier_ext,
|
| 960 |
-
"no_color_calibrated": sims_no_color - 0.2 * class_bias_no_color,
|
| 961 |
-
"no_color_adaptive_prior": sims_no_color + adaptive_weight * log_adaptive_prior,
|
| 962 |
-
"no_color_ext_adaptive_prior": sims_no_color_ext + adaptive_weight * log_adaptive_prior,
|
| 963 |
-
"extended_adaptive_prior": sims_extended + adaptive_weight * log_adaptive_prior,
|
| 964 |
-
})
|
| 965 |
-
|
| 966 |
-
# Weighted embeddings: amplify hierarchy dims relative to residual.
|
| 967 |
-
for amp_factor in (2.0, 4.0):
|
| 968 |
-
weights = torch.ones(image_embs_tensor.shape[1], device=device)
|
| 969 |
-
weights[:color_emb_dim] = 0.0
|
| 970 |
-
weights[color_emb_dim:hier_end] = amp_factor
|
| 971 |
-
weighted_img = F.normalize(image_embs_tensor * weights.unsqueeze(0), dim=-1)
|
| 972 |
-
weighted_text = F.normalize(text_embs_extended * weights.unsqueeze(0), dim=-1)
|
| 973 |
-
tag = f"weighted_hier_{amp_factor:.0f}x"
|
| 974 |
-
strategy_scores[tag] = weighted_img @ weighted_text.T
|
| 975 |
-
|
| 976 |
-
# Hierarchy model direct strategy (uses dedicated hierarchy encoder).
|
| 977 |
-
if hierarchy_model is not None:
|
| 978 |
-
hier_text_embs: List[torch.Tensor] = []
|
| 979 |
-
known_label_mask: List[bool] = []
|
| 980 |
-
for label in candidate_labels:
|
| 981 |
-
try:
|
| 982 |
-
emb = hierarchy_model.get_text_embeddings(label).squeeze(0)
|
| 983 |
-
hier_text_embs.append(emb)
|
| 984 |
-
known_label_mask.append(True)
|
| 985 |
-
except (ValueError, Exception):
|
| 986 |
-
hier_text_embs.append(text_ext_hier[candidate_labels.index(label)])
|
| 987 |
-
known_label_mask.append(False)
|
| 988 |
-
hier_text_matrix = F.normalize(torch.stack(hier_text_embs).to(device), dim=-1)
|
| 989 |
-
sims_hier_model = img_hier @ hier_text_matrix.T
|
| 990 |
-
strategy_scores["hierarchy_model_direct"] = sims_hier_model
|
| 991 |
-
class_bias_hier = sims_hier_model.mean(dim=0, keepdim=True)
|
| 992 |
-
strategy_scores["hier_model_calibrated"] = sims_hier_model - 0.2 * class_bias_hier
|
| 993 |
-
strategy_scores["hier_model_adaptive_prior"] = sims_hier_model + adaptive_weight * log_adaptive_prior
|
| 994 |
-
|
| 995 |
-
# Hybrid: hierarchy model scores for known labels, CLIP for unknown.
|
| 996 |
-
hybrid_scores = sims_no_color_ext.clone()
|
| 997 |
-
for label_idx, is_known in enumerate(known_label_mask):
|
| 998 |
-
if is_known:
|
| 999 |
-
hybrid_scores[:, label_idx] = sims_hier_model[:, label_idx]
|
| 1000 |
-
strategy_scores["hybrid_hier_clip"] = hybrid_scores
|
| 1001 |
-
|
| 1002 |
-
# Blended: z-score-normalised mix of hierarchy and full-space scores.
|
| 1003 |
-
hier_mu = sims_hier_model.mean()
|
| 1004 |
-
hier_std = sims_hier_model.std() + 1e-8
|
| 1005 |
-
full_mu = sims_extended.mean()
|
| 1006 |
-
full_std = sims_extended.std() + 1e-8
|
| 1007 |
-
hier_z = (sims_hier_model - hier_mu) / hier_std
|
| 1008 |
-
full_z = (sims_extended - full_mu) / full_std
|
| 1009 |
-
for alpha in (0.3, 0.5, 0.7):
|
| 1010 |
-
strategy_scores[f"blend_hier_full_{alpha:.1f}"] = alpha * hier_z + (1 - alpha) * full_z
|
| 1011 |
-
|
| 1012 |
-
# Select best strategy for C1.
|
| 1013 |
-
|
| 1014 |
-
best_strategy_c1 = "single_prompt"
|
| 1015 |
-
best_acc_c1 = -1.0
|
| 1016 |
-
best_scores_c1 = sims_single
|
| 1017 |
-
|
| 1018 |
-
# Track per-strategy accuracies and weighted-F1 for fair comparison.
|
| 1019 |
-
all_strategy_acc_c1: Dict[str, float] = {}
|
| 1020 |
-
all_strategy_wf1_c1: Dict[str, float] = {}
|
| 1021 |
-
ground_truths = [gt for _, gt in samples]
|
| 1022 |
-
|
| 1023 |
-
for strategy_name, score_mat in strategy_scores.items():
|
| 1024 |
-
pred_idx = score_mat.argmax(dim=1).tolist()
|
| 1025 |
-
preds = [candidate_labels[i] for i in pred_idx]
|
| 1026 |
-
correct = sum(1 for p, g in zip(preds, ground_truths) if p == g)
|
| 1027 |
-
acc = correct / len(samples)
|
| 1028 |
-
wf1 = f1_score(ground_truths, preds, average="weighted", zero_division=0)
|
| 1029 |
-
|
| 1030 |
-
all_strategy_acc_c1[strategy_name] = acc
|
| 1031 |
-
all_strategy_wf1_c1[strategy_name] = wf1
|
| 1032 |
-
|
| 1033 |
-
if acc > best_acc_c1:
|
| 1034 |
-
best_acc_c1 = acc
|
| 1035 |
-
best_strategy_c1 = strategy_name
|
| 1036 |
-
best_scores_c1 = score_mat
|
| 1037 |
-
|
| 1038 |
-
best_wf1_c1 = all_strategy_wf1_c1[best_strategy_c1]
|
| 1039 |
-
print(f"{title_prefix} selected C1 strategy: {best_strategy_c1} (acc={best_acc_c1:.2%}, wF1={best_wf1_c1:.2%})")
|
| 1040 |
-
|
| 1041 |
-
# C1: image -> all texts (classification)
|
| 1042 |
-
rows: List[List[str]] = []
|
| 1043 |
-
correct = 0
|
| 1044 |
-
all_preds: List[str] = []
|
| 1045 |
-
|
| 1046 |
-
for idx, (_, ground_truth) in enumerate(samples):
|
| 1047 |
-
sims = best_scores_c1[idx]
|
| 1048 |
-
best_idx = int(sims.argmax().item())
|
| 1049 |
-
predicted = candidate_labels[best_idx]
|
| 1050 |
-
best_sim = float(sims[best_idx].item())
|
| 1051 |
-
|
| 1052 |
-
ok = predicted == ground_truth
|
| 1053 |
-
if ok:
|
| 1054 |
-
correct += 1
|
| 1055 |
-
all_preds.append(predicted)
|
| 1056 |
-
|
| 1057 |
-
rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}", format_bool(ok)])
|
| 1058 |
-
|
| 1059 |
-
accuracy_c1 = correct / len(samples)
|
| 1060 |
-
wf1_c1 = f1_score(ground_truths, all_preds, average="weighted", zero_division=0)
|
| 1061 |
-
|
| 1062 |
-
print_table(
|
| 1063 |
-
f"{title_prefix} C1 image->texts (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
|
| 1064 |
-
["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
|
| 1065 |
-
rows[:num_printed],
|
| 1066 |
-
)
|
| 1067 |
-
print(f"{title_prefix} C1 aggregate: {correct}/{len(samples)} correct (acc={accuracy_c1:.2%}, wF1={wf1_c1:.2%})")
|
| 1068 |
-
|
| 1069 |
-
return {
|
| 1070 |
-
"accuracy_c1": accuracy_c1,
|
| 1071 |
-
"wf1_c1": wf1_c1,
|
| 1072 |
-
"strategy": best_strategy_c1,
|
| 1073 |
-
"all_strategy_acc_c1": all_strategy_acc_c1,
|
| 1074 |
-
"all_strategy_wf1_c1": all_strategy_wf1_c1,
|
| 1075 |
-
}
|
| 1076 |
-
|
| 1077 |
|
| 1078 |
-
|
| 1079 |
-
baseline_model: CLIPModelTransformers,
|
| 1080 |
-
baseline_processor: CLIPProcessor,
|
| 1081 |
-
device: torch.device,
|
| 1082 |
-
samples: List[Tuple[Image.Image, str]],
|
| 1083 |
-
candidate_labels: List[str],
|
| 1084 |
-
title_prefix: str,
|
| 1085 |
-
num_printed: int,
|
| 1086 |
-
) -> Dict[str, Optional[float]]:
|
| 1087 |
-
if len(samples) == 0:
|
| 1088 |
-
print(f" Skipping baseline {title_prefix}: no valid samples")
|
| 1089 |
-
return {"accuracy_c1": None}
|
| 1090 |
-
|
| 1091 |
-
candidate_texts = [f"a photo of a {label}" for label in candidate_labels]
|
| 1092 |
-
text_inputs = baseline_processor(text=candidate_texts, return_tensors="pt", padding=True, truncation=True)
|
| 1093 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 1094 |
-
with torch.no_grad():
|
| 1095 |
-
text_embs = baseline_model.get_text_features(**text_inputs)
|
| 1096 |
-
text_embs = F.normalize(text_embs, dim=-1)
|
| 1097 |
-
|
| 1098 |
-
# Precompute image embeddings once for C1.
|
| 1099 |
-
image_embs: List[torch.Tensor] = []
|
| 1100 |
-
for image, _ in samples:
|
| 1101 |
-
image_inputs = baseline_processor(images=[image], return_tensors="pt")
|
| 1102 |
-
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 1103 |
-
with torch.no_grad():
|
| 1104 |
-
img_emb = baseline_model.get_image_features(**image_inputs)
|
| 1105 |
-
img_emb = F.normalize(img_emb, dim=-1)
|
| 1106 |
-
image_embs.append(img_emb.squeeze(0))
|
| 1107 |
-
image_embs_tensor = torch.stack(image_embs, dim=0)
|
| 1108 |
-
|
| 1109 |
-
# C1: image -> all texts (classification)
|
| 1110 |
-
rows: List[List[str]] = []
|
| 1111 |
-
correct = 0
|
| 1112 |
-
all_preds: List[str] = []
|
| 1113 |
-
ground_truths = [gt for _, gt in samples]
|
| 1114 |
-
|
| 1115 |
-
for idx, (_, ground_truth) in enumerate(samples):
|
| 1116 |
-
img_emb = image_embs_tensor[idx].unsqueeze(0)
|
| 1117 |
-
sims = F.cosine_similarity(img_emb, text_embs, dim=1)
|
| 1118 |
-
best_idx = sims.argmax().item()
|
| 1119 |
-
predicted = candidate_labels[best_idx]
|
| 1120 |
-
best_sim = sims[best_idx].item()
|
| 1121 |
-
|
| 1122 |
-
ok = predicted == ground_truth
|
| 1123 |
-
if ok:
|
| 1124 |
-
correct += 1
|
| 1125 |
-
all_preds.append(predicted)
|
| 1126 |
-
|
| 1127 |
-
rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}", format_bool(ok)])
|
| 1128 |
-
|
| 1129 |
-
accuracy_c1 = correct / len(samples)
|
| 1130 |
-
wf1_c1 = f1_score(ground_truths, all_preds, average="weighted", zero_division=0)
|
| 1131 |
-
baseline_title = f"Baseline {title_prefix}"
|
| 1132 |
-
print_table(
|
| 1133 |
-
f"{baseline_title} C1 image->texts (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
|
| 1134 |
-
["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
|
| 1135 |
-
rows[:num_printed],
|
| 1136 |
-
)
|
| 1137 |
-
print(f"{baseline_title} C1 aggregate: {correct}/{len(samples)} correct (acc={accuracy_c1:.2%}, wF1={wf1_c1:.2%})")
|
| 1138 |
|
| 1139 |
-
return {"accuracy_c1": accuracy_c1, "wf1_c1": wf1_c1}
|
| 1140 |
|
| 1141 |
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
""
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
return [], []
|
| 1149 |
-
df = pd.read_csv(FASHION_MNIST_CSV)
|
| 1150 |
-
df = df.sample(n=min(num_examples, len(df)), random_state=42).reset_index(drop=True)
|
| 1151 |
-
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
|
| 1152 |
|
| 1153 |
-
|
| 1154 |
-
|
| 1155 |
-
for _, row in df.iterrows():
|
| 1156 |
-
label_id = int(row["label"])
|
| 1157 |
-
pixels = row[pixel_cols].values.astype(float)
|
| 1158 |
-
img_array = pixels.reshape(28, 28).astype(np.uint8)
|
| 1159 |
-
img_array = np.stack([img_array] * 3, axis=-1)
|
| 1160 |
-
image = Image.fromarray(img_array)
|
| 1161 |
-
baseline_samples.append((image, FASHION_MNIST_ORIGINAL_LABELS.get(label_id, "unknown")))
|
| 1162 |
-
gap_samples.append((image, FASHION_MNIST_LABELS.get(label_id, "unknown")))
|
| 1163 |
-
return baseline_samples, gap_samples
|
| 1164 |
|
| 1165 |
|
| 1166 |
-
def
|
| 1167 |
num_examples: int,
|
| 1168 |
-
) -> Tuple[List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]]]:
|
| 1169 |
-
"""Return (baseline_samples, gap_samples)
|
| 1170 |
-
try:
|
| 1171 |
-
from datasets import load_dataset # type: ignore
|
| 1172 |
-
except Exception:
|
| 1173 |
-
print(" Skipping KAGL Marqo: datasets package not available")
|
| 1174 |
-
return [], []
|
| 1175 |
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1181 |
|
| 1182 |
-
dataset = dataset.shuffle(seed=42).select(range(min(num_examples, len(dataset))))
|
| 1183 |
baseline_samples: List[Tuple[Image.Image, str]] = []
|
| 1184 |
gap_samples: List[Tuple[Image.Image, str]] = []
|
| 1185 |
-
for item in dataset:
|
| 1186 |
-
raw_label = item.get("category2")
|
| 1187 |
-
if raw_label is None:
|
| 1188 |
-
continue
|
| 1189 |
-
native_label = str(raw_label).strip()
|
| 1190 |
-
gap_label = normalize_hierarchy_label(native_label)
|
| 1191 |
-
image_obj = item.get("image")
|
| 1192 |
-
if image_obj is None:
|
| 1193 |
-
continue
|
| 1194 |
-
if hasattr(image_obj, "convert"):
|
| 1195 |
-
image = image_obj.convert("RGB")
|
| 1196 |
-
elif isinstance(image_obj, dict) and "bytes" in image_obj:
|
| 1197 |
-
image = Image.open(BytesIO(image_obj["bytes"])).convert("RGB")
|
| 1198 |
-
else:
|
| 1199 |
-
continue
|
| 1200 |
-
baseline_samples.append((image, native_label))
|
| 1201 |
-
gap_samples.append((image, gap_label))
|
| 1202 |
-
return baseline_samples, gap_samples
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
def load_internal_samples(
|
| 1206 |
-
num_examples: int,
|
| 1207 |
-
) -> Tuple[List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]]]:
|
| 1208 |
-
"""Return (baseline_samples, gap_samples) — same labels for both on this dataset."""
|
| 1209 |
-
csv_file = Path(INTERNAL_DATASET_CSV)
|
| 1210 |
-
if not csv_file.exists():
|
| 1211 |
-
print(f" Skipping internal dataset: {INTERNAL_DATASET_CSV} not found")
|
| 1212 |
-
return [], []
|
| 1213 |
|
| 1214 |
-
|
| 1215 |
-
|
| 1216 |
-
print(" Skipping internal dataset: missing 'hierarchy' column")
|
| 1217 |
-
return [], []
|
| 1218 |
-
|
| 1219 |
-
df = df.dropna(subset=["hierarchy", "image_url"]).sample(frac=1.0, random_state=42)
|
| 1220 |
-
samples: List[Tuple[Image.Image, str]] = []
|
| 1221 |
-
|
| 1222 |
-
for _, row in df.iterrows():
|
| 1223 |
-
if len(samples) >= num_examples:
|
| 1224 |
break
|
| 1225 |
-
|
| 1226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1227 |
try:
|
| 1228 |
-
|
| 1229 |
-
response.raise_for_status()
|
| 1230 |
-
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 1231 |
-
samples.append((image, ground_truth))
|
| 1232 |
except Exception:
|
| 1233 |
continue
|
| 1234 |
-
return samples, samples
|
| 1235 |
-
|
| 1236 |
-
|
| 1237 |
-
def run_test_c_baseline_fashion_clip(
|
| 1238 |
-
device: torch.device,
|
| 1239 |
-
num_examples: int,
|
| 1240 |
-
num_printed: int,
|
| 1241 |
-
csv_path: str = FASHION_MNIST_CSV,
|
| 1242 |
-
) -> Dict[str, Optional[float]]:
|
| 1243 |
-
"""
|
| 1244 |
-
Same zero-shot protocol as Test C, but using baseline Fashion-CLIP.
|
| 1245 |
-
"""
|
| 1246 |
-
csv_file = Path(csv_path)
|
| 1247 |
-
if not csv_file.exists():
|
| 1248 |
-
print(f" Skipping Baseline Test C: {csv_path} not found")
|
| 1249 |
-
return {"accuracy": None}
|
| 1250 |
-
|
| 1251 |
-
print("\nLoading baseline model (patrickjohncyh/fashion-clip)...")
|
| 1252 |
-
baseline_name = "patrickjohncyh/fashion-clip"
|
| 1253 |
-
baseline_processor = CLIPProcessor.from_pretrained(baseline_name)
|
| 1254 |
-
baseline_model = CLIPModelTransformers.from_pretrained(baseline_name).to(device)
|
| 1255 |
-
baseline_model.eval()
|
| 1256 |
-
print("Baseline model loaded.")
|
| 1257 |
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
|
| 1261 |
-
|
| 1262 |
-
|
| 1263 |
-
|
| 1264 |
-
text_inputs = baseline_processor(text=candidate_texts, return_tensors="pt", padding=True, truncation=True)
|
| 1265 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 1266 |
-
with torch.no_grad():
|
| 1267 |
-
text_embs = baseline_model.get_text_features(**text_inputs)
|
| 1268 |
-
text_embs = F.normalize(text_embs, dim=-1)
|
| 1269 |
-
|
| 1270 |
-
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
|
| 1271 |
-
rows: List[List[str]] = []
|
| 1272 |
-
failed_rows: List[List[str]] = []
|
| 1273 |
-
correct = 0
|
| 1274 |
-
|
| 1275 |
-
for idx in range(len(df)):
|
| 1276 |
-
row = df.iloc[idx]
|
| 1277 |
-
label_id = int(row["label"])
|
| 1278 |
-
ground_truth = FASHION_MNIST_ORIGINAL_LABELS.get(label_id, "unknown")
|
| 1279 |
-
|
| 1280 |
-
pixels = row[pixel_cols].values.astype(float)
|
| 1281 |
-
img_array = pixels.reshape(28, 28).astype(np.uint8)
|
| 1282 |
-
img_array = np.stack([img_array] * 3, axis=-1)
|
| 1283 |
-
image = Image.fromarray(img_array)
|
| 1284 |
-
|
| 1285 |
-
image_inputs = baseline_processor(images=[image], return_tensors="pt")
|
| 1286 |
-
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 1287 |
-
with torch.no_grad():
|
| 1288 |
-
img_emb = baseline_model.get_image_features(**image_inputs)
|
| 1289 |
-
img_emb = F.normalize(img_emb, dim=-1)
|
| 1290 |
-
|
| 1291 |
-
sims = F.cosine_similarity(img_emb, text_embs, dim=1)
|
| 1292 |
-
best_idx = sims.argmax().item()
|
| 1293 |
-
predicted = candidate_labels[best_idx]
|
| 1294 |
-
best_sim = sims[best_idx].item()
|
| 1295 |
-
|
| 1296 |
-
ok = predicted == ground_truth
|
| 1297 |
-
if ok:
|
| 1298 |
-
correct += 1
|
| 1299 |
-
|
| 1300 |
-
rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}", format_bool(ok)])
|
| 1301 |
-
if not ok:
|
| 1302 |
-
failed_rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}"])
|
| 1303 |
-
|
| 1304 |
-
accuracy = correct / len(df)
|
| 1305 |
-
|
| 1306 |
-
print_table(
|
| 1307 |
-
f"Baseline Test C (Fashion-CLIP): zero-shot (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
|
| 1308 |
-
["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
|
| 1309 |
-
rows[:num_printed],
|
| 1310 |
-
)
|
| 1311 |
-
print(f"Baseline Test C aggregate: {correct}/{len(df)} correct ({accuracy:.2%})")
|
| 1312 |
-
|
| 1313 |
-
return {"accuracy": accuracy}
|
| 1314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1315 |
|
| 1316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1317 |
random.seed(42)
|
| 1318 |
cfg = resolve_runtime_config()
|
| 1319 |
-
model_path = Path(cfg.main_model_path)
|
| 1320 |
-
if not model_path.exists():
|
| 1321 |
-
raise FileNotFoundError(f"Main model checkpoint not found: {cfg.main_model_path}")
|
| 1322 |
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1329 |
|
| 1330 |
result_a: Optional[Dict[str, object]] = None
|
| 1331 |
result_b: Optional[Dict[str, object]] = None
|
| 1332 |
-
|
| 1333 |
baseline_result_a: Optional[Dict[str, object]] = None
|
| 1334 |
baseline_result_b: Optional[Dict[str, object]] = None
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
baseline_processor
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
|
| 1345 |
-
print("Baseline model loaded.")
|
| 1346 |
|
| 1347 |
if "A" in selected_tests:
|
| 1348 |
result_a = run_test_a(
|
|
@@ -1378,8 +1182,8 @@ def main(selected_tests: set[str]) -> None:
|
|
| 1378 |
num_printed=DEFAULT_NUM_PRINTED,
|
| 1379 |
test_name="Baseline Test B",
|
| 1380 |
)
|
| 1381 |
-
if "
|
| 1382 |
-
|
| 1383 |
model,
|
| 1384 |
processor,
|
| 1385 |
cfg,
|
|
@@ -1387,83 +1191,59 @@ def main(selected_tests: set[str]) -> None:
|
|
| 1387 |
num_printed=DEFAULT_NUM_PRINTED,
|
| 1388 |
)
|
| 1389 |
if baseline_model is not None and baseline_processor is not None:
|
| 1390 |
-
|
| 1391 |
baseline_model,
|
| 1392 |
baseline_processor,
|
| 1393 |
cfg,
|
| 1394 |
num_examples=DEFAULT_NUM_EXAMPLES,
|
| 1395 |
num_printed=DEFAULT_NUM_PRINTED,
|
| 1396 |
-
test_name="Baseline Test
|
| 1397 |
)
|
| 1398 |
|
| 1399 |
-
|
| 1400 |
-
gap_fixed_c1: Dict[str, Optional[float]] = {}
|
| 1401 |
-
gap_fixed_wf1_c1: Dict[str, Optional[float]] = {}
|
| 1402 |
-
gap_best_wf1_c1: Dict[str, Optional[float]] = {}
|
| 1403 |
-
gap_best_strategy: Dict[str, Optional[str]] = {}
|
| 1404 |
-
base_fixed_c1: Dict[str, Optional[float]] = {}
|
| 1405 |
-
base_fixed_wf1_c1: Dict[str, Optional[float]] = {}
|
| 1406 |
-
|
| 1407 |
-
if "C" in selected_tests:
|
| 1408 |
assert baseline_model is not None and baseline_processor is not None
|
| 1409 |
|
| 1410 |
-
|
| 1411 |
-
|
| 1412 |
-
|
| 1413 |
-
|
| 1414 |
-
|
| 1415 |
-
|
| 1416 |
-
|
| 1417 |
-
|
| 1418 |
-
"KAGL Marqo":
|
| 1419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1420 |
}
|
| 1421 |
-
for dataset_name, (baseline_samples, gap_samples) in datasets_for_c.items():
|
| 1422 |
-
print(f"\n{'=' * 120}")
|
| 1423 |
-
print(f"Test C on {dataset_name}")
|
| 1424 |
-
print(f"{'=' * 120}")
|
| 1425 |
-
print(f"Valid samples: {len(baseline_samples)} (baseline), {len(gap_samples)} (GAP-CLIP)")
|
| 1426 |
-
|
| 1427 |
-
# Baseline uses dataset-native labels (matches published benchmarks).
|
| 1428 |
-
baseline_candidate_labels = sorted(set(label for _, label in baseline_samples))
|
| 1429 |
-
# GAP-CLIP uses its training-vocabulary labels.
|
| 1430 |
-
gap_candidate_labels = sorted(set(label for _, label in gap_samples))
|
| 1431 |
-
print(f"Baseline candidate labels ({len(baseline_candidate_labels)}): {baseline_candidate_labels}")
|
| 1432 |
-
print(f"GAP-CLIP candidate labels ({len(gap_candidate_labels)}): {gap_candidate_labels}")
|
| 1433 |
-
|
| 1434 |
-
# GAP-CLIP: full strategy search with its own label vocabulary.
|
| 1435 |
-
gap_metrics = evaluate_zero_shot_gap(
|
| 1436 |
-
model=model,
|
| 1437 |
-
processor=processor,
|
| 1438 |
-
device=cfg.device,
|
| 1439 |
-
samples=gap_samples,
|
| 1440 |
-
candidate_labels=gap_candidate_labels,
|
| 1441 |
-
title_prefix=f"Test C ({dataset_name}) GAP-CLIP",
|
| 1442 |
-
num_printed=DEFAULT_NUM_PRINTED,
|
| 1443 |
-
color_emb_dim=cfg.color_emb_dim,
|
| 1444 |
-
hierarchy_emb_dim=cfg.hierarchy_emb_dim,
|
| 1445 |
-
hierarchy_model=hierarchy_model_eval,
|
| 1446 |
-
)
|
| 1447 |
-
|
| 1448 |
-
# Baseline: single_prompt with native labels.
|
| 1449 |
-
baseline_metrics = evaluate_zero_shot_baseline(
|
| 1450 |
-
baseline_model=baseline_model,
|
| 1451 |
-
baseline_processor=baseline_processor,
|
| 1452 |
-
device=cfg.device,
|
| 1453 |
-
samples=baseline_samples,
|
| 1454 |
-
candidate_labels=baseline_candidate_labels,
|
| 1455 |
-
title_prefix=f"Test C ({dataset_name}) Baseline",
|
| 1456 |
-
num_printed=DEFAULT_NUM_PRINTED,
|
| 1457 |
-
)
|
| 1458 |
-
|
| 1459 |
-
# Store results.
|
| 1460 |
-
gap_fixed_c1[dataset_name] = gap_metrics.get("all_strategy_acc_c1", {}).get("single_prompt")
|
| 1461 |
-
gap_fixed_wf1_c1[dataset_name] = gap_metrics.get("all_strategy_wf1_c1", {}).get("single_prompt")
|
| 1462 |
-
gap_best_wf1_c1[dataset_name] = gap_metrics.get("wf1_c1")
|
| 1463 |
-
gap_best_strategy[dataset_name] = gap_metrics.get("strategy")
|
| 1464 |
-
base_fixed_c1[dataset_name] = baseline_metrics["accuracy_c1"]
|
| 1465 |
-
base_fixed_wf1_c1[dataset_name] = baseline_metrics.get("wf1_c1")
|
| 1466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1467 |
print("\n" + "=" * 120)
|
| 1468 |
print("Final Summary")
|
| 1469 |
print("=" * 120)
|
|
@@ -1478,70 +1258,37 @@ def main(selected_tests: set[str]) -> None:
|
|
| 1478 |
print(f"Test B full512 accuracy: {float(result_b['accuracy_full512']):.2%}")
|
| 1479 |
if baseline_result_b is not None:
|
| 1480 |
print(f"Baseline Test B full512 accuracy: {float(baseline_result_b['accuracy_full512']):.2%}")
|
| 1481 |
-
if
|
| 1482 |
-
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
|
| 1486 |
-
|
| 1487 |
-
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
print("\n" + "-" * 120)
|
| 1491 |
-
print("Test C — Fair comparison (weighted F1): single_prompt for both models")
|
| 1492 |
-
print("-" * 120)
|
| 1493 |
-
fair_rows: List[List[str]] = []
|
| 1494 |
-
for ds in ["Fashion-MNIST", "KAGL Marqo", "Internal dataset"]:
|
| 1495 |
-
fair_rows.append([
|
| 1496 |
-
ds,
|
| 1497 |
-
_pct(gap_fixed_wf1_c1.get(ds)), _pct(base_fixed_wf1_c1.get(ds)),
|
| 1498 |
-
_gain(gap_fixed_wf1_c1.get(ds), base_fixed_wf1_c1.get(ds)),
|
| 1499 |
-
])
|
| 1500 |
-
print_table(
|
| 1501 |
-
"Test C: single_prompt weighted F1 for both",
|
| 1502 |
-
["Dataset", "GAP-CLIP (wF1)", "Baseline (wF1)", "Gain"],
|
| 1503 |
-
fair_rows,
|
| 1504 |
-
)
|
| 1505 |
-
|
| 1506 |
-
print("\n" + "-" * 120)
|
| 1507 |
-
print("Test C — GAP-CLIP best strategy vs Baseline")
|
| 1508 |
-
print("-" * 120)
|
| 1509 |
-
best_rows: List[List[str]] = []
|
| 1510 |
-
for ds in ["Fashion-MNIST", "KAGL Marqo", "Internal dataset"]:
|
| 1511 |
-
strat = gap_best_strategy.get(ds) or "N/A"
|
| 1512 |
-
best_rows.append([
|
| 1513 |
-
ds,
|
| 1514 |
-
f"{strat}",
|
| 1515 |
-
_pct(gap_best_wf1_c1.get(ds)), _pct(base_fixed_wf1_c1.get(ds)),
|
| 1516 |
-
_gain(gap_best_wf1_c1.get(ds), base_fixed_wf1_c1.get(ds)),
|
| 1517 |
-
])
|
| 1518 |
-
print_table(
|
| 1519 |
-
"Test C: GAP-CLIP best strategy vs Baseline single_prompt",
|
| 1520 |
-
["Dataset", "GAP-CLIP strategy", "GAP-CLIP (wF1)", "Baseline (wF1)", "Gain"],
|
| 1521 |
-
best_rows,
|
| 1522 |
-
)
|
| 1523 |
-
|
| 1524 |
-
if result_d is not None:
|
| 1525 |
-
print(f"Test D overall: {format_bool(bool(result_d['overall']))}")
|
| 1526 |
-
print(f" pass rate: {float(result_d['pass_rate']):.2%}")
|
| 1527 |
-
print(f" avg color_match={float(result_d['avg_color_match']):.4f} vs cross={float(result_d['avg_color_cross']):.4f}")
|
| 1528 |
-
print(f" avg hier_match={float(result_d['avg_hier_match']):.4f} vs cross={float(result_d['avg_hier_cross']):.4f}")
|
| 1529 |
-
if baseline_result_d is not None:
|
| 1530 |
-
print(f"Baseline Test D overall: {format_bool(bool(baseline_result_d['overall']))}")
|
| 1531 |
-
print(f" baseline pass rate: {float(baseline_result_d['pass_rate']):.2%}")
|
| 1532 |
|
| 1533 |
if result_a is not None:
|
| 1534 |
-
assert
|
|
|
|
|
|
|
| 1535 |
if result_b is not None:
|
| 1536 |
-
assert
|
| 1537 |
-
|
| 1538 |
-
|
| 1539 |
-
|
|
|
|
|
|
|
| 1540 |
)
|
| 1541 |
|
| 1542 |
print("\nAll embedding-structure tests passed.")
|
| 1543 |
|
| 1544 |
|
| 1545 |
if __name__ == "__main__":
|
| 1546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1547 |
main(selected_tests)
|
|
|
|
| 21 |
for items sharing a color but differing in category.
|
| 22 |
Expected result: 1000/1000 pass.
|
| 23 |
|
| 24 |
+
Test C — Subspace Decomposition Consistency:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
Encode a full description (e.g. "red dress in cotton"), a standalone color
|
| 26 |
("red"), and a standalone hierarchy ("dress"). Verify that:
|
| 27 |
- The color subspace (first 16D) of the full embedding is more similar
|
|
|
|
| 30 |
similar to the hierarchy-only embedding than to the color-only embedding.
|
| 31 |
Expected result: 1000/1000 pass.
|
| 32 |
|
| 33 |
+
Test D — Zero-shot image-to-text classification:
|
| 34 |
+
Each image is used as a query; the highest-scoring text label (cosine in
|
| 35 |
+
shared latent space) is the predicted class. Accuracy is computed across
|
| 36 |
+
three datasets (Fashion-MNIST, KAGL Marqo, Internal).
|
| 37 |
+
|
| 38 |
Paper reference: Section 5.3.6 and Table 4.
|
| 39 |
|
| 40 |
Run directly:
|
|
|
|
| 51 |
from dataclasses import dataclass
|
| 52 |
from pathlib import Path
|
| 53 |
import random
|
| 54 |
+
import sys
|
| 55 |
+
|
| 56 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 57 |
from typing import Dict, List, Optional, Sequence, Tuple
|
| 58 |
|
| 59 |
import numpy as np
|
|
|
|
| 65 |
from io import BytesIO
|
| 66 |
from PIL import Image
|
| 67 |
from torchvision import transforms
|
| 68 |
+
from torchvision import datasets
|
| 69 |
+
from torch.utils.data import DataLoader
|
| 70 |
+
from tqdm import tqdm
|
| 71 |
from transformers import CLIPModel as CLIPModelTransformers
|
| 72 |
from transformers import CLIPProcessor
|
| 73 |
|
| 74 |
+
from training.hierarchy_model import HierarchyExtractor
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
import config as project_config # type: ignore
|
| 78 |
+
except Exception:
|
| 79 |
+
project_config = None
|
| 80 |
+
|
| 81 |
+
DEFAULT_COLOR_EMB_DIM = getattr(project_config, "color_emb_dim", 16)
|
| 82 |
+
DEFAULT_HIERARCHY_EMB_DIM = getattr(project_config, "hierarchy_emb_dim", 64)
|
| 83 |
+
DEFAULT_MAIN_EMB_DIM = getattr(project_config, "main_emb_dim", 512)
|
| 84 |
+
DEFAULT_MAIN_MODEL_PATH = getattr(project_config, "main_model_path", "models/gap_clip.pth")
|
| 85 |
+
DEFAULT_DEVICE = getattr(project_config, "device", torch.device("cpu"))
|
| 86 |
+
|
| 87 |
+
_HIERARCHY_EXTRACTOR = HierarchyExtractor([
|
| 88 |
+
"accessories", "bodysuits", "bras", "coat", "dress", "jacket",
|
| 89 |
+
"legging", "pant", "polo", "shirt", "shoes", "short", "skirt",
|
| 90 |
+
"socks", "sweater", "swimwear", "top", "underwear",
|
| 91 |
+
], verbose=False)
|
| 92 |
+
|
| 93 |
|
| 94 |
@dataclass
|
| 95 |
class RuntimeConfig:
|
| 96 |
+
color_emb_dim: int = DEFAULT_COLOR_EMB_DIM
|
| 97 |
+
hierarchy_emb_dim: int = DEFAULT_HIERARCHY_EMB_DIM
|
| 98 |
+
main_emb_dim: int = DEFAULT_MAIN_EMB_DIM
|
| 99 |
+
main_model_path: str = DEFAULT_MAIN_MODEL_PATH
|
| 100 |
+
device: torch.device = DEFAULT_DEVICE
|
| 101 |
|
| 102 |
DEFAULT_NUM_EXAMPLES = 10000
|
| 103 |
DEFAULT_NUM_PRINTED = 3
|
|
|
|
| 132 |
|
| 133 |
cfg.color_emb_dim = getattr(config, "color_emb_dim", cfg.color_emb_dim)
|
| 134 |
cfg.hierarchy_emb_dim = getattr(config, "hierarchy_emb_dim", cfg.hierarchy_emb_dim)
|
| 135 |
+
cfg.main_emb_dim = getattr(config, "main_emb_dim", cfg.main_emb_dim)
|
| 136 |
cfg.main_model_path = getattr(config, "main_model_path", cfg.main_model_path)
|
| 137 |
cfg.device = getattr(config, "device", cfg.device)
|
| 138 |
except Exception:
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
def load_main_model(device: torch.device, main_model_path: str) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
|
| 150 |
+
"""Load GAP-CLIP from local checkpoint path only."""
|
| 151 |
+
model_path = Path(main_model_path)
|
| 152 |
+
if not model_path.exists():
|
| 153 |
+
raise FileNotFoundError(f"Main model checkpoint not found: {main_model_path}")
|
| 154 |
+
|
| 155 |
+
clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
| 156 |
+
model = CLIPModelTransformers.from_pretrained(clip_name)
|
| 157 |
+
checkpoint = torch.load(str(model_path), map_location=device)
|
| 158 |
+
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
| 159 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 160 |
+
else:
|
| 161 |
+
model.load_state_dict(checkpoint, strict=False)
|
| 162 |
+
model = model.to(device)
|
| 163 |
+
model.eval()
|
| 164 |
+
processor = CLIPProcessor.from_pretrained(clip_name)
|
| 165 |
+
return model, processor
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def encode_text(model, processor, text_queries, device):
|
| 169 |
+
"""Encode text queries into embeddings (unnormalized)."""
|
| 170 |
+
if isinstance(text_queries, str):
|
| 171 |
+
text_queries = [text_queries]
|
| 172 |
+
inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True)
|
| 173 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
text_features = model.get_text_features(**inputs)
|
| 176 |
+
return text_features
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
+
def encode_image(model, processor, images, device):
|
| 180 |
+
"""Encode images into embeddings (unnormalized)."""
|
| 181 |
+
if not isinstance(images, list):
|
| 182 |
+
images = [images]
|
| 183 |
+
inputs = processor(images=images, return_tensors="pt")
|
| 184 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 185 |
with torch.no_grad():
|
| 186 |
+
image_features = model.get_image_features(**inputs)
|
| 187 |
+
return image_features
|
| 188 |
+
|
| 189 |
|
| 190 |
+
def get_text_embedding(
|
| 191 |
+
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, text: str) -> torch.Tensor:
|
| 192 |
+
"""Normalized single text embedding (shape: [512])."""
|
| 193 |
+
return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0)
|
| 194 |
|
| 195 |
|
| 196 |
def cosine(a: torch.Tensor, b: torch.Tensor) -> float:
|
|
|
|
| 229 |
cfg: RuntimeConfig,
|
| 230 |
num_examples: int,
|
| 231 |
num_printed: int,
|
| 232 |
+
test_name: str = "Test A") -> Dict[str, bool]:
|
|
|
|
| 233 |
"""
|
| 234 |
A: different colors + same hierarchy.
|
| 235 |
Expect hierarchy subspace to be more similar than color subspace.
|
|
|
|
| 341 |
cfg: RuntimeConfig,
|
| 342 |
num_examples: int,
|
| 343 |
num_printed: int,
|
| 344 |
+
test_name: str = "Test B",) -> Dict[str, bool]:
|
|
|
|
| 345 |
"""
|
| 346 |
B: same color + different hierarchies.
|
| 347 |
Expect similarity in first16 (color) to be higher than full512.
|
|
|
|
| 446 |
|
| 447 |
|
| 448 |
|
| 449 |
+
def run_test_c(
|
| 450 |
model: CLIPModelTransformers,
|
| 451 |
processor: CLIPProcessor,
|
| 452 |
cfg: RuntimeConfig,
|
| 453 |
num_examples: int,
|
| 454 |
num_printed: int,
|
| 455 |
+
test_name: str = "Test C",) -> Dict[str, object]:
|
|
|
|
| 456 |
"""
|
| 457 |
+
C: Subspace Decomposition Consistency.
|
| 458 |
Encode a full description (e.g. "red dress in cotton"), a standalone color
|
| 459 |
("red"), and a standalone hierarchy ("dress"). Then verify:
|
| 460 |
- The color subspace (first 16D) of the full embedding aligns with the
|
|
|
|
| 615 |
def get_image_embedding(
|
| 616 |
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image_tensor: torch.Tensor
|
| 617 |
) -> torch.Tensor:
|
| 618 |
+
"""Normalized image embedding from a preprocessed tensor (shape: [512])."""
|
| 619 |
image_tensor = image_tensor.unsqueeze(0).to(device)
|
| 620 |
+
# Convert tensor to PIL for encode_image
|
| 621 |
+
from torchvision.transforms.functional import to_pil_image
|
| 622 |
+
pil_img = to_pil_image(image_tensor.squeeze(0).cpu())
|
| 623 |
+
return F.normalize(encode_image(model, processor, pil_img, device), dim=-1).squeeze(0)
|
|
|
|
| 624 |
|
| 625 |
|
| 626 |
def get_image_embedding_from_pil(
|
| 627 |
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image: Image.Image
|
| 628 |
) -> torch.Tensor:
|
| 629 |
+
"""Normalized image embedding from a PIL image (shape: [512])."""
|
| 630 |
+
return F.normalize(encode_image(model, processor, image, device), dim=-1).squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
|
| 632 |
|
| 633 |
def get_text_embeddings_batch(
|
| 634 |
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, texts: List[str]
|
| 635 |
) -> torch.Tensor:
|
| 636 |
+
"""Normalized text embeddings for a batch (shape: [N, 512])."""
|
| 637 |
+
return F.normalize(encode_text(model, processor, texts, device), dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
|
| 639 |
|
| 640 |
def get_prompt_ensembled_text_embeddings(
|
|
|
|
| 715 |
return probs, recommended_weight
|
| 716 |
|
| 717 |
|
| 718 |
+
def zero_shot_fashion_mnist(
|
| 719 |
+
model,
|
| 720 |
+
processor,
|
| 721 |
+
device,
|
| 722 |
+
batch_size: int = 64,
|
| 723 |
+
data_root: str = "./data") -> float:
|
| 724 |
+
"""Notebook-equivalent zero-shot accuracy on all Fashion-MNIST test samples."""
|
| 725 |
+
dataset = datasets.FashionMNIST(
|
| 726 |
+
root=data_root, train=False, download=True,
|
| 727 |
+
transform=transforms.Grayscale(num_output_channels=3),
|
| 728 |
+
)
|
| 729 |
+
loader = DataLoader(
|
| 730 |
+
dataset, batch_size=batch_size, shuffle=False,
|
| 731 |
+
collate_fn=lambda batch: (
|
| 732 |
+
[item[0] for item in batch],
|
| 733 |
+
torch.tensor([item[1] for item in batch]),
|
| 734 |
+
),
|
| 735 |
+
)
|
|
|
|
|
|
|
| 736 |
|
| 737 |
+
prompts = [f"a photo of a {label}" for label in dataset.classes]
|
| 738 |
+
text_embs = encode_text(model, processor, prompts, device).to(device).float()
|
| 739 |
+
text_embs = F.normalize(text_embs, dim=-1)
|
| 740 |
|
|
|
|
|
|
|
|
|
|
| 741 |
correct = 0
|
| 742 |
+
total = 0
|
| 743 |
|
| 744 |
+
for pil_images, labels in tqdm(loader, desc="Zero-shot Fashion-MNIST"):
|
| 745 |
+
img_embs = encode_image(model, processor, pil_images, device)
|
| 746 |
+
img_embs = img_embs.to(device).float()
|
| 747 |
+
img_embs = F.normalize(img_embs, dim=-1)
|
| 748 |
|
| 749 |
+
sim = img_embs @ text_embs.T
|
| 750 |
+
preds = sim.argmax(dim=-1).cpu()
|
|
|
|
| 751 |
|
| 752 |
+
correct += (preds == labels).sum().item()
|
| 753 |
+
total += labels.size(0)
|
|
|
|
|
|
|
| 754 |
|
| 755 |
+
accuracy = correct / total
|
| 756 |
+
print(f"Zero-shot accuracy on Fashion MNIST: {accuracy:.4f} ({correct}/{total})")
|
| 757 |
+
return accuracy
|
| 758 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
|
|
|
|
| 760 |
|
| 761 |
+
def zero_shot_kagl(
|
| 762 |
+
model,
|
| 763 |
+
processor,
|
| 764 |
+
device,
|
| 765 |
+
batch_size: int = 64,
|
| 766 |
+
num_examples: int = 10000,
|
| 767 |
+
) -> Optional[Dict[str, float]]:
|
| 768 |
+
"""Notebook-equivalent zero-shot accuracy/F1 on KAGL Marqo (category2)."""
|
| 769 |
+
try:
|
| 770 |
+
from datasets import load_dataset # type: ignore
|
| 771 |
+
except Exception:
|
| 772 |
+
print("Skipping zero_shot_kagl: datasets package not available")
|
| 773 |
+
return None
|
| 774 |
+
|
| 775 |
+
try:
|
| 776 |
+
dataset = load_dataset("Marqo/KAGL", split="data")
|
| 777 |
+
except Exception as exc:
|
| 778 |
+
print(f"Skipping zero_shot_kagl: failed to load dataset ({exc})")
|
| 779 |
+
return None
|
| 780 |
+
|
| 781 |
+
dataset = dataset.shuffle(seed=42).select(range(min(num_examples, len(dataset))))
|
| 782 |
+
|
| 783 |
+
pil_images: List[Image.Image] = []
|
| 784 |
+
labels_text: List[str] = []
|
| 785 |
+
for item in dataset:
|
| 786 |
+
raw_label = item.get("category2")
|
| 787 |
+
image_obj = item.get("image")
|
| 788 |
+
if raw_label is None or image_obj is None:
|
| 789 |
+
continue
|
| 790 |
+
|
| 791 |
+
if hasattr(image_obj, "convert"):
|
| 792 |
+
image = image_obj.convert("RGB")
|
| 793 |
+
elif isinstance(image_obj, dict) and "bytes" in image_obj:
|
| 794 |
+
image = Image.open(BytesIO(image_obj["bytes"])).convert("RGB")
|
| 795 |
+
else:
|
| 796 |
+
continue
|
| 797 |
+
pil_images.append(image)
|
| 798 |
+
labels_text.append(str(raw_label).strip())
|
| 799 |
+
|
| 800 |
+
if not pil_images:
|
| 801 |
+
print("Skipping zero_shot_kagl: no valid samples")
|
| 802 |
+
return None
|
| 803 |
|
| 804 |
+
candidate_labels = sorted(set(labels_text))
|
| 805 |
+
label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)}
|
| 806 |
+
all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64)
|
| 807 |
+
|
| 808 |
+
prompts = [f"a photo of a {label}" for label in candidate_labels]
|
| 809 |
+
text_embs = encode_text(model, processor, prompts, device).to(device).float()
|
| 810 |
+
text_embs = F.normalize(text_embs, dim=-1)
|
| 811 |
+
|
| 812 |
+
all_preds: List[np.ndarray] = []
|
| 813 |
+
for start in tqdm(range(0, len(pil_images), batch_size), desc="Zero-shot KAGL"):
|
| 814 |
+
batch_images = pil_images[start : start + batch_size]
|
| 815 |
+
img_embs = encode_image(model, processor, batch_images, device).to(device).float()
|
| 816 |
+
img_embs = F.normalize(img_embs, dim=-1)
|
| 817 |
+
sim = img_embs @ text_embs.T
|
| 818 |
+
preds = sim.argmax(dim=-1).cpu().numpy()
|
| 819 |
+
all_preds.append(preds)
|
| 820 |
+
|
| 821 |
+
pred_array = np.concatenate(all_preds, axis=0) if all_preds else np.array([], dtype=np.int64)
|
| 822 |
+
accuracy = float((pred_array == all_labels).mean()) if len(all_labels) else 0.0
|
| 823 |
+
weighted_f1 = f1_score(all_labels, pred_array, average="weighted") if len(all_labels) else 0.0
|
| 824 |
+
print(f"KAGL accuracy: {accuracy:.4f}")
|
| 825 |
+
print(f"KAGL weighted macro F1: {weighted_f1:.4f}")
|
| 826 |
+
return {"accuracy": accuracy, "weighted_f1": float(weighted_f1)}
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def zero_shot_internal(
|
| 830 |
+
model,
|
| 831 |
+
processor,
|
| 832 |
+
device,
|
| 833 |
+
batch_size: int = 64,
|
| 834 |
+
num_examples: int = 10000,
|
| 835 |
+
csv_path: str = INTERNAL_DATASET_CSV,) -> Optional[Dict[str, float]]:
|
| 836 |
+
"""Notebook-equivalent zero-shot accuracy/F1 on internal dataset."""
|
| 837 |
+
csv_file = Path(csv_path)
|
| 838 |
+
if not csv_file.exists():
|
| 839 |
+
print(f"Skipping zero_shot_internal: {csv_path} not found")
|
| 840 |
+
return None
|
| 841 |
+
|
| 842 |
+
df = pd.read_csv(csv_file)
|
| 843 |
+
use_local = "local_image_path" in df.columns
|
| 844 |
+
required_cols = {"hierarchy", "local_image_path"} if use_local else {"hierarchy", "image_url"}
|
| 845 |
+
if not required_cols.issubset(df.columns):
|
| 846 |
+
print(f"Skipping zero_shot_internal: missing required columns {required_cols}")
|
| 847 |
+
return None
|
| 848 |
+
|
| 849 |
+
img_col = "local_image_path" if use_local else "image_url"
|
| 850 |
+
df = df.dropna(subset=["hierarchy", img_col]).sample(frac=1.0, random_state=42)
|
| 851 |
+
pil_images: List[Image.Image] = []
|
| 852 |
+
labels_text: List[str] = []
|
| 853 |
+
for _, row in df.iterrows():
|
| 854 |
+
if len(pil_images) >= num_examples:
|
| 855 |
+
break
|
| 856 |
+
try:
|
| 857 |
+
if use_local:
|
| 858 |
+
img_path = Path(str(row["local_image_path"]))
|
| 859 |
+
if not img_path.exists():
|
| 860 |
+
continue
|
| 861 |
+
image = Image.open(img_path).convert("RGB")
|
| 862 |
+
else:
|
| 863 |
+
response = requests.get(str(row["image_url"]), timeout=5)
|
| 864 |
+
response.raise_for_status()
|
| 865 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 866 |
+
except Exception:
|
| 867 |
+
continue
|
| 868 |
+
label = normalize_hierarchy_label(str(row["hierarchy"]))
|
| 869 |
+
pil_images.append(image)
|
| 870 |
+
labels_text.append(label)
|
| 871 |
+
|
| 872 |
+
if not pil_images:
|
| 873 |
+
print("Skipping zero_shot_internal: no valid samples")
|
| 874 |
+
return None
|
| 875 |
+
|
| 876 |
+
candidate_labels = sorted(set(labels_text))
|
| 877 |
+
label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)}
|
| 878 |
+
all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64)
|
| 879 |
+
|
| 880 |
+
prompts = [f"a photo of a {label}" for label in candidate_labels]
|
| 881 |
+
text_embs = encode_text(model, processor, prompts, device).to(device).float()
|
| 882 |
+
text_embs = F.normalize(text_embs, dim=-1)
|
| 883 |
+
|
| 884 |
+
all_preds: List[np.ndarray] = []
|
| 885 |
+
for start in tqdm(range(0, len(pil_images), batch_size), desc="Zero-shot Internal"):
|
| 886 |
+
batch_images = pil_images[start : start + batch_size]
|
| 887 |
+
img_embs = encode_image(model, processor, batch_images, device).to(device).float()
|
| 888 |
+
img_embs = F.normalize(img_embs, dim=-1)
|
| 889 |
+
sim = img_embs @ text_embs.T
|
| 890 |
+
preds = sim.argmax(dim=-1).cpu().numpy()
|
| 891 |
+
all_preds.append(preds)
|
| 892 |
+
|
| 893 |
+
pred_array = np.concatenate(all_preds, axis=0) if all_preds else np.array([], dtype=np.int64)
|
| 894 |
+
accuracy = float((pred_array == all_labels).mean()) if len(all_labels) else 0.0
|
| 895 |
+
weighted_f1 = f1_score(all_labels, pred_array, average="weighted") if len(all_labels) else 0.0
|
| 896 |
+
print(f"Internal accuracy: {accuracy:.4f}")
|
| 897 |
+
print(f"Internal weighted macro F1: {weighted_f1:.4f}")
|
| 898 |
+
return {"accuracy": accuracy, "weighted_f1": float(weighted_f1)}
|
| 899 |
|
| 900 |
|
| 901 |
def normalize_hierarchy_label(raw_label: str) -> str:
|
|
|
|
| 950 |
"innerwear": "underwear",
|
| 951 |
"loungewear and nightwear": "underwear",
|
| 952 |
"saree": "dress",
|
| 953 |
+
"boots": "shoes",
|
| 954 |
+
"outer": "coat",
|
| 955 |
+
"sunglasses": "accessories",
|
| 956 |
+
"scarf & tie": "accessories",
|
| 957 |
+
"scarf/tie": "accessories",
|
| 958 |
+
"belt": "accessories",
|
| 959 |
}
|
| 960 |
+
exact = synonyms.get(label, None)
|
| 961 |
+
if exact is not None:
|
| 962 |
+
return exact
|
| 963 |
+
|
| 964 |
+
# Phase 2: substring/regex fallback via HierarchyExtractor
|
| 965 |
+
# Handles Internal dataset's multi-word hierarchy strings like
|
| 966 |
+
# "womens wms woven shirts sleeveless linen" -> "shirt"
|
| 967 |
+
result = _HIERARCHY_EXTRACTOR.extract_hierarchy(label)
|
| 968 |
+
if result:
|
| 969 |
+
return result
|
| 970 |
+
|
| 971 |
+
# Phase 3: extra keywords for the ~9 labels HierarchyExtractor misses
|
| 972 |
+
_EXTRA_KEYWORDS = [
|
| 973 |
+
("capri", "pant"),
|
| 974 |
+
("denim", "pant"),
|
| 975 |
+
("skinny", "pant"),
|
| 976 |
+
("boyfriend", "pant"),
|
| 977 |
+
("graphic", "top"),
|
| 978 |
+
("longsleeve", "top"),
|
| 979 |
+
("leather", "jacket"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
]
|
| 981 |
+
for keyword, category in _EXTRA_KEYWORDS:
|
| 982 |
+
if keyword in label:
|
| 983 |
+
return category
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 984 |
|
| 985 |
+
return label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 986 |
|
|
|
|
| 987 |
|
| 988 |
|
| 989 |
+
# ModaNet 13 categories (category_id -> label)
|
| 990 |
+
MODANET_CATEGORIES = {
|
| 991 |
+
1: "bag", 2: "belt", 3: "boots", 4: "footwear", 5: "outer",
|
| 992 |
+
6: "dress", 7: "sunglasses", 8: "pants", 9: "top", 10: "shorts",
|
| 993 |
+
11: "skirt", 12: "headwear", 13: "scarf/tie",
|
| 994 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 995 |
|
| 996 |
+
MODANET_ANNOTATIONS_JSON = "data/modanet_instances_train.json"
|
| 997 |
+
MODANET_IMAGES_DIR = "data/modanet_images/images"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
|
| 999 |
|
| 1000 |
+
def load_modanet_samples(
|
| 1001 |
num_examples: int,
|
| 1002 |
+
) -> Tuple[List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]], List[Tuple[Image.Image, str]]]:
|
| 1003 |
+
"""Return (baseline_samples, gap_samples, color_samples) from ModaNet.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
|
| 1005 |
+
Loads from local COCO JSON annotations + image directory.
|
| 1006 |
+
Each image may have multiple annotations — we pick the largest bbox area.
|
| 1007 |
+
"""
|
| 1008 |
+
import json as _json
|
| 1009 |
+
|
| 1010 |
+
ann_path = Path(MODANET_ANNOTATIONS_JSON)
|
| 1011 |
+
img_dir = Path(MODANET_IMAGES_DIR)
|
| 1012 |
+
|
| 1013 |
+
if not ann_path.exists():
|
| 1014 |
+
print(f" Skipping ModaNet: annotations not found at {MODANET_ANNOTATIONS_JSON}")
|
| 1015 |
+
return [], [], []
|
| 1016 |
+
if not img_dir.exists():
|
| 1017 |
+
print(f" Skipping ModaNet: images directory not found at {MODANET_IMAGES_DIR}")
|
| 1018 |
+
return [], [], []
|
| 1019 |
+
|
| 1020 |
+
print(" Loading ModaNet annotations...")
|
| 1021 |
+
with open(ann_path) as f:
|
| 1022 |
+
coco = _json.load(f)
|
| 1023 |
+
|
| 1024 |
+
cat_map = {c["id"]: c["name"] for c in coco["categories"]}
|
| 1025 |
+
img_map = {img["id"]: img["file_name"] for img in coco["images"]}
|
| 1026 |
+
|
| 1027 |
+
# For each image, find the annotation with the largest area.
|
| 1028 |
+
best_per_image: Dict[int, Tuple[int, float]] = {} # image_id -> (category_id, area)
|
| 1029 |
+
for ann in coco["annotations"]:
|
| 1030 |
+
img_id = ann["image_id"]
|
| 1031 |
+
cat_id = ann["category_id"]
|
| 1032 |
+
area = ann.get("area", 0)
|
| 1033 |
+
if img_id not in best_per_image or area > best_per_image[img_id][1]:
|
| 1034 |
+
best_per_image[img_id] = (cat_id, area)
|
| 1035 |
+
|
| 1036 |
+
# Shuffle deterministically and load images.
|
| 1037 |
+
image_ids = list(best_per_image.keys())
|
| 1038 |
+
rng = random.Random(42)
|
| 1039 |
+
rng.shuffle(image_ids)
|
| 1040 |
|
|
|
|
| 1041 |
baseline_samples: List[Tuple[Image.Image, str]] = []
|
| 1042 |
gap_samples: List[Tuple[Image.Image, str]] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1043 |
|
| 1044 |
+
for img_id in image_ids:
|
| 1045 |
+
if len(baseline_samples) >= num_examples:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
break
|
| 1047 |
+
file_name = img_map.get(img_id)
|
| 1048 |
+
if file_name is None:
|
| 1049 |
+
continue
|
| 1050 |
+
img_path = img_dir / file_name
|
| 1051 |
+
if not img_path.exists():
|
| 1052 |
+
continue
|
| 1053 |
try:
|
| 1054 |
+
image = Image.open(img_path).convert("RGB")
|
|
|
|
|
|
|
|
|
|
| 1055 |
except Exception:
|
| 1056 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1057 |
|
| 1058 |
+
cat_id, _ = best_per_image[img_id]
|
| 1059 |
+
native_label = cat_map.get(cat_id, "unknown")
|
| 1060 |
+
gap_label = normalize_hierarchy_label(native_label)
|
| 1061 |
+
baseline_samples.append((image, native_label))
|
| 1062 |
+
gap_samples.append((image, gap_label))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1063 |
|
| 1064 |
+
print(f" ModaNet: loaded {len(baseline_samples)} valid samples (from {len(best_per_image)} annotated images)")
|
| 1065 |
+
return baseline_samples, gap_samples, []
|
| 1066 |
+
|
| 1067 |
+
|
| 1068 |
+
def zero_shot_modanet(
|
| 1069 |
+
model,
|
| 1070 |
+
processor,
|
| 1071 |
+
device,
|
| 1072 |
+
batch_size: int = 64,
|
| 1073 |
+
num_examples: int = 10000,
|
| 1074 |
+
use_gap_labels: bool = True,
|
| 1075 |
+
) -> Optional[Dict[str, float]]:
|
| 1076 |
+
"""Zero-shot accuracy/F1 on ModaNet dataset."""
|
| 1077 |
+
baseline_samples, gap_samples, _ = load_modanet_samples(num_examples)
|
| 1078 |
+
samples = gap_samples if use_gap_labels else baseline_samples
|
| 1079 |
+
if not samples:
|
| 1080 |
+
print("Skipping zero_shot_modanet: no valid samples")
|
| 1081 |
+
return None
|
| 1082 |
|
| 1083 |
+
pil_images = [img for img, _ in samples]
|
| 1084 |
+
labels_text = [label for _, label in samples]
|
| 1085 |
+
|
| 1086 |
+
candidate_labels = sorted(set(labels_text))
|
| 1087 |
+
label_to_idx = {label: idx for idx, label in enumerate(candidate_labels)}
|
| 1088 |
+
all_labels = np.array([label_to_idx[label] for label in labels_text], dtype=np.int64)
|
| 1089 |
+
|
| 1090 |
+
prompts = [f"a photo of a {label}" for label in candidate_labels]
|
| 1091 |
+
text_embs = encode_text(model, processor, prompts, device).to(device).float()
|
| 1092 |
+
text_embs = F.normalize(text_embs, dim=-1)
|
| 1093 |
+
|
| 1094 |
+
all_preds: List[np.ndarray] = []
|
| 1095 |
+
for start in tqdm(range(0, len(pil_images), batch_size), desc="Zero-shot ModaNet"):
|
| 1096 |
+
batch_images = pil_images[start : start + batch_size]
|
| 1097 |
+
img_embs = encode_image(model, processor, batch_images, device).to(device).float()
|
| 1098 |
+
img_embs = F.normalize(img_embs, dim=-1)
|
| 1099 |
+
sim = img_embs @ text_embs.T
|
| 1100 |
+
preds = sim.argmax(dim=-1).cpu().numpy()
|
| 1101 |
+
all_preds.append(preds)
|
| 1102 |
+
|
| 1103 |
+
pred_array = np.concatenate(all_preds, axis=0) if all_preds else np.array([], dtype=np.int64)
|
| 1104 |
+
accuracy = float((pred_array == all_labels).mean()) if len(all_labels) else 0.0
|
| 1105 |
+
weighted_f1 = f1_score(all_labels, pred_array, average="weighted") if len(all_labels) else 0.0
|
| 1106 |
+
label_kind = "GAP" if use_gap_labels else "native"
|
| 1107 |
+
print(f"ModaNet ({label_kind}) accuracy: {accuracy:.4f}")
|
| 1108 |
+
print(f"ModaNet ({label_kind}) weighted macro F1: {weighted_f1:.4f}")
|
| 1109 |
+
return {"accuracy": accuracy, "weighted_f1": float(weighted_f1)}
|
| 1110 |
+
|
| 1111 |
+
|
| 1112 |
+
def main(
|
| 1113 |
+
selected_tests: set[str],
|
| 1114 |
+
model=None,
|
| 1115 |
+
processor=None,
|
| 1116 |
+
baseline_model=None,
|
| 1117 |
+
baseline_processor=None,
|
| 1118 |
+
) -> None:
|
| 1119 |
random.seed(42)
|
| 1120 |
cfg = resolve_runtime_config()
|
|
|
|
|
|
|
|
|
|
| 1121 |
|
| 1122 |
+
if model is None or processor is None:
|
| 1123 |
+
model_path = Path(cfg.main_model_path)
|
| 1124 |
+
if not model_path.exists():
|
| 1125 |
+
raise FileNotFoundError(f"Main model checkpoint not found: {cfg.main_model_path}")
|
| 1126 |
+
print("Loading model...")
|
| 1127 |
+
print(f" device: {cfg.device}")
|
| 1128 |
+
print(f" checkpoint: {cfg.main_model_path}")
|
| 1129 |
+
print(f" dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total={cfg.main_emb_dim}")
|
| 1130 |
+
model, processor = load_main_model(cfg.device, cfg.main_model_path)
|
| 1131 |
+
print("Model loaded.")
|
| 1132 |
+
else:
|
| 1133 |
+
print(f"Using pre-loaded GAP-CLIP model (dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total={cfg.main_emb_dim})")
|
| 1134 |
|
| 1135 |
result_a: Optional[Dict[str, object]] = None
|
| 1136 |
result_b: Optional[Dict[str, object]] = None
|
| 1137 |
+
result_c: Optional[Dict[str, object]] = None
|
| 1138 |
baseline_result_a: Optional[Dict[str, object]] = None
|
| 1139 |
baseline_result_b: Optional[Dict[str, object]] = None
|
| 1140 |
+
baseline_result_c: Optional[Dict[str, object]] = None
|
| 1141 |
+
|
| 1142 |
+
if baseline_model is None or baseline_processor is None:
|
| 1143 |
+
if any(t in selected_tests for t in ("A", "B", "C", "D")):
|
| 1144 |
+
print("\nLoading baseline model (patrickjohncyh/fashion-clip)...")
|
| 1145 |
+
baseline_name = "patrickjohncyh/fashion-clip"
|
| 1146 |
+
baseline_processor = CLIPProcessor.from_pretrained(baseline_name)
|
| 1147 |
+
baseline_model = CLIPModelTransformers.from_pretrained(baseline_name).to(cfg.device)
|
| 1148 |
+
baseline_model.eval()
|
| 1149 |
+
print("Baseline model loaded.")
|
|
|
|
| 1150 |
|
| 1151 |
if "A" in selected_tests:
|
| 1152 |
result_a = run_test_a(
|
|
|
|
| 1182 |
num_printed=DEFAULT_NUM_PRINTED,
|
| 1183 |
test_name="Baseline Test B",
|
| 1184 |
)
|
| 1185 |
+
if "C" in selected_tests:
|
| 1186 |
+
result_c = run_test_c(
|
| 1187 |
model,
|
| 1188 |
processor,
|
| 1189 |
cfg,
|
|
|
|
| 1191 |
num_printed=DEFAULT_NUM_PRINTED,
|
| 1192 |
)
|
| 1193 |
if baseline_model is not None and baseline_processor is not None:
|
| 1194 |
+
baseline_result_c = run_test_c(
|
| 1195 |
baseline_model,
|
| 1196 |
baseline_processor,
|
| 1197 |
cfg,
|
| 1198 |
num_examples=DEFAULT_NUM_EXAMPLES,
|
| 1199 |
num_printed=DEFAULT_NUM_PRINTED,
|
| 1200 |
+
test_name="Baseline Test C",
|
| 1201 |
)
|
| 1202 |
|
| 1203 |
+
if "D" in selected_tests:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1204 |
assert baseline_model is not None and baseline_processor is not None
|
| 1205 |
|
| 1206 |
+
print("\n" + "=" * 120)
|
| 1207 |
+
print("Test D — Notebook-style zero-shot accuracy")
|
| 1208 |
+
print("=" * 120)
|
| 1209 |
+
d_results: Dict[str, Dict[str, Optional[Dict[str, float]]]] = {
|
| 1210 |
+
"Fashion-MNIST": {
|
| 1211 |
+
"gap": {"accuracy": zero_shot_fashion_mnist(model=model, processor=processor, device=cfg.device, batch_size=64)},
|
| 1212 |
+
"base": {"accuracy": zero_shot_fashion_mnist(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64)},
|
| 1213 |
+
},
|
| 1214 |
+
"KAGL Marqo": {
|
| 1215 |
+
"gap": zero_shot_kagl(model=model, processor=processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES),
|
| 1216 |
+
"base": zero_shot_kagl(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES),
|
| 1217 |
+
},
|
| 1218 |
+
"Internal dataset": {
|
| 1219 |
+
"gap": zero_shot_internal(model=model, processor=processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES),
|
| 1220 |
+
"base": zero_shot_internal(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES),
|
| 1221 |
+
},
|
| 1222 |
+
"ModaNet": {
|
| 1223 |
+
"gap": zero_shot_modanet(model=model, processor=processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, use_gap_labels=True),
|
| 1224 |
+
"base": zero_shot_modanet(model=baseline_model, processor=baseline_processor, device=cfg.device, batch_size=64, num_examples=DEFAULT_NUM_EXAMPLES, use_gap_labels=True),
|
| 1225 |
+
},
|
| 1226 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1227 |
|
| 1228 |
+
print("\n" + "-" * 120)
|
| 1229 |
+
print("Test D summary")
|
| 1230 |
+
print("-" * 120)
|
| 1231 |
+
summary_rows: List[List[str]] = []
|
| 1232 |
+
for ds in ["Fashion-MNIST", "KAGL Marqo", "ModaNet", "Internal dataset"]:
|
| 1233 |
+
gap_result = d_results[ds]["gap"]
|
| 1234 |
+
base_result = d_results[ds]["base"]
|
| 1235 |
+
gap_acc = None if gap_result is None else gap_result.get("accuracy")
|
| 1236 |
+
base_acc = None if base_result is None else base_result.get("accuracy")
|
| 1237 |
+
summary_rows.append([
|
| 1238 |
+
ds,
|
| 1239 |
+
f"{gap_acc:.2%}" if gap_acc is not None else "N/A",
|
| 1240 |
+
f"{base_acc:.2%}" if base_acc is not None else "N/A",
|
| 1241 |
+
])
|
| 1242 |
+
print_table(
|
| 1243 |
+
"Test D — zero-shot accuracy (notebook protocol)",
|
| 1244 |
+
["Dataset", "GAP-CLIP", "Fashion-CLIP (baseline)"],
|
| 1245 |
+
summary_rows,
|
| 1246 |
+
)
|
| 1247 |
print("\n" + "=" * 120)
|
| 1248 |
print("Final Summary")
|
| 1249 |
print("=" * 120)
|
|
|
|
| 1258 |
print(f"Test B full512 accuracy: {float(result_b['accuracy_full512']):.2%}")
|
| 1259 |
if baseline_result_b is not None:
|
| 1260 |
print(f"Baseline Test B full512 accuracy: {float(baseline_result_b['accuracy_full512']):.2%}")
|
| 1261 |
+
if result_c is not None:
|
| 1262 |
+
print(f"Test C overall: {format_bool(bool(result_c['overall']))}")
|
| 1263 |
+
print(f" pass rate: {float(result_c['pass_rate']):.2%}")
|
| 1264 |
+
print(f" avg color_match={float(result_c['avg_color_match']):.4f} vs cross={float(result_c['avg_color_cross']):.4f}")
|
| 1265 |
+
print(f" avg hier_match={float(result_c['avg_hier_match']):.4f} vs cross={float(result_c['avg_hier_cross']):.4f}")
|
| 1266 |
+
if baseline_result_c is not None:
|
| 1267 |
+
print(f"Baseline Test C overall: {format_bool(bool(baseline_result_c['overall']))}")
|
| 1268 |
+
print(f" baseline pass rate: {float(baseline_result_c['pass_rate']):.2%}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1269 |
|
| 1270 |
if result_a is not None:
|
| 1271 |
+
assert float(result_a["pass_rate"]) >= 0.95, (
|
| 1272 |
+
f"Test A failed: pass rate {float(result_a['pass_rate']):.2%} < 95%."
|
| 1273 |
+
)
|
| 1274 |
if result_b is not None:
|
| 1275 |
+
assert float(result_b["pass_rate"]) >= 0.95, (
|
| 1276 |
+
f"Test B failed: pass rate {float(result_b['pass_rate']):.2%} < 95%."
|
| 1277 |
+
)
|
| 1278 |
+
if result_c is not None:
|
| 1279 |
+
assert float(result_c["pass_rate"]) >= 0.95, (
|
| 1280 |
+
f"Test C failed: subspace decomposition pass rate {float(result_c['pass_rate']):.2%} < 95%."
|
| 1281 |
)
|
| 1282 |
|
| 1283 |
print("\nAll embedding-structure tests passed.")
|
| 1284 |
|
| 1285 |
|
| 1286 |
if __name__ == "__main__":
|
| 1287 |
+
parser = argparse.ArgumentParser(description="Embedding structure evaluation")
|
| 1288 |
+
parser.add_argument("--tests", default="ABCD", help="Which tests to run, e.g. 'C' or 'ABCD'")
|
| 1289 |
+
parser.add_argument("--num-examples", type=int, default=None, help="Override DEFAULT_NUM_EXAMPLES")
|
| 1290 |
+
args = parser.parse_args()
|
| 1291 |
+
if args.num_examples is not None:
|
| 1292 |
+
DEFAULT_NUM_EXAMPLES = args.num_examples
|
| 1293 |
+
selected_tests = set(args.tests.upper())
|
| 1294 |
main(selected_tests)
|