Update repository with restructured codebase
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +34 -0
- MODEL_CARD.md +68 -0
- README.md +906 -42
- __init__.py +45 -0
- config.py +65 -206
- data/{dowload_images_data.py → download_images.py} +3 -3
- data/get_csv_from_chunks.py +62 -0
- evaluation/.DS_Store +0 -0
- evaluation/0_shot_classification.py +0 -512
- evaluation/{heatmap_color_similarities.py → annex92_color_heatmaps.py} +20 -0
- evaluation/{tsne_images.py → annex93_tsne.py} +24 -3
- evaluation/annex94_search_demo.py +425 -0
- evaluation/basic_test_generalized.py +0 -425
- evaluation/fashion_search.py +0 -365
- evaluation/hierarchy_evaluation.py +0 -1842
- evaluation/run_all_evaluations.py +186 -287
- evaluation/{color_evaluation.py → sec51_color_model_eval.py} +189 -71
- evaluation/sec52_category_model_eval.py +1212 -0
- evaluation/{main_model_evaluation.py → sec533_clip_nn_accuracy.py} +58 -288
- evaluation/sec5354_separation_semantic.py +329 -0
- evaluation/sec536_embedding_structure.py +1460 -0
- evaluation/utils/.DS_Store +0 -0
- evaluation/utils/__init__.py +1 -0
- evaluation/utils/datasets.py +389 -0
- evaluation/utils/metrics.py +208 -0
- evaluation/utils/model_loader.py +221 -0
- example_usage.py +2 -2
- figures/.DS_Store +0 -0
- color_model.pt → figures/baseline_blue_pant.png +2 -2
- hierarchy_model.pth → figures/baseline_red_dress.png +2 -2
- figures/confusion_matrices/.DS_Store +0 -0
- gap_clip.pth → figures/confusion_matrices/cm_color/kaggle_baseline_image_color_confusion_matrix.png +2 -2
- figures/confusion_matrices/cm_color/kaggle_baseline_text_color_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_color/kaggle_image_color_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_color/kaggle_text_color_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_color/local_baseline_image_color_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_color/local_baseline_text_color_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_color/local_image_color_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_color/local_text_color_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/baseline_image_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/baseline_internal_image_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/baseline_internal_text_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/baseline_kagl_marqo_image_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/baseline_kagl_marqo_text_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/baseline_text_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/gap_clip_image_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/gap_clip_internal_image_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/gap_clip_internal_text_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/gap_clip_kagl_marqo_image_hierarchy_confusion_matrix.png +3 -0
- figures/confusion_matrices/cm_hierarchy/gap_clip_kagl_marqo_text_hierarchy_confusion_matrix.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,37 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figures/baseline_blue_pant.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
figures/baseline_red_dress.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
figures/confusion_matrices/cm_color/kaggle_baseline_image_color_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
figures/confusion_matrices/cm_color/kaggle_baseline_text_color_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
figures/confusion_matrices/cm_color/kaggle_image_color_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
figures/confusion_matrices/cm_color/kaggle_text_color_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
figures/confusion_matrices/cm_color/local_baseline_image_color_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
figures/confusion_matrices/cm_color/local_baseline_text_color_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
figures/confusion_matrices/cm_color/local_image_color_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
figures/confusion_matrices/cm_color/local_text_color_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
figures/confusion_matrices/cm_hierarchy/baseline_image_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
figures/confusion_matrices/cm_hierarchy/baseline_internal_image_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
figures/confusion_matrices/cm_hierarchy/baseline_internal_text_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
figures/confusion_matrices/cm_hierarchy/baseline_kagl_marqo_image_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
figures/confusion_matrices/cm_hierarchy/baseline_kagl_marqo_text_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
figures/confusion_matrices/cm_hierarchy/baseline_text_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
figures/confusion_matrices/cm_hierarchy/gap_clip_image_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
figures/confusion_matrices/cm_hierarchy/gap_clip_internal_image_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
figures/confusion_matrices/cm_hierarchy/gap_clip_internal_text_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
figures/confusion_matrices/cm_hierarchy/gap_clip_kagl_marqo_image_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
figures/confusion_matrices/cm_hierarchy/gap_clip_kagl_marqo_text_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
figures/confusion_matrices/cm_hierarchy/gap_clip_text_hierarchy_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
figures/gapclip_blue_pant.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
figures/gapclip_red_dress.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
figures/heatmap.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
figures/heatmap_baseline.jpg filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
figures/red_dress.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
figures/scheme.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
figures/training_curves.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
figures/tsne_baseline.png filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
figures/tsne_hierarchy_baseline.png filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
figures/tsne_hierarchy_our.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
figures/tsne_model.png filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
paper/paper.pdf filter=lfs diff=lfs merge=lfs -text
|
MODEL_CARD.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
tags:
|
| 4 |
+
- fashion
|
| 5 |
+
- clip
|
| 6 |
+
- multimodal
|
| 7 |
+
- image-search
|
| 8 |
+
- text-search
|
| 9 |
+
- embeddings
|
| 10 |
+
- contrastive-learning
|
| 11 |
+
license: mit
|
| 12 |
+
datasets:
|
| 13 |
+
- custom
|
| 14 |
+
metrics:
|
| 15 |
+
- accuracy
|
| 16 |
+
- cosine-similarity
|
| 17 |
+
library_name: transformers
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
# GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings
|
| 21 |
+
|
| 22 |
+
This model is part of the GAP-CLIP project for fashion search with guaranteed attribute positioning.
|
| 23 |
+
|
| 24 |
+
## Model Description
|
| 25 |
+
|
| 26 |
+
GAP-CLIP is a multi-modal search model for fashion that combines:
|
| 27 |
+
- **Color embeddings** (16 dimensions): Specialized for color representation
|
| 28 |
+
- **Hierarchy embeddings** (64 dimensions): Specialized for category classification
|
| 29 |
+
- **General CLIP embeddings** (432 dimensions): General visual-semantic understanding
|
| 30 |
+
|
| 31 |
+
**Total embedding size**: 512 dimensions
|
| 32 |
+
|
| 33 |
+
## Quick Start
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 37 |
+
from huggingface_hub import hf_hub_download
|
| 38 |
+
import torch
|
| 39 |
+
|
| 40 |
+
# Load model
|
| 41 |
+
model = CLIPModel.from_pretrained("Leacb4/gap-clip")
|
| 42 |
+
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
|
| 43 |
+
|
| 44 |
+
# Process text
|
| 45 |
+
text = "red dress"
|
| 46 |
+
inputs = processor(text=[text], return_tensors="pt", padding=True)
|
| 47 |
+
text_features = model.get_text_features(**inputs)
|
| 48 |
+
|
| 49 |
+
# Extract subspaces
|
| 50 |
+
color_emb = text_features[:, :16] # Color dimensions
|
| 51 |
+
hierarchy_emb = text_features[:, 16:80] # Hierarchy dimensions
|
| 52 |
+
general_emb = text_features[:, 80:] # General CLIP dimensions
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Citation
|
| 56 |
+
|
| 57 |
+
```bibtex
|
| 58 |
+
@misc{gap-clip-2024,
|
| 59 |
+
title={GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings for Fashion Search},
|
| 60 |
+
author={Sarfati, Lea Attia},
|
| 61 |
+
year={2024},
|
| 62 |
+
url={https://huggingface.co/Leacb4/gap-clip}
|
| 63 |
+
}
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## License
|
| 67 |
+
|
| 68 |
+
MIT License - See LICENSE file for details.
|
README.md
CHANGED
|
@@ -1,68 +1,932 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
---
|
| 19 |
|
| 20 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
##
|
| 25 |
|
| 26 |
-
|
| 27 |
-
- **Color embeddings** (16 dimensions): Specialized for color representation
|
| 28 |
-
- **Hierarchy embeddings** (64 dimensions): Specialized for category classification
|
| 29 |
-
- **General CLIP embeddings** (432 dimensions): General visual-semantic understanding
|
| 30 |
|
| 31 |
-
**
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
```python
|
| 36 |
-
from transformers import CLIPProcessor, CLIPModel
|
| 37 |
-
from huggingface_hub import hf_hub_download
|
| 38 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
| 42 |
-
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
text_features = model.get_text_features(**inputs)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
```
|
| 54 |
|
| 55 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
```bibtex
|
| 58 |
@misc{gap-clip-2024,
|
| 59 |
title={GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings for Fashion Search},
|
| 60 |
author={Sarfati, Lea Attia},
|
| 61 |
year={2024},
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
}
|
| 64 |
```
|
| 65 |
|
| 66 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
|
|
|
| 1 |
+
# GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings
|
| 2 |
+
|
| 3 |
+
[](https://www.python.org/downloads/)
|
| 4 |
+
[](https://pytorch.org/)
|
| 5 |
+
[](https://opensource.org/licenses/MIT)
|
| 6 |
+
[](https://huggingface.co/Leacb4/gap-clip)
|
| 7 |
+
|
| 8 |
+
**Advanced multimodal fashion search model combining specialized color embeddings, hierarchical category embeddings, and CLIP for intelligent fashion item retrieval.**
|
| 9 |
+
|
| 10 |
---
|
| 11 |
+
|
| 12 |
+
## 🚀 Quick Start
|
| 13 |
+
|
| 14 |
+
### Installation (< 1 minute)
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
# Clone the repository
|
| 18 |
+
git clone https://github.com/Leacb4/gap-clip.git
|
| 19 |
+
cd gap-clip
|
| 20 |
+
|
| 21 |
+
# Install package with pip
|
| 22 |
+
pip install -e .
|
| 23 |
+
|
| 24 |
+
# Or just install dependencies
|
| 25 |
+
pip install -r requirements.txt
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### Try It Now (< 2 minutes)
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
from example_usage import load_models_from_hf
|
| 32 |
+
|
| 33 |
+
# Load pre-trained models from Hugging Face
|
| 34 |
+
models = load_models_from_hf("Leacb4/gap-clip")
|
| 35 |
+
|
| 36 |
+
# Search with text
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
text_query = "red summer dress"
|
| 39 |
+
text_inputs = models['processor'](text=[text_query], padding=True, return_tensors="pt")
|
| 40 |
+
text_inputs = {k: v.to(models['device']) for k, v in text_inputs.items()}
|
| 41 |
+
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
text_features = models['main_model'](**text_inputs).text_embeds
|
| 44 |
+
|
| 45 |
+
# Extract specialized embeddings
|
| 46 |
+
color_emb = text_features[:, :16] # Color (dims 0-15)
|
| 47 |
+
category_emb = text_features[:, 16:80] # Category (dims 16-79)
|
| 48 |
+
general_emb = text_features[:, 80:] # General CLIP (dims 80-511)
|
| 49 |
+
|
| 50 |
+
print(f"✅ Successfully extracted embeddings!")
|
| 51 |
+
print(f" Color: {color_emb.shape}, Category: {category_emb.shape}, General: {general_emb.shape}")
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
---
|
| 55 |
|
| 56 |
+
## 📋 Description
|
| 57 |
+
|
| 58 |
+
This project implements an advanced fashion search system based on CLIP, with three specialized models:
|
| 59 |
+
|
| 60 |
+
1. **Color Model** (`color_model.pt`) : Specialized CLIP model for extracting reduced-size color embeddings from text and images
|
| 61 |
+
2. **Hierarchy Model** (`hierarchy_model.pth`) : Model for classifying and encoding reduced-size categorical hierarchy of fashion items
|
| 62 |
+
3. **Main CLIP Model** (`gap_clip.pth`) : Main CLIP model based on LAION, trained with color and hierarchy embeddings
|
| 63 |
+
|
| 64 |
+
### Architecture
|
| 65 |
+
|
| 66 |
+
The main model's embedding structure:
|
| 67 |
+
- **Dimensions 0-15** (16 dims): Color embeddings aligned with specialized color model
|
| 68 |
+
- **Dimensions 16-79** (64 dims): Hierarchy embeddings aligned with specialized hierarchy model
|
| 69 |
+
- **Dimensions 80-511** (432 dims): Standard CLIP embeddings for general visual-semantic understanding
|
| 70 |
+
|
| 71 |
+
**Total: 512 dimensions** per embedding (text or image)
|
| 72 |
+
|
| 73 |
+
**Key Innovation**: The first 80 dimensions are explicitly trained to align with specialized models through direct MSE and cosine similarity losses, ensuring guaranteed attribute positioning (GAP) while maintaining full CLIP capabilities in the remaining dimensions.
|
| 74 |
+
|
| 75 |
+
### Loss Functions
|
| 76 |
+
|
| 77 |
+
**1. Enhanced Contrastive Loss** (`enhanced_contrastive_loss`):
|
| 78 |
+
|
| 79 |
+
Combines multiple objectives:
|
| 80 |
+
- **Original Triple Loss**: Text-image-attributes contrastive learning
|
| 81 |
+
- **Color Alignment**: Forces dims 0-15 to match color model embeddings
|
| 82 |
+
- **Hierarchy Alignment**: Forces dims 16-79 to match hierarchy model embeddings
|
| 83 |
+
- **Reference Loss**: Optional regularization to stay close to base CLIP
|
| 84 |
+
|
| 85 |
+
**2. Alignment Components**:
|
| 86 |
+
```python
|
| 87 |
+
# Color alignment (text & image)
|
| 88 |
+
color_text_mse = F.mse_loss(main_color_dims, color_model_emb)
|
| 89 |
+
color_text_cosine = 1 - F.cosine_similarity(main_color_dims, color_model_emb).mean()
|
| 90 |
+
|
| 91 |
+
# Hierarchy alignment (text & image)
|
| 92 |
+
hierarchy_text_mse = F.mse_loss(main_hierarchy_dims, hierarchy_model_emb)
|
| 93 |
+
hierarchy_text_cosine = 1 - F.cosine_similarity(main_hierarchy_dims, hierarchy_model_emb).mean()
|
| 94 |
+
|
| 95 |
+
# Combined alignment
|
| 96 |
+
alignment_loss = (color_alignment + hierarchy_alignment) / 2
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
**3. Final Loss**:
|
| 100 |
+
```python
|
| 101 |
+
total_loss = (1 - α) * contrastive_loss + α * alignment_loss + β * reference_loss
|
| 102 |
+
```
|
| 103 |
+
Where:
|
| 104 |
+
- α (alignment_weight) = 0.2 : Balances contrastive and alignment objectives
|
| 105 |
+
- β (reference_weight) = 0.1 : Keeps text space close to base CLIP
|
| 106 |
+
|
| 107 |
+
## 🚀 Installation
|
| 108 |
+
|
| 109 |
+
### Prerequisites
|
| 110 |
+
|
| 111 |
+
- Python 3.8 or higher
|
| 112 |
+
- PyTorch 2.0+ (with CUDA for GPU support, optional but recommended)
|
| 113 |
+
- 16GB RAM minimum (32GB recommended for training)
|
| 114 |
+
- ~5GB disk space for models and data
|
| 115 |
+
|
| 116 |
+
### Method 1: Install as Package (Recommended)
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
# Clone repository
|
| 120 |
+
git clone https://github.com/Leacb4/gap-clip.git
|
| 121 |
+
cd gap-clip
|
| 122 |
+
|
| 123 |
+
# Install in development mode
|
| 124 |
+
pip install -e .
|
| 125 |
+
|
| 126 |
+
# Or install with optional dependencies
|
| 127 |
+
pip install -e ".[dev]" # With development tools
|
| 128 |
+
pip install -e ".[optuna]" # With hyperparameter optimization
|
| 129 |
+
pip install -e ".[all]" # With all extras
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### Method 2: Install Dependencies Only
|
| 133 |
+
|
| 134 |
+
```bash
|
| 135 |
+
pip install -r requirements.txt
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
### Method 3: From Hugging Face (Model Only)
|
| 139 |
+
|
| 140 |
+
```python
|
| 141 |
+
from example_usage import load_models_from_hf
|
| 142 |
+
models = load_models_from_hf("Leacb4/gap-clip")
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Main Dependencies
|
| 146 |
+
|
| 147 |
+
| Package | Version | Purpose |
|
| 148 |
+
|---------|---------|---------|
|
| 149 |
+
| `torch` | ≥2.0.0 | Deep learning framework |
|
| 150 |
+
| `transformers` | ≥4.30.0 | Hugging Face CLIP models |
|
| 151 |
+
| `huggingface-hub` | ≥0.16.0 | Model download/upload |
|
| 152 |
+
| `pillow` | ≥9.0.0 | Image processing |
|
| 153 |
+
| `pandas` | ≥1.5.0 | Data manipulation |
|
| 154 |
+
| `scikit-learn` | ≥1.3.0 | ML metrics & evaluation |
|
| 155 |
+
| `tqdm` | ≥4.65.0 | Progress bars |
|
| 156 |
+
| `matplotlib` | ≥3.7.0 | Visualization |
|
| 157 |
+
|
| 158 |
+
### Verify Installation
|
| 159 |
+
|
| 160 |
+
```python
|
| 161 |
+
# Test that everything works
|
| 162 |
+
import config
|
| 163 |
+
config.print_config()
|
| 164 |
+
|
| 165 |
+
# Check device
|
| 166 |
+
print(f"Using device: {config.device}")
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
## 📁 Project Structure
|
| 170 |
+
|
| 171 |
+
```
|
| 172 |
+
.
|
| 173 |
+
├── config.py # Configuration for paths and parameters
|
| 174 |
+
├── example_usage.py # Usage examples and HuggingFace loading
|
| 175 |
+
├── setup.py # Package installation
|
| 176 |
+
├── __init__.py # Package initialization
|
| 177 |
+
├── README.md # This documentation
|
| 178 |
+
├── MODEL_CARD.md # Hugging Face model card
|
| 179 |
+
│
|
| 180 |
+
├── paper/ # Scientific paper
|
| 181 |
+
│ ├── latex_paper.ltx # LaTeX source
|
| 182 |
+
│ └── paper.pdf # Compiled PDF
|
| 183 |
+
│
|
| 184 |
+
├── figures/ # Paper figures
|
| 185 |
+
│ ├── scheme.png # Architecture diagram
|
| 186 |
+
│ ├── heatmap_baseline.jpg # Baseline color heatmap
|
| 187 |
+
│ ├── heatmap.png # GAP-CLIP color heatmap
|
| 188 |
+
│ ├── tsne_*.png # t-SNE visualizations
|
| 189 |
+
│ ├── red_dress.png # Search demo example
|
| 190 |
+
│ ├── blue_jeans.png # Search demo example
|
| 191 |
+
│ ├── optuna_param_importances.png # Optuna importance plot
|
| 192 |
+
│ └── training_curves.png # Training loss curves
|
| 193 |
+
│
|
| 194 |
+
├── training/ # Model training code
|
| 195 |
+
│ ├── main_model.py # Main GAP-CLIP model with enhanced loss
|
| 196 |
+
│ ├── hierarchy_model.py # Hierarchy/category model
|
| 197 |
+
│ ├── train_main_model.py # Training with Optuna-optimized params
|
| 198 |
+
│ └── optuna_optimisation.py # Hyperparameter optimization
|
| 199 |
+
│
|
| 200 |
+
├── evaluation/ # Paper evaluation scripts
|
| 201 |
+
│ ├── run_all_evaluations.py # Orchestrates all evaluations
|
| 202 |
+
│ ├── sec51_color_model_eval.py # Section 5.1 - Color model
|
| 203 |
+
│ ├── sec52_category_model_eval.py # Section 5.2 - Category model
|
| 204 |
+
│ ├── sec533_clip_nn_accuracy.py # Section 5.3.3 - Classification
|
| 205 |
+
│ ├── sec5354_separation_semantic.py # Sections 5.3.4-5.3.5
|
| 206 |
+
│ ├── sec536_embedding_structure.py # Section 5.3.6 - Structure tests
|
| 207 |
+
│ ├── annex92_color_heatmaps.py # Annex - Color heatmaps
|
| 208 |
+
│ ├── annex93_tsne.py # Annex - t-SNE visualizations
|
| 209 |
+
│ ├── annex94_search_demo.py # Annex - Search demo
|
| 210 |
+
│ └── utils/ # Shared evaluation utilities
|
| 211 |
+
│ ├── datasets.py # Dataset loaders
|
| 212 |
+
│ ├── metrics.py # Metrics (separation, accuracy)
|
| 213 |
+
│ └── model_loader.py # Model loading helpers
|
| 214 |
+
│
|
| 215 |
+
├── data/ # Data preparation
|
| 216 |
+
│ ├── download_images.py # Download dataset images
|
| 217 |
+
│ └── get_csv_from_chunks.py # Merge CSV chunks
|
| 218 |
+
│
|
| 219 |
+
├── models/ # Trained model weights
|
| 220 |
+
│ ├── color_model.pt # Color model checkpoint
|
| 221 |
+
│ ├── hierarchy_model.pth # Hierarchy model checkpoint
|
| 222 |
+
│ └── gap_clip.pth # Main GAP-CLIP checkpoint
|
| 223 |
+
│
|
| 224 |
+
└── optuna/ # Optuna optimization artifacts
|
| 225 |
+
├── optuna_results.txt # Best hyperparameters
|
| 226 |
+
├── optuna_study.pkl # Saved study
|
| 227 |
+
├── optuna_optimization_history.png
|
| 228 |
+
└── optuna_param_importances.png
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
### Key Files Description
|
| 232 |
+
|
| 233 |
+
**Core Model Files** (in `training/`):
|
| 234 |
+
- `main_model.py`: GAP-CLIP implementation with enhanced contrastive loss
|
| 235 |
+
- `hierarchy_model.py`: ResNet18-based hierarchy classification model (64 dims)
|
| 236 |
+
- `train_main_model.py`: Training with Optuna-optimized hyperparameters
|
| 237 |
+
- `optuna_optimisation.py`: Hyperparameter search with Optuna
|
| 238 |
+
|
| 239 |
+
**Configuration & Setup**:
|
| 240 |
+
- `config.py`: Configuration with type hints, auto device detection, validation
|
| 241 |
+
- `setup.py`: Package installer with CLI entry points
|
| 242 |
+
- `__init__.py`: Package initialization for easy imports
|
| 243 |
+
|
| 244 |
+
**Evaluation Suite** (in `evaluation/`):
|
| 245 |
+
- Scripts prefixed `sec5*` correspond to paper sections 5.1–5.3.6
|
| 246 |
+
- Scripts prefixed `annex9*` generate annex figures (heatmaps, t-SNE, search demo)
|
| 247 |
+
- `run_all_evaluations.py`: Orchestrates all paper evaluations
|
| 248 |
+
- `utils/`: Shared datasets, metrics, and model loading
|
| 249 |
+
|
| 250 |
+
**CLI Commands**:
|
| 251 |
+
After installation with `pip install -e .`, you can use:
|
| 252 |
+
```bash
|
| 253 |
+
gap-clip-train # Start training
|
| 254 |
+
gap-clip-example # Run usage examples
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
## 🔧 Configuration
|
| 258 |
+
|
| 259 |
+
Main parameters are defined in `config.py` (✨ completely rewritten with improvements):
|
| 260 |
+
|
| 261 |
+
```python
|
| 262 |
+
import config
|
| 263 |
+
|
| 264 |
+
# Automatic device detection (CUDA > MPS > CPU)
|
| 265 |
+
device = config.device # Automatically selects best available device
|
| 266 |
+
|
| 267 |
+
# Embedding dimensions
|
| 268 |
+
color_emb_dim = config.color_emb_dim # 16 dims (0-15)
|
| 269 |
+
hierarchy_emb_dim = config.hierarchy_emb_dim # 64 dims (16-79)
|
| 270 |
+
main_emb_dim = config.main_emb_dim # 512 dims total
|
| 271 |
+
|
| 272 |
+
# Default training hyperparameters
|
| 273 |
+
batch_size = config.DEFAULT_BATCH_SIZE # 32
|
| 274 |
+
learning_rate = config.DEFAULT_LEARNING_RATE # 1.5e-5
|
| 275 |
+
temperature = config.DEFAULT_TEMPERATURE # 0.09
|
| 276 |
+
|
| 277 |
+
# Utility functions
|
| 278 |
+
config.print_config() # Print current configuration
|
| 279 |
+
config.validate_paths() # Validate that all files exist
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
### New Features in config.py ✨
|
| 283 |
+
|
| 284 |
+
- **Automatic device detection**: Selects CUDA > MPS > CPU automatically
|
| 285 |
+
- **Type hints**: Full type annotations for better IDE support
|
| 286 |
+
- **Validation**: `validate_paths()` checks all model files exist
|
| 287 |
+
- **Print utility**: `print_config()` shows current settings
|
| 288 |
+
- **Constants**: Pre-defined default hyperparameters
|
| 289 |
+
- **Documentation**: Comprehensive docstrings for all settings
|
| 290 |
+
|
| 291 |
+
### Model Paths
|
| 292 |
+
|
| 293 |
+
Default paths configured in `config.py`:
|
| 294 |
+
- `models/color_model.pt` : Trained color model checkpoint
|
| 295 |
+
- `models/hierarchy_model.pth` : Trained hierarchy model checkpoint
|
| 296 |
+
- `models/gap_clip.pth` : Main GAP-CLIP model checkpoint
|
| 297 |
+
- `tokenizer_vocab.json` : Tokenizer vocabulary for color model
|
| 298 |
+
- `data.csv` : Training/validation dataset
|
| 299 |
+
|
| 300 |
+
### Dataset Format
|
| 301 |
+
|
| 302 |
+
The training dataset CSV should contain:
|
| 303 |
+
- `text`: Text description of the fashion item
|
| 304 |
+
- `color`: Color label (e.g., "red", "blue", "black")
|
| 305 |
+
- `hierarchy`: Category label (e.g., "dress", "shirt", "shoes")
|
| 306 |
+
- `local_image_path`: Path to the image file
|
| 307 |
+
|
| 308 |
+
Example:
|
| 309 |
+
```csv
|
| 310 |
+
text,color,hierarchy,local_image_path
|
| 311 |
+
"red summer dress with floral pattern",red,dress,data/images/001.jpg
|
| 312 |
+
"blue denim jeans casual style",blue,jeans,data/images/002.jpg
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
## 📦 Usage
|
| 316 |
+
|
| 317 |
+
### 1. Load Models from Hugging Face
|
| 318 |
+
|
| 319 |
+
If your models are already uploaded to Hugging Face:
|
| 320 |
+
|
| 321 |
+
```python
|
| 322 |
+
from example_usage import load_models_from_hf
|
| 323 |
+
|
| 324 |
+
# Load all models
|
| 325 |
+
models = load_models_from_hf("your-username/your-model")
|
| 326 |
+
|
| 327 |
+
color_model = models['color_model']
|
| 328 |
+
hierarchy_model = models['hierarchy_model']
|
| 329 |
+
main_model = models['main_model']
|
| 330 |
+
processor = models['processor']
|
| 331 |
+
device = models['device']
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
### 2. Text Search
|
| 335 |
+
|
| 336 |
+
```python
|
| 337 |
+
import torch
|
| 338 |
+
from transformers import CLIPProcessor
|
| 339 |
+
|
| 340 |
+
# Prepare text query
|
| 341 |
+
text_query = "red dress"
|
| 342 |
+
text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
|
| 343 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 344 |
+
|
| 345 |
+
# Get main model embeddings
|
| 346 |
+
with torch.no_grad():
|
| 347 |
+
outputs = main_model(**text_inputs)
|
| 348 |
+
text_features = outputs.text_embeds
|
| 349 |
+
|
| 350 |
+
# Get specialized embeddings
|
| 351 |
+
color_emb = color_model.get_text_embeddings([text_query])
|
| 352 |
+
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 353 |
+
```
|
| 354 |
+
|
| 355 |
+
### 3. Image Search
|
| 356 |
+
|
| 357 |
+
```python
|
| 358 |
+
from PIL import Image
|
| 359 |
+
|
| 360 |
+
# Load image
|
| 361 |
+
image = Image.open("path/to/image.jpg").convert("RGB")
|
| 362 |
+
image_inputs = processor(images=[image], return_tensors="pt")
|
| 363 |
+
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 364 |
+
|
| 365 |
+
# Get embeddings
|
| 366 |
+
with torch.no_grad():
|
| 367 |
+
outputs = main_model(**image_inputs)
|
| 368 |
+
image_features = outputs.image_embeds
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
### 4. Using the Example Script
|
| 372 |
+
|
| 373 |
+
The `example_usage.py` provides ready-to-use examples for loading and using GAP-CLIP:
|
| 374 |
+
|
| 375 |
+
```bash
|
| 376 |
+
# Load from HuggingFace and search with text
|
| 377 |
+
python example_usage.py \
|
| 378 |
+
--repo-id Leacb4/gap-clip \
|
| 379 |
+
--text "red summer dress"
|
| 380 |
+
|
| 381 |
+
# Search with image
|
| 382 |
+
python example_usage.py \
|
| 383 |
+
--repo-id Leacb4/gap-clip \
|
| 384 |
+
--image path/to/image.jpg
|
| 385 |
+
|
| 386 |
+
# Both text and image
|
| 387 |
+
python example_usage.py \
|
| 388 |
+
--repo-id Leacb4/gap-clip \
|
| 389 |
+
--text "blue denim jeans" \
|
| 390 |
+
--image path/to/image.jpg
|
| 391 |
+
```
|
| 392 |
+
|
| 393 |
+
This script demonstrates:
|
| 394 |
+
- Loading models from HuggingFace Hub
|
| 395 |
+
- Extracting text and image embeddings
|
| 396 |
+
- Accessing color and hierarchy subspaces
|
| 397 |
+
- Measuring alignment quality with specialized models
|
| 398 |
+
|
| 399 |
+
## 🎯 Model Training
|
| 400 |
+
|
| 401 |
+
### Train the Color Model
|
| 402 |
+
|
| 403 |
+
```python
|
| 404 |
+
from color_model import ColorCLIP, train_color_model
|
| 405 |
+
|
| 406 |
+
# Configuration
|
| 407 |
+
model = ColorCLIP(vocab_size=10000, embedding_dim=16)
|
| 408 |
+
# ... dataset configuration ...
|
| 409 |
+
|
| 410 |
+
# Training
|
| 411 |
+
train_color_model(model, train_loader, val_loader, num_epochs=20)
|
| 412 |
+
```
|
| 413 |
+
|
| 414 |
+
### Train the Hierarchy Model
|
| 415 |
+
|
| 416 |
+
```python
|
| 417 |
+
from training.hierarchy_model import Model as HierarchyModel, train_hierarchy_model
|
| 418 |
+
|
| 419 |
+
# Configuration
|
| 420 |
+
model = HierarchyModel(num_hierarchy_classes=10, embed_dim=64)
|
| 421 |
+
# ... dataset configuration ...
|
| 422 |
+
|
| 423 |
+
# Training
|
| 424 |
+
train_hierarchy_model(model, train_loader, val_loader, num_epochs=20)
|
| 425 |
+
```
|
| 426 |
+
|
| 427 |
+
### Train the Main CLIP Model
|
| 428 |
+
|
| 429 |
+
The main model trains with both specialized models using an enhanced contrastive loss.
|
| 430 |
|
| 431 |
+
**Option 1: Train with optimized hyperparameters (recommended)**:
|
| 432 |
+
```bash
|
| 433 |
+
python -m training.train_main_model
|
| 434 |
+
```
|
| 435 |
+
This uses hyperparameters optimized with Optuna (Trial 29, validation loss ~0.1129).
|
| 436 |
+
|
| 437 |
+
**Option 2: Train with default parameters**:
|
| 438 |
+
```bash
|
| 439 |
+
python -m training.main_model
|
| 440 |
+
```
|
| 441 |
+
This runs the main training loop with manually configured parameters.
|
| 442 |
+
|
| 443 |
+
**Default Training Parameters** (in `training/main_model.py`):
|
| 444 |
+
- `num_epochs = 20` : Number of training epochs
|
| 445 |
+
- `learning_rate = 1.5e-5` : Learning rate with AdamW optimizer
|
| 446 |
+
- `temperature = 0.09` : Temperature for softer contrastive learning
|
| 447 |
+
- `alignment_weight = 0.2` : Weight for color/hierarchy alignment loss
|
| 448 |
+
- `weight_decay = 5e-4` : L2 regularization to prevent overfitting
|
| 449 |
+
- `batch_size = 32` : Batch size
|
| 450 |
+
- `subset_size = 20000` : Dataset size for better generalization
|
| 451 |
+
- `reference_weight = 0.1` : Weight for base CLIP regularization
|
| 452 |
+
|
| 453 |
+
**Enhanced Loss Function**:
|
| 454 |
+
|
| 455 |
+
The training uses `enhanced_contrastive_loss` which combines:
|
| 456 |
+
|
| 457 |
+
1. **Triple Contrastive Loss** (weighted):
|
| 458 |
+
- Text-Image alignment (70%)
|
| 459 |
+
- Text-Attributes alignment (15%)
|
| 460 |
+
- Image-Attributes alignment (15%)
|
| 461 |
+
|
| 462 |
+
2. **Direct Alignment Loss** (combines color & hierarchy):
|
| 463 |
+
- MSE loss between main model color dims (0-15) and color model embeddings
|
| 464 |
+
- MSE loss between main model hierarchy dims (16-79) and hierarchy model embeddings
|
| 465 |
+
- Cosine similarity losses for both color and hierarchy
|
| 466 |
+
- Applied to both text and image embeddings
|
| 467 |
+
|
| 468 |
+
3. **Reference Model Loss** (optional):
|
| 469 |
+
- Keeps text embeddings close to base CLIP
|
| 470 |
+
- Improves cross-domain generalization
|
| 471 |
+
|
| 472 |
+
**Training Features**:
|
| 473 |
+
- Enhanced data augmentation (rotation, color jitter, blur, affine transforms)
|
| 474 |
+
- Gradient clipping (max_norm=1.0) to prevent exploding gradients
|
| 475 |
+
- ReduceLROnPlateau scheduler (patience=3, factor=0.5)
|
| 476 |
+
- Early stopping (patience=7)
|
| 477 |
+
- Automatic best model saving with checkpoints
|
| 478 |
+
- Detailed metrics logging (alignment losses, cosine similarities)
|
| 479 |
+
- Overfitting detection and warnings
|
| 480 |
+
- Training curves visualization with 3 plots (losses, overfitting gap, comparison)
|
| 481 |
+
|
| 482 |
+
### Hyperparameter Optimization
|
| 483 |
+
|
| 484 |
+
The project includes Optuna-based hyperparameter optimization:
|
| 485 |
+
|
| 486 |
+
```bash
|
| 487 |
+
python -m training.optuna_optimisation
|
| 488 |
+
```
|
| 489 |
+
|
| 490 |
+
This optimizes:
|
| 491 |
+
- Learning rate
|
| 492 |
+
- Temperature for contrastive loss
|
| 493 |
+
- Alignment weight
|
| 494 |
+
- Weight decay
|
| 495 |
+
|
| 496 |
+
Results are saved in `optuna/optuna_study.pkl` and visualizations in `optuna/optuna_optimization_history.png` and `optuna/optuna_param_importances.png`.
|
| 497 |
+
|
| 498 |
+
The best hyperparameters from Optuna optimization are used in `training/train_main_model.py`.
|
| 499 |
|
| 500 |
+
## 📊 Models
|
| 501 |
|
| 502 |
+
### Color Model
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
+
- **Architecture** : ResNet18 (image encoder) + Embedding (text encoder)
|
| 505 |
+
- **Embedding dimension** : 16
|
| 506 |
+
- **Trained on** : Fashion data with color annotations
|
| 507 |
+
- **Usage** : Extract color embeddings from text or images
|
| 508 |
|
| 509 |
+
### Hierarchy Model
|
| 510 |
+
|
| 511 |
+
- **Architecture** : ResNet18 (image encoder) + Embedding (hierarchy encoder)
|
| 512 |
+
- **Embedding dimension** : 64
|
| 513 |
+
- **Hierarchy classes** : shirt, dress, pant, shoe, bag, etc.
|
| 514 |
+
- **Usage** : Classify and encode categorical hierarchy
|
| 515 |
+
|
| 516 |
+
### Main CLIP Model (GAP-CLIP)
|
| 517 |
+
|
| 518 |
+
- **Architecture** : CLIP ViT-B/32 (LAION)
|
| 519 |
+
- **Base Model** : `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`
|
| 520 |
+
- **Training Approach** : Enhanced contrastive loss with direct attribute alignment
|
| 521 |
+
- **Embedding Dimensions** : 512 total
|
| 522 |
+
- Color subspace: dims 0-15 (16 dims)
|
| 523 |
+
- Hierarchy subspace: dims 16-79 (64 dims)
|
| 524 |
+
- General CLIP: dims 80-511 (432 dims)
|
| 525 |
+
- **Training Dataset** : 20,000 fashion items with color and hierarchy annotations
|
| 526 |
+
- **Validation Split** : 80/20 train-validation split
|
| 527 |
+
- **Optimizer** : AdamW with weight decay (5e-4)
|
| 528 |
+
- **Best Checkpoint** : Automatically saved based on validation loss
|
| 529 |
+
- **Features** :
|
| 530 |
+
- Multi-modal text-image search
|
| 531 |
+
- Guaranteed attribute positioning (GAP) in specific dimensions
|
| 532 |
+
- Direct alignment with specialized color and hierarchy models
|
| 533 |
+
- Maintains general CLIP capabilities for cross-domain tasks
|
| 534 |
+
- Reduced overfitting through augmentation and regularization
|
| 535 |
+
|
| 536 |
+
## 🔍 Advanced Usage Examples
|
| 537 |
+
|
| 538 |
+
### Search with Combined Embeddings
|
| 539 |
|
| 540 |
```python
|
|
|
|
|
|
|
| 541 |
import torch
|
| 542 |
+
import torch.nn.functional as F
|
| 543 |
+
|
| 544 |
+
# Text query
|
| 545 |
+
text_query = "red dress"
|
| 546 |
+
text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
|
| 547 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 548 |
+
|
| 549 |
+
# Main model embeddings
|
| 550 |
+
with torch.no_grad():
|
| 551 |
+
outputs = main_model(**text_inputs)
|
| 552 |
+
text_features = outputs.text_embeds # Shape: [1, 512]
|
| 553 |
+
|
| 554 |
+
# Extract specialized embeddings from main model
|
| 555 |
+
main_color_emb = text_features[:, :16] # Color dimensions (0-15)
|
| 556 |
+
main_hierarchy_emb = text_features[:, 16:80] # Hierarchy dimensions (16-79)
|
| 557 |
+
main_clip_emb = text_features[:, 80:] # General CLIP dimensions (80-511)
|
| 558 |
|
| 559 |
+
# Compare with specialized models
|
| 560 |
+
color_emb = color_model.get_text_embeddings([text_query])
|
| 561 |
+
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 562 |
|
| 563 |
+
# Measure alignment quality
|
| 564 |
+
color_similarity = F.cosine_similarity(color_emb, main_color_emb, dim=1)
|
| 565 |
+
hierarchy_similarity = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1)
|
|
|
|
| 566 |
|
| 567 |
+
print(f"Color alignment: {color_similarity.item():.4f}")
|
| 568 |
+
print(f"Hierarchy alignment: {hierarchy_similarity.item():.4f}")
|
| 569 |
+
|
| 570 |
+
# For search, you can use different strategies:
|
| 571 |
+
# 1. Use full embeddings for general search
|
| 572 |
+
# 2. Use color subspace for color-specific search
|
| 573 |
+
# 3. Use hierarchy subspace for category search
|
| 574 |
+
# 4. Weighted combination of subspaces
|
| 575 |
```
|
| 576 |
|
| 577 |
+
### Search in an Image Database
|
| 578 |
+
|
| 579 |
+
```python
|
| 580 |
+
import numpy as np
|
| 581 |
+
import torch
|
| 582 |
+
import torch.nn.functional as F
|
| 583 |
+
from tqdm import tqdm
|
| 584 |
+
|
| 585 |
+
# Step 1: Pre-compute image embeddings (do this once)
|
| 586 |
+
image_paths = [...] # List of image paths
|
| 587 |
+
image_features_list = []
|
| 588 |
+
|
| 589 |
+
print("Computing image embeddings...")
|
| 590 |
+
for img_path in tqdm(image_paths):
|
| 591 |
+
image = Image.open(img_path).convert("RGB")
|
| 592 |
+
image_inputs = processor(images=[image], return_tensors="pt")
|
| 593 |
+
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 594 |
+
|
| 595 |
+
with torch.no_grad():
|
| 596 |
+
outputs = main_model(**image_inputs)
|
| 597 |
+
features = outputs.image_embeds # Shape: [1, 512]
|
| 598 |
+
image_features_list.append(features.cpu())
|
| 599 |
+
|
| 600 |
+
# Stack all features
|
| 601 |
+
image_features = torch.cat(image_features_list, dim=0) # Shape: [N, 512]
|
| 602 |
+
|
| 603 |
+
# Step 2: Search with text query
|
| 604 |
+
query = "red dress"
|
| 605 |
+
text_inputs = processor(text=[query], padding=True, return_tensors="pt")
|
| 606 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 607 |
+
|
| 608 |
+
with torch.no_grad():
|
| 609 |
+
outputs = main_model(**text_inputs)
|
| 610 |
+
text_features = outputs.text_embeds # Shape: [1, 512]
|
| 611 |
+
|
| 612 |
+
# Step 3: Calculate similarities
|
| 613 |
+
# Normalize embeddings for cosine similarity
|
| 614 |
+
text_features_norm = F.normalize(text_features, dim=-1)
|
| 615 |
+
image_features_norm = F.normalize(image_features.to(device), dim=-1)
|
| 616 |
+
|
| 617 |
+
# Compute cosine similarities
|
| 618 |
+
similarities = (text_features_norm @ image_features_norm.T).squeeze(0) # Shape: [N]
|
| 619 |
+
|
| 620 |
+
# Step 4: Get top-k results
|
| 621 |
+
top_k = 10
|
| 622 |
+
top_scores, top_indices = similarities.topk(top_k, largest=True)
|
| 623 |
+
|
| 624 |
+
# Display results
|
| 625 |
+
print(f"\nTop {top_k} results for query: '{query}'")
|
| 626 |
+
for i, (idx, score) in enumerate(zip(top_indices, top_scores)):
|
| 627 |
+
print(f"{i+1}. {image_paths[idx]} (similarity: {score.item():.4f})")
|
| 628 |
+
|
| 629 |
+
# Optional: Filter by color or hierarchy
|
| 630 |
+
# Extract color embeddings from query
|
| 631 |
+
query_color_emb = text_features[:, :16]
|
| 632 |
+
# Extract hierarchy embeddings from query
|
| 633 |
+
query_hierarchy_emb = text_features[:, 16:80]
|
| 634 |
+
# Use these for more targeted search
|
| 635 |
+
```
|
| 636 |
+
|
| 637 |
+
## 📝 Evaluation
|
| 638 |
+
|
| 639 |
+
### Running All Evaluations
|
| 640 |
+
|
| 641 |
+
Use the orchestrator script to run all paper evaluations:
|
| 642 |
+
|
| 643 |
+
```bash
|
| 644 |
+
python evaluation/run_all_evaluations.py
|
| 645 |
+
```
|
| 646 |
+
|
| 647 |
+
Or run specific sections:
|
| 648 |
+
```bash
|
| 649 |
+
python evaluation/run_all_evaluations.py --steps sec51,sec52
|
| 650 |
+
```
|
| 651 |
+
|
| 652 |
+
**Available steps**:
|
| 653 |
+
| Step | Paper Section | Description |
|
| 654 |
+
|------|--------------|-------------|
|
| 655 |
+
| `sec51` | §5.1 | Color model accuracy (Table 1) |
|
| 656 |
+
| `sec52` | §5.2 | Category model confusion matrices (Table 2) |
|
| 657 |
+
| `sec533` | §5.3.3 | NN classification accuracy (Table 3) |
|
| 658 |
+
| `sec5354` | §5.3.4-5 | Separation & zero-shot semantic eval |
|
| 659 |
+
| `sec536` | §5.3.6 | Embedding structure Tests A/B/C (Table 4) |
|
| 660 |
+
| `annex92` | Annex 9.2 | Color similarity heatmaps |
|
| 661 |
+
| `annex93` | Annex 9.3 | t-SNE visualizations |
|
| 662 |
+
| `annex94` | Annex 9.4 | Fashion search demo |
|
| 663 |
+
|
| 664 |
+
**Evaluation Datasets**:
|
| 665 |
+
1. **Internal dataset** (~50,000 samples) — Fashion items with color and category annotations
|
| 666 |
+
2. **KAGL Marqo** (HuggingFace dataset) — Real-world fashion e-commerce data
|
| 667 |
+
3. **Fashion-MNIST** (~10,000 samples) — Standard benchmark with 10 categories
|
| 668 |
+
|
| 669 |
+
**Evaluation Metrics**:
|
| 670 |
+
- Nearest-neighbor classification accuracy
|
| 671 |
+
- Centroid-based classification accuracy
|
| 672 |
+
- Separation score (intra-class vs inter-class cosine similarity)
|
| 673 |
+
- Confusion matrices (text and image modalities)
|
| 674 |
+
|
| 675 |
+
**Baseline Comparison**: All evaluations compare GAP-CLIP against `patrickjohncyh/fashion-clip`.
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
## 📊 Performance & Results
|
| 679 |
+
|
| 680 |
+
The evaluation framework tests GAP-CLIP across three datasets with comparison to the Fashion-CLIP baseline.
|
| 681 |
+
|
| 682 |
+
### Evaluation Metrics
|
| 683 |
+
|
| 684 |
+
**Color Classification** (dimensions 0-15):
|
| 685 |
+
- Nearest Neighbor Accuracy
|
| 686 |
+
- Centroid-based Accuracy
|
| 687 |
+
- Separation Score (class separability)
|
| 688 |
+
|
| 689 |
+
**Hierarchy Classification** (dimensions 16-79):
|
| 690 |
+
- Nearest Neighbor Accuracy
|
| 691 |
+
- Centroid-based Accuracy
|
| 692 |
+
- Separation Score
|
| 693 |
+
|
| 694 |
+
### Datasets Used for Evaluation
|
| 695 |
+
|
| 696 |
+
1. **Fashion-MNIST**: 10,000 grayscale fashion item images
|
| 697 |
+
- 10 categories (T-shirt, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot)
|
| 698 |
+
- Mapped to model's hierarchy classes
|
| 699 |
+
|
| 700 |
+
2. **KAGL Marqo Dataset**: Real-world fashion images from HuggingFace
|
| 701 |
+
- Diverse fashion items with rich metadata
|
| 702 |
+
- Color and category annotations
|
| 703 |
+
- Realistic product images
|
| 704 |
+
|
| 705 |
+
3. **Local Validation Set**: Custom validation dataset
|
| 706 |
+
- Fashion items with local image paths
|
| 707 |
+
- Annotated with colors and hierarchies
|
| 708 |
+
- Domain-specific evaluation
|
| 709 |
+
|
| 710 |
+
### Comparative Analysis
|
| 711 |
+
|
| 712 |
+
The evaluation includes:
|
| 713 |
+
- **Baseline comparison**: GAP-CLIP vs `patrickjohncyh/fashion-clip`
|
| 714 |
+
- **Subspace analysis**: Dedicated dimensions (0-79) vs full space (0-511)
|
| 715 |
+
- **Cross-dataset generalization**: Performance consistency across datasets
|
| 716 |
+
- **Alignment quality**: How well specialized dimensions match expert models
|
| 717 |
+
|
| 718 |
+
All visualizations (confusion matrices, t-SNE plots, heatmaps) are automatically saved in the analysis directory.
|
| 719 |
+
|
| 720 |
+
## 📄 Citation
|
| 721 |
+
|
| 722 |
+
If you use GAP-CLIP in your research, please cite:
|
| 723 |
|
| 724 |
```bibtex
|
| 725 |
@misc{gap-clip-2024,
|
| 726 |
title={GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings for Fashion Search},
|
| 727 |
author={Sarfati, Lea Attia},
|
| 728 |
year={2024},
|
| 729 |
+
note={A multi-loss framework combining contrastive learning with direct attribute alignment},
|
| 730 |
+
howpublished={\url{https://huggingface.co/Leacb4/gap-clip}},
|
| 731 |
+
abstract={GAP-CLIP introduces a novel training approach that guarantees specific embedding
|
| 732 |
+
dimensions encode color (dims 0-15) and hierarchy (dims 16-79) information through
|
| 733 |
+
direct alignment with specialized models, while maintaining full CLIP capabilities
|
| 734 |
+
in the remaining dimensions (80-511).}
|
| 735 |
}
|
| 736 |
```
|
| 737 |
|
| 738 |
+
### Key Contributions
|
| 739 |
+
|
| 740 |
+
- **Guaranteed Attribute Positioning**: Specific dimensions reliably encode color and hierarchy
|
| 741 |
+
- **Multi-Loss Training**: Combines contrastive learning with MSE and cosine alignment losses
|
| 742 |
+
- **Specialized Model Alignment**: Direct supervision from expert color and hierarchy models
|
| 743 |
+
- **Preserved Generalization**: Maintains base CLIP capabilities for cross-domain tasks
|
| 744 |
+
- **Comprehensive Evaluation**: Tested across multiple datasets with baseline comparisons
|
| 745 |
+
|
| 746 |
+
## ❓ FAQ & Troubleshooting
|
| 747 |
+
|
| 748 |
+
### Q: What are the minimum hardware requirements?
|
| 749 |
+
|
| 750 |
+
**A**:
|
| 751 |
+
- **GPU**: Recommended for training (CUDA or MPS). CPU training is very slow.
|
| 752 |
+
- **RAM**: Minimum 16GB, recommended 32GB for training
|
| 753 |
+
- **Storage**: ~5GB for models and datasets
|
| 754 |
+
|
| 755 |
+
### Q: Why are my embeddings not aligned?
|
| 756 |
+
|
| 757 |
+
**A**: Check that:
|
| 758 |
+
1. You're using the correct dimension ranges (0-15 for color, 16-79 for hierarchy)
|
| 759 |
+
2. The model was trained with alignment_weight > 0
|
| 760 |
+
3. Color and hierarchy models were properly loaded during training
|
| 761 |
+
|
| 762 |
+
### Q: How do I use only the color or hierarchy subspace for search?
|
| 763 |
+
|
| 764 |
+
**A**:
|
| 765 |
+
```python
|
| 766 |
+
# Extract and use only color embeddings
|
| 767 |
+
text_color_emb = text_features[:, :16]
|
| 768 |
+
image_color_emb = image_features[:, :16]
|
| 769 |
+
color_similarity = F.cosine_similarity(text_color_emb, image_color_emb)
|
| 770 |
+
|
| 771 |
+
# Extract and use only hierarchy embeddings
|
| 772 |
+
text_hierarchy_emb = text_features[:, 16:80]
|
| 773 |
+
image_hierarchy_emb = image_features[:, 16:80]
|
| 774 |
+
hierarchy_similarity = F.cosine_similarity(text_hierarchy_emb, image_hierarchy_emb)
|
| 775 |
+
```
|
| 776 |
+
|
| 777 |
+
### Q: Can I add more attributes beyond color and hierarchy?
|
| 778 |
+
|
| 779 |
+
**A**: Yes! The architecture is extensible:
|
| 780 |
+
1. Train a new specialized model for your attribute
|
| 781 |
+
2. Reserve additional dimensions in the embedding space
|
| 782 |
+
3. Add alignment losses for these dimensions in `enhanced_contrastive_loss`
|
| 783 |
+
4. Update `config.py` with new dimension ranges
|
| 784 |
+
|
| 785 |
+
### Q: How do I evaluate on my own dataset?
|
| 786 |
+
|
| 787 |
+
**A**:
|
| 788 |
+
1. Format your dataset as CSV with columns: `text`, `color`, `hierarchy`, `local_image_path`
|
| 789 |
+
2. Update `config.local_dataset_path` in `config.py`
|
| 790 |
+
3. Run the evaluation: `python evaluation/run_all_evaluations.py`
|
| 791 |
+
|
| 792 |
+
### Q: Training loss is decreasing but validation loss is increasing. What should I do?
|
| 793 |
+
|
| 794 |
+
**A**: This indicates overfitting. Try:
|
| 795 |
+
- Increase `weight_decay` (e.g., from 5e-4 to 1e-3)
|
| 796 |
+
- Reduce `alignment_weight` (e.g., from 0.2 to 0.1)
|
| 797 |
+
- Increase dataset size (`subset_size`)
|
| 798 |
+
- Add more data augmentation in `CustomDataset`
|
| 799 |
+
- Enable or increase early stopping patience
|
| 800 |
+
|
| 801 |
+
### Q: Can I fine-tune GAP-CLIP on a specific domain?
|
| 802 |
+
|
| 803 |
+
**A**: Yes! Load the checkpoint and continue training:
|
| 804 |
+
```python
|
| 805 |
+
checkpoint = torch.load('models/gap_clip.pth')
|
| 806 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 807 |
+
# Continue training with your domain-specific data
|
| 808 |
+
```
|
| 809 |
+
|
| 810 |
+
## 🧪 Testing & Evaluation
|
| 811 |
+
|
| 812 |
+
### Quick Test
|
| 813 |
+
|
| 814 |
+
```bash
|
| 815 |
+
# Test configuration
|
| 816 |
+
python -c "import config; config.print_config()"
|
| 817 |
+
|
| 818 |
+
# Test model loading
|
| 819 |
+
python example_usage.py --repo-id Leacb4/gap-clip --text "red dress"
|
| 820 |
+
```
|
| 821 |
+
|
| 822 |
+
### Full Evaluation Suite
|
| 823 |
+
|
| 824 |
+
```bash
|
| 825 |
+
# Run all evaluations
|
| 826 |
+
cd evaluation
|
| 827 |
+
python run_all_evaluations.py --repo-id Leacb4/gap-clip
|
| 828 |
+
|
| 829 |
+
# Results will be saved to evaluation_results/ with:
|
| 830 |
+
# - summary.json: Detailed metrics
|
| 831 |
+
# - summary_comparison.png: Visual comparison
|
| 832 |
+
```
|
| 833 |
+
|
| 834 |
+
## 🐛 Known Issues & Fixes
|
| 835 |
+
|
| 836 |
+
### Fixed Issues ✨
|
| 837 |
+
|
| 838 |
+
1. **Color model image loading bug** (Fixed in `color_model.py`)
|
| 839 |
+
- Previous: `Image.open(config.column_local_image_path)`
|
| 840 |
+
- Fixed: `Image.open(img_path)` - Now correctly gets path from dataframe
|
| 841 |
+
|
| 842 |
+
2. **Function naming in training** (Fixed in `training/main_model.py` and `training/train_main_model.py`)
|
| 843 |
+
- Previous: `train_one_epoch_enhanced`
|
| 844 |
+
- Fixed: `train_one_epoch` - Consistent naming
|
| 845 |
+
|
| 846 |
+
3. **Device compatibility** (Improved in `config.py`)
|
| 847 |
+
- Now automatically detects and selects best device (CUDA > MPS > CPU)
|
| 848 |
+
|
| 849 |
+
## 🎓 Learning Resources
|
| 850 |
+
|
| 851 |
+
### Documentation Files
|
| 852 |
+
|
| 853 |
+
- **README.md** (this file): Complete project documentation
|
| 854 |
+
- **paper/latex_paper.ltx**: Scientific paper (LaTeX source)
|
| 855 |
+
- **MODEL_CARD.md**: Hugging Face model card
|
| 856 |
+
|
| 857 |
+
### Code Examples
|
| 858 |
+
|
| 859 |
+
- **example_usage.py**: Basic usage with Hugging Face Hub
|
| 860 |
+
- **evaluation/annex94_search_demo.py**: Interactive search demo
|
| 861 |
+
- **evaluation/annex93_tsne.py**: t-SNE visualization
|
| 862 |
+
|
| 863 |
+
## 🤝 Contributing
|
| 864 |
+
|
| 865 |
+
We welcome contributions! Here's how:
|
| 866 |
+
|
| 867 |
+
1. **Report bugs**: Open an issue with detailed description
|
| 868 |
+
2. **Suggest features**: Describe your idea in an issue
|
| 869 |
+
3. **Submit PR**: Fork, create branch, commit, and open pull request
|
| 870 |
+
4. **Improve docs**: Help make documentation clearer
|
| 871 |
+
|
| 872 |
+
### Development Setup
|
| 873 |
+
|
| 874 |
+
```bash
|
| 875 |
+
# Install with dev dependencies
|
| 876 |
+
pip install -e ".[dev]"
|
| 877 |
+
|
| 878 |
+
# Run tests (if available)
|
| 879 |
+
pytest
|
| 880 |
+
|
| 881 |
+
# Format code
|
| 882 |
+
black .
|
| 883 |
+
flake8 .
|
| 884 |
+
```
|
| 885 |
+
|
| 886 |
+
## 📊 Project Statistics
|
| 887 |
+
|
| 888 |
+
- **Language**: Python 3.8+
|
| 889 |
+
- **Framework**: PyTorch 2.0+
|
| 890 |
+
- **Models**: 3 specialized models (color, hierarchy, main)
|
| 891 |
+
- **Embedding Size**: 512 dimensions
|
| 892 |
+
- **Training Data**: 20,000+ fashion items
|
| 893 |
+
- **Lines of Code**: 5,000+ (including documentation)
|
| 894 |
+
- **Documentation**: Comprehensive docstrings and guides
|
| 895 |
+
|
| 896 |
+
## 🔗 Links
|
| 897 |
+
|
| 898 |
+
- **Hugging Face Hub**: [Leacb4/gap-clip](https://huggingface.co/Leacb4/gap-clip)
|
| 899 |
+
- **GitHub**: [github.com/Leacb4/gap-clip](https://github.com/Leacb4/gap-clip)
|
| 900 |
+
- **Contact**: lea.attia@gmail.com
|
| 901 |
+
|
| 902 |
+
## 📧 Contact & Support
|
| 903 |
+
|
| 904 |
+
**Author**: Lea Attia Sarfati
|
| 905 |
+
**Email**: lea.attia@gmail.com
|
| 906 |
+
**Hugging Face**: [@Leacb4](https://huggingface.co/Leacb4)
|
| 907 |
+
|
| 908 |
+
For questions, issues, or suggestions:
|
| 909 |
+
- 🐛 **Bug reports**: Open an issue on GitHub
|
| 910 |
+
- 💡 **Feature requests**: Open an issue with [Feature Request] tag
|
| 911 |
+
- 📧 **Direct contact**: lea.attia@gmail.com
|
| 912 |
+
- 💬 **Discussions**: Hugging Face Discussions
|
| 913 |
+
|
| 914 |
+
---
|
| 915 |
+
|
| 916 |
+
## 📜 License
|
| 917 |
+
|
| 918 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 919 |
+
|
| 920 |
+
## 🙏 Acknowledgments
|
| 921 |
+
|
| 922 |
+
- LAION team for the base CLIP model
|
| 923 |
+
- Hugging Face for transformers library and model hosting
|
| 924 |
+
- PyTorch team for the deep learning framework
|
| 925 |
+
- Fashion-MNIST dataset creators
|
| 926 |
+
- All contributors and users of this project
|
| 927 |
+
|
| 928 |
+
---
|
| 929 |
+
|
| 930 |
+
**⭐ If you find this project useful, please consider giving it a star on GitHub!**
|
| 931 |
|
| 932 |
+
**📢 Version**: 1.0.0 | **Status**: Production Ready ✅ | **Last Updated**: December 2024
|
__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings
|
| 3 |
+
==============================================================
|
| 4 |
+
|
| 5 |
+
A multimodal fashion search model that combines color embeddings,
|
| 6 |
+
hierarchical category embeddings, and general CLIP capabilities.
|
| 7 |
+
|
| 8 |
+
Main Components:
|
| 9 |
+
- ColorCLIP: Specialized color embedding model (16 dims)
|
| 10 |
+
- HierarchyModel: Category classification model (64 dims)
|
| 11 |
+
- GAP-CLIP: Main CLIP model with aligned subspaces (512 dims)
|
| 12 |
+
|
| 13 |
+
Quick Start:
|
| 14 |
+
>>> from gap_clip import load_models_from_hf
|
| 15 |
+
>>> models = load_models_from_hf("Leacb4/gap-clip")
|
| 16 |
+
>>> # Use models for search...
|
| 17 |
+
|
| 18 |
+
For more information, see the README.md file or visit:
|
| 19 |
+
https://huggingface.co/Leacb4/gap-clip
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
__version__ = "1.0.0"
|
| 23 |
+
__author__ = "Lea Attia Sarfati"
|
| 24 |
+
__email__ = "lea.attia@gmail.com"
|
| 25 |
+
|
| 26 |
+
# Import main components for easy access
|
| 27 |
+
try:
|
| 28 |
+
from .color_model import ColorCLIP, Tokenizer
|
| 29 |
+
from .training.hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 30 |
+
from .example_usage import load_models_from_hf, example_search
|
| 31 |
+
import config
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
'ColorCLIP',
|
| 35 |
+
'Tokenizer',
|
| 36 |
+
'HierarchyModel',
|
| 37 |
+
'HierarchyExtractor',
|
| 38 |
+
'load_models_from_hf',
|
| 39 |
+
'example_search',
|
| 40 |
+
'config',
|
| 41 |
+
'__version__',
|
| 42 |
+
]
|
| 43 |
+
except ImportError:
|
| 44 |
+
# If imports fail, it's ok - the package can still be used
|
| 45 |
+
__all__ = ['__version__']
|
config.py
CHANGED
|
@@ -1,216 +1,75 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
used throughout the GAP-CLIP project. It provides a single source of truth
|
| 7 |
-
for model paths, embedding dimensions, dataset locations, and device settings.
|
| 8 |
-
|
| 9 |
-
Key Configuration Categories:
|
| 10 |
-
- Model paths: Paths to trained model checkpoints
|
| 11 |
-
- Data paths: Dataset locations and CSV files
|
| 12 |
-
- Embedding dimensions: Size of color and hierarchy embeddings
|
| 13 |
-
- Column names: CSV column identifiers for data loading
|
| 14 |
-
- Device: Hardware accelerator configuration (CUDA, MPS, or CPU)
|
| 15 |
-
|
| 16 |
-
Usage:
|
| 17 |
-
>>> import config
|
| 18 |
-
>>> model_path = config.main_model_path
|
| 19 |
-
>>> device = config.device
|
| 20 |
-
>>> color_dim = config.color_emb_dim
|
| 21 |
-
|
| 22 |
-
Author: Lea Attia Sarfati
|
| 23 |
-
Project: GAP-CLIP (Guaranteed Attribute Positioning in CLIP Embeddings)
|
| 24 |
"""
|
| 25 |
|
| 26 |
-
from
|
|
|
|
|
|
|
| 27 |
import torch
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# MODEL PATHS
|
| 32 |
-
# =============================================================================
|
| 33 |
-
# Paths to trained model checkpoints used for inference and fine-tuning
|
| 34 |
-
|
| 35 |
-
#: Path to the trained color model checkpoint (ColorCLIP)
|
| 36 |
-
#: This model extracts 16-dimensional color embeddings from images and text
|
| 37 |
-
color_model_path: Final[str] = "models/color_model.pt"
|
| 38 |
-
|
| 39 |
-
#: Path to the trained hierarchy model checkpoint
|
| 40 |
-
#: This model extracts 64-dimensional category embeddings (e.g., dress, shirt, shoes)
|
| 41 |
-
hierarchy_model_path: Final[str] = "models/hierarchy_model.pth"
|
| 42 |
-
|
| 43 |
-
#: Path to the main GAP-CLIP model checkpoint
|
| 44 |
-
#: This is the primary 512-dimensional CLIP model with aligned color and hierarchy subspaces
|
| 45 |
-
main_model_path: Final[str] = "models/gap_clip.pth"
|
| 46 |
-
|
| 47 |
-
#: Path to the tokenizer vocabulary JSON file
|
| 48 |
-
#: Used by the color model's text encoder for tokenization
|
| 49 |
-
tokeniser_path: Final[str] = "tokenizer_vocab.json"
|
| 50 |
-
|
| 51 |
-
# =============================================================================
|
| 52 |
-
# DATASET PATHS
|
| 53 |
-
# =============================================================================
|
| 54 |
-
# Paths to training, validation, and test datasets
|
| 55 |
-
|
| 56 |
-
#: Path to the main training dataset with local image paths
|
| 57 |
-
#: CSV format with columns: text, color, hierarchy, local_image_path
|
| 58 |
-
local_dataset_path: Final[str] = "data/data_with_local_paths.csv"
|
| 59 |
-
|
| 60 |
-
#: Path to Fashion-MNIST test dataset for evaluation
|
| 61 |
-
#: Used for zero-shot classification benchmarking
|
| 62 |
-
fashion_mnist_test_path: Final[str] = "data/fashion-mnist_test.csv"
|
| 63 |
-
|
| 64 |
-
#: Directory containing image files for the dataset
|
| 65 |
-
images_dir: Final[str] = "data/images"
|
| 66 |
-
|
| 67 |
-
#: Directory for evaluation scripts and results
|
| 68 |
-
evaluation_directory: Final[str] = "evaluation/"
|
| 69 |
-
|
| 70 |
-
# =============================================================================
|
| 71 |
-
# CSV COLUMN NAMES
|
| 72 |
-
# =============================================================================
|
| 73 |
-
# Column identifiers used in dataset CSV files
|
| 74 |
-
|
| 75 |
-
#: Column name for local file paths to images
|
| 76 |
-
column_local_image_path: Final[str] = "local_image_path"
|
| 77 |
-
|
| 78 |
-
#: Column name for image URLs (when using remote images)
|
| 79 |
-
column_url_image: Final[str] = "image_url"
|
| 80 |
-
|
| 81 |
-
#: Column name for text descriptions of fashion items
|
| 82 |
-
text_column: Final[str] = "text"
|
| 83 |
-
|
| 84 |
-
#: Column name for color labels (e.g., "red", "blue", "black")
|
| 85 |
-
color_column: Final[str] = "color"
|
| 86 |
-
|
| 87 |
-
#: Column name for hierarchy/category labels (e.g., "dress", "shirt", "shoes")
|
| 88 |
-
hierarchy_column: Final[str] = "hierarchy"
|
| 89 |
-
|
| 90 |
-
# =============================================================================
|
| 91 |
-
# EMBEDDING DIMENSIONS
|
| 92 |
-
# =============================================================================
|
| 93 |
-
# Dimensionality of various embedding spaces
|
| 94 |
-
|
| 95 |
-
#: Dimension of color embeddings (positions 0-15 in main model)
|
| 96 |
-
#: These dimensions are explicitly trained to encode color information
|
| 97 |
-
color_emb_dim: Final[int] = 16
|
| 98 |
-
|
| 99 |
-
#: Dimension of hierarchy embeddings (positions 16-79 in main model)
|
| 100 |
-
#: These dimensions are explicitly trained to encode category information
|
| 101 |
-
hierarchy_emb_dim: Final[int] = 64
|
| 102 |
-
|
| 103 |
-
#: Total dimension of main CLIP embeddings
|
| 104 |
-
#: Structure: [color (16) | hierarchy (64) | general CLIP (432)] = 512
|
| 105 |
-
main_emb_dim: Final[int] = 512
|
| 106 |
-
|
| 107 |
-
#: Dimension of general CLIP embeddings (remaining dimensions after color and hierarchy)
|
| 108 |
-
general_clip_dim: Final[int] = main_emb_dim - color_emb_dim - hierarchy_emb_dim
|
| 109 |
-
|
| 110 |
-
# =============================================================================
|
| 111 |
-
# DEVICE CONFIGURATION
|
| 112 |
-
# =============================================================================
|
| 113 |
-
# Hardware accelerator settings for model training and inference
|
| 114 |
-
|
| 115 |
-
def get_device() -> torch.device:
|
| 116 |
-
"""
|
| 117 |
-
Automatically detect and return the best available device.
|
| 118 |
-
|
| 119 |
-
Priority order:
|
| 120 |
-
1. CUDA (NVIDIA GPU) if available
|
| 121 |
-
2. MPS (Apple Silicon) if available
|
| 122 |
-
3. CPU as fallback
|
| 123 |
-
|
| 124 |
-
Returns:
|
| 125 |
-
torch.device: The device to use for tensor operations
|
| 126 |
-
|
| 127 |
-
Examples:
|
| 128 |
-
>>> device = get_device()
|
| 129 |
-
>>> model = model.to(device)
|
| 130 |
-
"""
|
| 131 |
if torch.cuda.is_available():
|
| 132 |
return torch.device("cuda")
|
| 133 |
-
|
| 134 |
return torch.device("mps")
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
#
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
#
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def validate_paths() -> bool:
|
| 170 |
-
"""
|
| 171 |
-
Validate that all critical paths exist and are accessible.
|
| 172 |
-
|
| 173 |
-
Returns:
|
| 174 |
-
bool: True if all paths exist, False otherwise
|
| 175 |
-
|
| 176 |
-
Raises:
|
| 177 |
-
FileNotFoundError: If critical model files are missing
|
| 178 |
-
"""
|
| 179 |
-
critical_paths = [
|
| 180 |
-
color_model_path,
|
| 181 |
-
hierarchy_model_path,
|
| 182 |
-
main_model_path,
|
| 183 |
-
tokeniser_path
|
| 184 |
-
]
|
| 185 |
-
|
| 186 |
-
missing_paths = [p for p in critical_paths if not os.path.exists(p)]
|
| 187 |
-
|
| 188 |
-
if missing_paths:
|
| 189 |
-
print(f"⚠️ Warning: Missing files: {', '.join(missing_paths)}")
|
| 190 |
-
return False
|
| 191 |
-
|
| 192 |
-
return True
|
| 193 |
|
| 194 |
def print_config() -> None:
|
| 195 |
-
"""
|
| 196 |
-
Print a formatted summary of the current configuration.
|
| 197 |
-
|
| 198 |
-
Useful for debugging and logging training runs.
|
| 199 |
-
"""
|
| 200 |
-
print("=" * 80)
|
| 201 |
print("GAP-CLIP Configuration")
|
| 202 |
-
print("
|
| 203 |
-
print(f"
|
| 204 |
-
print(f"
|
| 205 |
-
print(f"
|
| 206 |
-
print(f"
|
| 207 |
-
print(f"
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Project configuration for GAP-CLIP scripts.
|
| 3 |
+
|
| 4 |
+
This module provides default paths, column names, and runtime constants used by
|
| 5 |
+
training/evaluation scripts. Values can be edited locally as needed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _detect_device() -> torch.device:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
if torch.cuda.is_available():
|
| 16 |
return torch.device("cuda")
|
| 17 |
+
if torch.backends.mps.is_available():
|
| 18 |
return torch.device("mps")
|
| 19 |
+
return torch.device("cpu")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
ROOT_DIR = Path(__file__).resolve().parent
|
| 23 |
+
|
| 24 |
+
# Runtime/device
|
| 25 |
+
device = _detect_device()
|
| 26 |
+
|
| 27 |
+
# Embedding dimensions
|
| 28 |
+
color_emb_dim = 16
|
| 29 |
+
hierarchy_emb_dim = 64
|
| 30 |
+
main_emb_dim = 512
|
| 31 |
+
|
| 32 |
+
# Default training hyperparameters
|
| 33 |
+
DEFAULT_BATCH_SIZE = 32
|
| 34 |
+
DEFAULT_LEARNING_RATE = 1.5e-5
|
| 35 |
+
DEFAULT_TEMPERATURE = 0.09
|
| 36 |
+
|
| 37 |
+
# Data columns
|
| 38 |
+
text_column = "text"
|
| 39 |
+
color_column = "color"
|
| 40 |
+
hierarchy_column = "hierarchy"
|
| 41 |
+
column_local_image_path = "local_image_path"
|
| 42 |
+
column_url_image = "image_url"
|
| 43 |
+
|
| 44 |
+
# Paths
|
| 45 |
+
local_dataset_path = str(ROOT_DIR / "data" / "data.csv")
|
| 46 |
+
color_model_path = str(ROOT_DIR / "models" / "color_model.pt")
|
| 47 |
+
hierarchy_model_path = str(ROOT_DIR / "models" / "hierarchy_model.pth")
|
| 48 |
+
main_model_path = str(ROOT_DIR / "models" / "gap_clip.pth")
|
| 49 |
+
tokeniser_path = str(ROOT_DIR / "tokenizer_vocab.json")
|
| 50 |
+
images_dir = str(ROOT_DIR / "data" / "images")
|
| 51 |
+
fashion_mnist_csv = str(ROOT_DIR / "data" / "fashion-mnist_test.csv")
|
| 52 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def print_config() -> None:
|
| 55 |
+
"""Pretty-print core configuration."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
print("GAP-CLIP Configuration")
|
| 57 |
+
print(f" device: {device}")
|
| 58 |
+
print(f" dims: color={color_emb_dim}, hierarchy={hierarchy_emb_dim}, total={main_emb_dim}")
|
| 59 |
+
print(f" dataset: {local_dataset_path}")
|
| 60 |
+
print(f" color model: {color_model_path}")
|
| 61 |
+
print(f" hierarchy model: {hierarchy_model_path}")
|
| 62 |
+
print(f" main model: {main_model_path}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def validate_paths() -> dict[str, bool]:
|
| 66 |
+
"""Return path existence checks for key files."""
|
| 67 |
+
checks = {
|
| 68 |
+
"local_dataset_path": Path(local_dataset_path).exists(),
|
| 69 |
+
"color_model_path": Path(color_model_path).exists(),
|
| 70 |
+
"hierarchy_model_path": Path(hierarchy_model_path).exists(),
|
| 71 |
+
"main_model_path": Path(main_model_path).exists(),
|
| 72 |
+
"tokeniser_path": Path(tokeniser_path).exists(),
|
| 73 |
+
}
|
| 74 |
+
return checks
|
| 75 |
+
|
data/{dowload_images_data.py → download_images.py}
RENAMED
|
@@ -20,7 +20,7 @@ from threading import Lock
|
|
| 20 |
import config
|
| 21 |
|
| 22 |
class ImageDownloader:
|
| 23 |
-
def __init__(self, df, images_dir=
|
| 24 |
"""
|
| 25 |
Initialize the image downloader.
|
| 26 |
|
|
@@ -202,7 +202,7 @@ def main():
|
|
| 202 |
# Create the downloader
|
| 203 |
downloader = ImageDownloader(
|
| 204 |
df=df,
|
| 205 |
-
images_dir=
|
| 206 |
max_workers=8,
|
| 207 |
timeout=10
|
| 208 |
)
|
|
@@ -211,7 +211,7 @@ def main():
|
|
| 211 |
df_with_paths = downloader.download_all_images()
|
| 212 |
|
| 213 |
print("\n🎉 DOWNLOAD COMPLETED!")
|
| 214 |
-
print("💡 You can now use the local images
|
| 215 |
|
| 216 |
if __name__ == "__main__":
|
| 217 |
main()
|
|
|
|
| 20 |
import config
|
| 21 |
|
| 22 |
class ImageDownloader:
|
| 23 |
+
def __init__(self, df, images_dir="data/images", max_workers=8, timeout=10):
|
| 24 |
"""
|
| 25 |
Initialize the image downloader.
|
| 26 |
|
|
|
|
| 202 |
# Create the downloader
|
| 203 |
downloader = ImageDownloader(
|
| 204 |
df=df,
|
| 205 |
+
images_dir="data/images",
|
| 206 |
max_workers=8,
|
| 207 |
timeout=10
|
| 208 |
)
|
|
|
|
| 211 |
df_with_paths = downloader.download_all_images()
|
| 212 |
|
| 213 |
print("\n🎉 DOWNLOAD COMPLETED!")
|
| 214 |
+
print("💡 You can now use the local images.")
|
| 215 |
|
| 216 |
if __name__ == "__main__":
|
| 217 |
main()
|
data/get_csv_from_chunks.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to combine multiple CSV files into a single DataFrame.
|
| 3 |
+
This file allows merging multiple CSV files (chunks) into a single pandas DataFrame.
|
| 4 |
+
It is useful when data is split into multiple files for easier processing
|
| 5 |
+
and needs to be combined into a single dataset for training or evaluation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import glob
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
def create_single_dataframe_from_chunks(chunks_directory, pattern='*.csv'):
|
| 13 |
+
"""
|
| 14 |
+
Create a single pandas DataFrame by combining multiple CSV chunks.
|
| 15 |
+
|
| 16 |
+
Parameters:
|
| 17 |
+
-----------
|
| 18 |
+
chunks_directory : str
|
| 19 |
+
Directory containing the CSV chunk files
|
| 20 |
+
pattern : str, default='*.csv'
|
| 21 |
+
Pattern to match the CSV files
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
--------
|
| 25 |
+
pandas.DataFrame
|
| 26 |
+
Combined DataFrame from all CSV chunks
|
| 27 |
+
"""
|
| 28 |
+
# Get a list of all CSV files in the directory that match the pattern
|
| 29 |
+
csv_files = glob.glob(os.path.join(chunks_directory, pattern))
|
| 30 |
+
|
| 31 |
+
# Check if any files were found
|
| 32 |
+
if not csv_files:
|
| 33 |
+
raise ValueError(f"No CSV files found in {chunks_directory} matching pattern {pattern}")
|
| 34 |
+
|
| 35 |
+
print(f"Found {len(csv_files)} CSV files to combine")
|
| 36 |
+
|
| 37 |
+
# Create an empty list to store individual DataFrames
|
| 38 |
+
dfs = []
|
| 39 |
+
|
| 40 |
+
# Read each CSV file and append it to the list
|
| 41 |
+
for file in csv_files:
|
| 42 |
+
print(f"Reading {file}...")
|
| 43 |
+
chunk_df = pd.read_csv(file)
|
| 44 |
+
dfs.append(chunk_df)
|
| 45 |
+
print(f"Added chunk with shape {chunk_df.shape}")
|
| 46 |
+
|
| 47 |
+
# Combine all DataFrames into one
|
| 48 |
+
combined_df = pd.concat(dfs, ignore_index=True)
|
| 49 |
+
|
| 50 |
+
print(f"Created combined DataFrame with shape {combined_df.shape}")
|
| 51 |
+
|
| 52 |
+
return combined_df
|
| 53 |
+
|
| 54 |
+
# Example usage
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
# Replace with your chunks directory
|
| 57 |
+
chunks_dir = "data"
|
| 58 |
+
|
| 59 |
+
# Create the combined DataFrame
|
| 60 |
+
df = create_single_dataframe_from_chunks(chunks_dir)
|
| 61 |
+
df.to_csv("data/data_gil.csv", index=False)
|
| 62 |
+
|
evaluation/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
evaluation/0_shot_classification.py
DELETED
|
@@ -1,512 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Zero-shot classification evaluation on a new dataset.
|
| 3 |
-
This file evaluates the main model's performance on unseen data by performing
|
| 4 |
-
zero-shot classification. It compares three methods: color-to-color classification,
|
| 5 |
-
text-to-text, and image-to-text. It generates confusion matrices and classification reports
|
| 6 |
-
for each method to analyze the model's generalization capability.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import os
|
| 10 |
-
# Set environment variable to disable tokenizers parallelism warnings
|
| 11 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
import numpy as np
|
| 16 |
-
import pandas as pd
|
| 17 |
-
from torch.utils.data import Dataset
|
| 18 |
-
import matplotlib.pyplot as plt
|
| 19 |
-
from PIL import Image
|
| 20 |
-
from torchvision import transforms
|
| 21 |
-
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 22 |
-
import warnings
|
| 23 |
-
import config
|
| 24 |
-
from tqdm import tqdm
|
| 25 |
-
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
|
| 26 |
-
import seaborn as sns
|
| 27 |
-
from color_model import CLIPModel as ColorModel
|
| 28 |
-
from hierarchy_model import Model, HierarchyExtractor
|
| 29 |
-
|
| 30 |
-
# Suppress warnings
|
| 31 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 32 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
| 33 |
-
|
| 34 |
-
def load_trained_model(model_path, device):
|
| 35 |
-
"""
|
| 36 |
-
Load the trained CLIP model from checkpoint
|
| 37 |
-
"""
|
| 38 |
-
print(f"Loading trained model from: {model_path}")
|
| 39 |
-
|
| 40 |
-
# Load checkpoint
|
| 41 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 42 |
-
|
| 43 |
-
# Create the base CLIP model
|
| 44 |
-
model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 45 |
-
|
| 46 |
-
# Load the trained weights
|
| 47 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 48 |
-
model = model.to(device)
|
| 49 |
-
model.eval()
|
| 50 |
-
|
| 51 |
-
print(f"✅ Model loaded successfully!")
|
| 52 |
-
print(f"📊 Training epoch: {checkpoint['epoch']}")
|
| 53 |
-
print(f"📉 Best validation loss: {checkpoint['best_val_loss']:.4f}")
|
| 54 |
-
|
| 55 |
-
return model, checkpoint
|
| 56 |
-
|
| 57 |
-
def load_feature_models(device):
|
| 58 |
-
"""Load feature models (color and hierarchy)"""
|
| 59 |
-
|
| 60 |
-
# Load color model (embed_dim=16)
|
| 61 |
-
color_checkpoint = torch.load(config.color_model_path, map_location=device, weights_only=True)
|
| 62 |
-
color_model = ColorModel(embed_dim=config.color_emb_dim).to(device)
|
| 63 |
-
color_model.load_state_dict(color_checkpoint)
|
| 64 |
-
color_model.eval()
|
| 65 |
-
color_model.name = 'color'
|
| 66 |
-
|
| 67 |
-
# Load hierarchy model (embed_dim=64)
|
| 68 |
-
hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=device)
|
| 69 |
-
hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
|
| 70 |
-
hierarchy_model = Model(
|
| 71 |
-
num_hierarchy_classes=len(hierarchy_classes),
|
| 72 |
-
embed_dim=config.hierarchy_emb_dim
|
| 73 |
-
).to(device)
|
| 74 |
-
hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
|
| 75 |
-
|
| 76 |
-
# Set up hierarchy extractor
|
| 77 |
-
hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
|
| 78 |
-
hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
|
| 79 |
-
hierarchy_model.eval()
|
| 80 |
-
hierarchy_model.name = 'hierarchy'
|
| 81 |
-
|
| 82 |
-
feature_models = {model.name: model for model in [color_model, hierarchy_model]}
|
| 83 |
-
return feature_models
|
| 84 |
-
|
| 85 |
-
def get_image_embedding(model, image, device):
|
| 86 |
-
"""Get image embedding from the trained model"""
|
| 87 |
-
model.eval()
|
| 88 |
-
with torch.no_grad():
|
| 89 |
-
# Ensure image has 3 channels
|
| 90 |
-
if image.dim() == 3 and image.size(0) == 1:
|
| 91 |
-
image = image.expand(3, -1, -1)
|
| 92 |
-
elif image.dim() == 4 and image.size(1) == 1:
|
| 93 |
-
image = image.expand(-1, 3, -1, -1)
|
| 94 |
-
|
| 95 |
-
# Add batch dimension if missing
|
| 96 |
-
if image.dim() == 3:
|
| 97 |
-
image = image.unsqueeze(0) # Add batch dimension: (C, H, W) -> (1, C, H, W)
|
| 98 |
-
|
| 99 |
-
image = image.to(device)
|
| 100 |
-
|
| 101 |
-
# Use vision model directly to get image embeddings
|
| 102 |
-
vision_outputs = model.vision_model(pixel_values=image)
|
| 103 |
-
image_features = model.visual_projection(vision_outputs.pooler_output)
|
| 104 |
-
|
| 105 |
-
return F.normalize(image_features, dim=-1)
|
| 106 |
-
|
| 107 |
-
def get_text_embedding(model, text, processor, device):
|
| 108 |
-
"""Get text embedding from the trained model"""
|
| 109 |
-
model.eval()
|
| 110 |
-
with torch.no_grad():
|
| 111 |
-
text_inputs = processor(text=text, padding=True, return_tensors="pt")
|
| 112 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 113 |
-
|
| 114 |
-
# Use text model directly to get text embeddings
|
| 115 |
-
text_outputs = model.text_model(**text_inputs)
|
| 116 |
-
text_features = model.text_projection(text_outputs.pooler_output)
|
| 117 |
-
|
| 118 |
-
return F.normalize(text_features, dim=-1)
|
| 119 |
-
|
| 120 |
-
def evaluate_custom_csv_accuracy(model, dataset, processor, method='similarity'):
|
| 121 |
-
"""
|
| 122 |
-
Evaluate the accuracy of the model on your custom CSV using text-to-text similarity
|
| 123 |
-
|
| 124 |
-
Args:
|
| 125 |
-
model: The trained CLIP model
|
| 126 |
-
dataset: CustomCSVDataset
|
| 127 |
-
processor: CLIPProcessor
|
| 128 |
-
method: 'similarity' or 'classification'
|
| 129 |
-
"""
|
| 130 |
-
print(f"\n📊 === Evaluation of the accuracy on custom CSV (TEXT-TO-TEXT method) ===")
|
| 131 |
-
|
| 132 |
-
model.eval()
|
| 133 |
-
|
| 134 |
-
# Get all unique colors for classification
|
| 135 |
-
all_colors = set()
|
| 136 |
-
for i in range(len(dataset)):
|
| 137 |
-
_, _, color = dataset[i]
|
| 138 |
-
all_colors.add(color)
|
| 139 |
-
|
| 140 |
-
color_list = sorted(list(all_colors))
|
| 141 |
-
print(f"🎨 Colors found: {color_list}")
|
| 142 |
-
|
| 143 |
-
true_labels = []
|
| 144 |
-
predicted_labels = []
|
| 145 |
-
|
| 146 |
-
# Pre-calculate the embeddings of the color descriptions
|
| 147 |
-
print("🔄 Pre-calculating the embeddings of the colors...")
|
| 148 |
-
color_embeddings = {}
|
| 149 |
-
for color in color_list:
|
| 150 |
-
color_emb = get_text_embedding(model, color, processor)
|
| 151 |
-
color_embeddings[color] = color_emb
|
| 152 |
-
|
| 153 |
-
print("🔄 Evaluation in progress...")
|
| 154 |
-
correct_predictions = 0
|
| 155 |
-
|
| 156 |
-
for idx in tqdm(range(len(dataset)), desc="Evaluation"):
|
| 157 |
-
image, text, true_color = dataset[idx]
|
| 158 |
-
|
| 159 |
-
# Get text embedding instead of image embedding
|
| 160 |
-
text_emb = get_text_embedding(model, text, processor)
|
| 161 |
-
|
| 162 |
-
# Calculate the similarity with each possible color
|
| 163 |
-
best_similarity = -1
|
| 164 |
-
predicted_color = color_list[0]
|
| 165 |
-
|
| 166 |
-
for color, color_emb in color_embeddings.items():
|
| 167 |
-
similarity = F.cosine_similarity(text_emb, color_emb, dim=1).item()
|
| 168 |
-
if similarity > best_similarity:
|
| 169 |
-
best_similarity = similarity
|
| 170 |
-
predicted_color = color
|
| 171 |
-
|
| 172 |
-
true_labels.append(true_color)
|
| 173 |
-
predicted_labels.append(predicted_color)
|
| 174 |
-
|
| 175 |
-
if true_color == predicted_color:
|
| 176 |
-
correct_predictions += 1
|
| 177 |
-
|
| 178 |
-
# Calculate the accuracy
|
| 179 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 180 |
-
|
| 181 |
-
print(f"\n✅ Results of evaluation:")
|
| 182 |
-
print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
|
| 183 |
-
print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
|
| 184 |
-
|
| 185 |
-
return true_labels, predicted_labels, accuracy
|
| 186 |
-
|
| 187 |
-
def evaluate_custom_csv_accuracy_image(model, dataset, processor, method='similarity'):
|
| 188 |
-
"""
|
| 189 |
-
Evaluate the accuracy of the model on your custom CSV using image-to-text similarity
|
| 190 |
-
|
| 191 |
-
Args:
|
| 192 |
-
model: The trained CLIP model
|
| 193 |
-
dataset: CustomCSVDataset with images loaded
|
| 194 |
-
processor: CLIPProcessor
|
| 195 |
-
method: 'similarity' or 'classification'
|
| 196 |
-
"""
|
| 197 |
-
print(f"\n📊 === Evaluation of the accuracy on custom CSV (IMAGE-TO-TEXT method) ===")
|
| 198 |
-
|
| 199 |
-
model.eval()
|
| 200 |
-
|
| 201 |
-
# Get all unique colors for classification
|
| 202 |
-
all_colors = set()
|
| 203 |
-
for i in range(len(dataset)):
|
| 204 |
-
_, _, color = dataset[i]
|
| 205 |
-
all_colors.add(color)
|
| 206 |
-
|
| 207 |
-
color_list = sorted(list(all_colors))
|
| 208 |
-
print(f"🎨 Colors found: {color_list}")
|
| 209 |
-
|
| 210 |
-
true_labels = []
|
| 211 |
-
predicted_labels = []
|
| 212 |
-
|
| 213 |
-
# Pre-calculate the embeddings of the color descriptions
|
| 214 |
-
print("🔄 Pre-calculating the embeddings of the colors...")
|
| 215 |
-
color_embeddings = {}
|
| 216 |
-
for color in color_list:
|
| 217 |
-
color_emb = get_text_embedding(model, color, processor)
|
| 218 |
-
color_embeddings[color] = color_emb
|
| 219 |
-
|
| 220 |
-
print("🔄 Evaluation in progress...")
|
| 221 |
-
correct_predictions = 0
|
| 222 |
-
|
| 223 |
-
for idx in tqdm(range(len(dataset)), desc="Evaluation"):
|
| 224 |
-
image, text, true_color = dataset[idx]
|
| 225 |
-
|
| 226 |
-
# Get image embedding (this is the key difference from text-to-text)
|
| 227 |
-
image_emb = get_image_embedding(model, image, processor)
|
| 228 |
-
|
| 229 |
-
# Calculate the similarity with each possible color
|
| 230 |
-
best_similarity = -1
|
| 231 |
-
predicted_color = color_list[0]
|
| 232 |
-
|
| 233 |
-
for color, color_emb in color_embeddings.items():
|
| 234 |
-
similarity = F.cosine_similarity(image_emb, color_emb, dim=1).item()
|
| 235 |
-
if similarity > best_similarity:
|
| 236 |
-
best_similarity = similarity
|
| 237 |
-
predicted_color = color
|
| 238 |
-
|
| 239 |
-
true_labels.append(true_color)
|
| 240 |
-
predicted_labels.append(predicted_color)
|
| 241 |
-
|
| 242 |
-
if true_color == predicted_color:
|
| 243 |
-
correct_predictions += 1
|
| 244 |
-
|
| 245 |
-
# Calculate the accuracy
|
| 246 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 247 |
-
|
| 248 |
-
print(f"\n✅ Results of evaluation:")
|
| 249 |
-
print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
|
| 250 |
-
print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
|
| 251 |
-
|
| 252 |
-
return true_labels, predicted_labels, accuracy
|
| 253 |
-
|
| 254 |
-
def evaluate_custom_csv_accuracy_color_only(model, dataset, processor):
|
| 255 |
-
"""
|
| 256 |
-
Evaluate the accuracy by encoding ONLY the color (not the full text)
|
| 257 |
-
This tests if the embedding space is consistent for colors
|
| 258 |
-
|
| 259 |
-
Args:
|
| 260 |
-
model: The trained CLIP model
|
| 261 |
-
dataset: CustomCSVDataset
|
| 262 |
-
processor: CLIPProcessor
|
| 263 |
-
"""
|
| 264 |
-
print(f"\n📊 === Evaluation of the accuracy on custom CSV (COLOR-TO-COLOR method) ===")
|
| 265 |
-
print("🔬 This test encodes ONLY the color name, not the full text")
|
| 266 |
-
|
| 267 |
-
model.eval()
|
| 268 |
-
|
| 269 |
-
# Get all unique colors for classification
|
| 270 |
-
all_colors = set()
|
| 271 |
-
for i in range(len(dataset)):
|
| 272 |
-
_, _, color = dataset[i]
|
| 273 |
-
all_colors.add(color)
|
| 274 |
-
|
| 275 |
-
color_list = sorted(list(all_colors))
|
| 276 |
-
print(f"🎨 Colors found: {color_list}")
|
| 277 |
-
|
| 278 |
-
true_labels = []
|
| 279 |
-
predicted_labels = []
|
| 280 |
-
|
| 281 |
-
# Pre-calculate the embeddings of the color descriptions
|
| 282 |
-
print("🔄 Pre-calculating the embeddings of the colors...")
|
| 283 |
-
color_embeddings = {}
|
| 284 |
-
for color in color_list:
|
| 285 |
-
color_emb = get_text_embedding(model, color, processor)
|
| 286 |
-
color_embeddings[color] = color_emb
|
| 287 |
-
|
| 288 |
-
print("🔄 Evaluation in progress...")
|
| 289 |
-
correct_predictions = 0
|
| 290 |
-
|
| 291 |
-
for idx in tqdm(range(len(dataset)), desc="Evaluation"):
|
| 292 |
-
image, text, true_color = dataset[idx]
|
| 293 |
-
|
| 294 |
-
# KEY DIFFERENCE: Get embedding of the TRUE COLOR only (not the full text)
|
| 295 |
-
true_color_emb = get_text_embedding(model, true_color, processor)
|
| 296 |
-
|
| 297 |
-
# Calculate the similarity with each possible color
|
| 298 |
-
best_similarity = -1
|
| 299 |
-
predicted_color = color_list[0]
|
| 300 |
-
|
| 301 |
-
for color, color_emb in color_embeddings.items():
|
| 302 |
-
similarity = F.cosine_similarity(true_color_emb, color_emb, dim=1).item()
|
| 303 |
-
if similarity > best_similarity:
|
| 304 |
-
best_similarity = similarity
|
| 305 |
-
predicted_color = color
|
| 306 |
-
|
| 307 |
-
true_labels.append(true_color)
|
| 308 |
-
predicted_labels.append(predicted_color)
|
| 309 |
-
|
| 310 |
-
if true_color == predicted_color:
|
| 311 |
-
correct_predictions += 1
|
| 312 |
-
|
| 313 |
-
# Calculate the accuracy
|
| 314 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 315 |
-
|
| 316 |
-
print(f"\n✅ Results of evaluation:")
|
| 317 |
-
print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
|
| 318 |
-
print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
|
| 319 |
-
|
| 320 |
-
return true_labels, predicted_labels, accuracy
|
| 321 |
-
|
| 322 |
-
def search_custom_csv_by_text(model, dataset, query, processor, top_k=5):
|
| 323 |
-
"""Search in your CSV by text query"""
|
| 324 |
-
print(f"\n🔍 Search in custom CSV: '{query}'")
|
| 325 |
-
|
| 326 |
-
# Get the embedding of the query
|
| 327 |
-
query_emb = get_text_embedding(model, query, processor)
|
| 328 |
-
|
| 329 |
-
similarities = []
|
| 330 |
-
|
| 331 |
-
print("🔄 Calculating similarities...")
|
| 332 |
-
for idx in tqdm(range(len(dataset)), desc="Processing"):
|
| 333 |
-
image, text, color, _, image_path = dataset[idx]
|
| 334 |
-
|
| 335 |
-
# Get the embedding of the image
|
| 336 |
-
image_emb = get_image_embedding(model, image, processor)
|
| 337 |
-
|
| 338 |
-
# Calculer la similarité
|
| 339 |
-
similarity = F.cosine_similarity(query_emb, image_emb, dim=1).item()
|
| 340 |
-
|
| 341 |
-
similarities.append((idx, similarity, text, color, color, image_path))
|
| 342 |
-
|
| 343 |
-
# Trier par similarité
|
| 344 |
-
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 345 |
-
|
| 346 |
-
return similarities[:top_k]
|
| 347 |
-
|
| 348 |
-
def plot_confusion_matrix(true_labels, predicted_labels, save_path=None, title_suffix="text"):
|
| 349 |
-
"""
|
| 350 |
-
Display and save the confusion matrix
|
| 351 |
-
"""
|
| 352 |
-
print("\n📈 === Generation of the confusion matrix ===")
|
| 353 |
-
|
| 354 |
-
# Calculate the confusion matrix
|
| 355 |
-
cm = confusion_matrix(true_labels, predicted_labels)
|
| 356 |
-
|
| 357 |
-
# Get unique labels in sorted order
|
| 358 |
-
unique_labels = sorted(set(true_labels + predicted_labels))
|
| 359 |
-
|
| 360 |
-
# Calculate accuracy
|
| 361 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 362 |
-
|
| 363 |
-
# Calculate the percentages and round to integers
|
| 364 |
-
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
|
| 365 |
-
cm_percent = np.around(cm_percent).astype(int)
|
| 366 |
-
|
| 367 |
-
# Create the figure
|
| 368 |
-
plt.figure(figsize=(12, 10))
|
| 369 |
-
|
| 370 |
-
# Confusion matrix with percentages and labels (no decimal points)
|
| 371 |
-
sns.heatmap(cm_percent,
|
| 372 |
-
annot=True,
|
| 373 |
-
fmt='d',
|
| 374 |
-
cmap='Blues',
|
| 375 |
-
cbar_kws={'label': 'Percentage (%)'},
|
| 376 |
-
xticklabels=unique_labels,
|
| 377 |
-
yticklabels=unique_labels)
|
| 378 |
-
|
| 379 |
-
plt.title(f"Confusion Matrix for {title_suffix} - new data - accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)", fontsize=16)
|
| 380 |
-
plt.xlabel('Predictions', fontsize=12)
|
| 381 |
-
plt.ylabel('True colors', fontsize=12)
|
| 382 |
-
plt.xticks(rotation=45, ha='right')
|
| 383 |
-
plt.yticks(rotation=0)
|
| 384 |
-
plt.tight_layout()
|
| 385 |
-
|
| 386 |
-
if save_path:
|
| 387 |
-
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 388 |
-
print(f"💾 Confusion matrix saved: {save_path}")
|
| 389 |
-
|
| 390 |
-
plt.show()
|
| 391 |
-
|
| 392 |
-
return cm
|
| 393 |
-
|
| 394 |
-
class CustomCSVDataset(Dataset):
|
| 395 |
-
def __init__(self, dataframe, image_size=224, load_images=True):
|
| 396 |
-
self.dataframe = dataframe
|
| 397 |
-
self.image_size = image_size
|
| 398 |
-
self.load_images = load_images
|
| 399 |
-
|
| 400 |
-
# Define image transformations
|
| 401 |
-
self.transform = transforms.Compose([
|
| 402 |
-
transforms.Resize((image_size, image_size)),
|
| 403 |
-
transforms.ToTensor(),
|
| 404 |
-
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
| 405 |
-
std=[0.26862954, 0.26130258, 0.27577711])
|
| 406 |
-
])
|
| 407 |
-
|
| 408 |
-
def __len__(self):
|
| 409 |
-
return len(self.dataframe)
|
| 410 |
-
|
| 411 |
-
def __getitem__(self, idx):
|
| 412 |
-
row = self.dataframe.iloc[idx]
|
| 413 |
-
text = row[config.text_column]
|
| 414 |
-
colors = row[config.color_column]
|
| 415 |
-
|
| 416 |
-
if self.load_images and config.column_local_image_path in row:
|
| 417 |
-
# Load the actual image
|
| 418 |
-
try:
|
| 419 |
-
image = Image.open(row[config.column_local_image_path]).convert('RGB')
|
| 420 |
-
image = self.transform(image)
|
| 421 |
-
except Exception as e:
|
| 422 |
-
print(f"Warning: Could not load image {row.get(config.column_local_image_path, 'unknown')}: {e}")
|
| 423 |
-
image = torch.zeros(3, self.image_size, self.image_size)
|
| 424 |
-
else:
|
| 425 |
-
# Return dummy image if not loading images
|
| 426 |
-
image = torch.zeros(3, self.image_size, self.image_size)
|
| 427 |
-
|
| 428 |
-
return image, text, colors
|
| 429 |
-
|
| 430 |
-
if __name__ == "__main__":
|
| 431 |
-
"""Main function with evaluation"""
|
| 432 |
-
print("🚀 === Test and Evaluation of the model on new dataset ===")
|
| 433 |
-
|
| 434 |
-
# Load model
|
| 435 |
-
print("🔧 Loading the model...")
|
| 436 |
-
model, checkpoint = load_trained_model(config.main_model_path, config.device)
|
| 437 |
-
|
| 438 |
-
# Create processor
|
| 439 |
-
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 440 |
-
|
| 441 |
-
# Load new dataset
|
| 442 |
-
print("📊 Loading the new dataset...")
|
| 443 |
-
df = pd.read_csv(config.local_dataset_path) # replace local_dataset_path with a new df
|
| 444 |
-
|
| 445 |
-
print("\n" + "="*80)
|
| 446 |
-
print("🎨 COLOR-TO-COLOR CLASSIFICATION (Control Test)")
|
| 447 |
-
print("="*80)
|
| 448 |
-
|
| 449 |
-
# Create dataset without loading images
|
| 450 |
-
dataset_color = CustomCSVDataset(df, load_images=False)
|
| 451 |
-
|
| 452 |
-
# 0. Evaluation encoding ONLY the color (control test)
|
| 453 |
-
true_labels_color, predicted_labels_color, accuracy_color = evaluate_custom_csv_accuracy_color_only(
|
| 454 |
-
model, dataset_color, processor
|
| 455 |
-
)
|
| 456 |
-
|
| 457 |
-
# Confusion matrix for color-only
|
| 458 |
-
confusion_matrix_color = plot_confusion_matrix(
|
| 459 |
-
true_labels_color, predicted_labels_color,
|
| 460 |
-
save_path="confusion_matrix_color_only.png",
|
| 461 |
-
title_suffix="color-only"
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
print("\n" + "="*80)
|
| 465 |
-
print("📝 TEXT-TO-TEXT CLASSIFICATION")
|
| 466 |
-
print("="*80)
|
| 467 |
-
|
| 468 |
-
# Create dataset without loading images for text-to-text
|
| 469 |
-
dataset_text = CustomCSVDataset(df, load_images=False)
|
| 470 |
-
|
| 471 |
-
# 1. Evaluation of the accuracy (text-to-text)
|
| 472 |
-
true_labels_text, predicted_labels_text, accuracy_text = evaluate_custom_csv_accuracy(
|
| 473 |
-
model, dataset_text, processor, method='similarity'
|
| 474 |
-
)
|
| 475 |
-
|
| 476 |
-
# 2. Confusion matrix for text
|
| 477 |
-
confusion_matrix_text = plot_confusion_matrix(
|
| 478 |
-
true_labels_text, predicted_labels_text,
|
| 479 |
-
save_path="confusion_matrix_text.png",
|
| 480 |
-
title_suffix="text"
|
| 481 |
-
)
|
| 482 |
-
|
| 483 |
-
print("\n" + "="*80)
|
| 484 |
-
print("🖼️ IMAGE-TO-TEXT CLASSIFICATION")
|
| 485 |
-
print("="*80)
|
| 486 |
-
|
| 487 |
-
# Create dataset with images loaded for image-to-text
|
| 488 |
-
dataset_image = CustomCSVDataset(df, load_images=True)
|
| 489 |
-
|
| 490 |
-
# 3. Evaluation of the accuracy (image-to-text)
|
| 491 |
-
true_labels_image, predicted_labels_image, accuracy_image = evaluate_custom_csv_accuracy_image(
|
| 492 |
-
model, dataset_image, processor, method='similarity'
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
# 4. Confusion matrix for images
|
| 496 |
-
confusion_matrix_image = plot_confusion_matrix(
|
| 497 |
-
true_labels_image, predicted_labels_image,
|
| 498 |
-
save_path="confusion_matrix_image.png",
|
| 499 |
-
title_suffix="image"
|
| 500 |
-
)
|
| 501 |
-
|
| 502 |
-
# 5. Summary comparison
|
| 503 |
-
print("\n" + "="*80)
|
| 504 |
-
print("📊 SUMMARY")
|
| 505 |
-
print("="*80)
|
| 506 |
-
print(f"🎨 Color-to-Color Accuracy (Control): {accuracy_color:.4f} ({accuracy_color*100:.2f}%)")
|
| 507 |
-
print(f"📝 Text-to-Text Accuracy: {accuracy_text:.4f} ({accuracy_text*100:.2f}%)")
|
| 508 |
-
print(f"🖼️ Image-to-Text Accuracy: {accuracy_image:.4f} ({accuracy_image*100:.2f}%)")
|
| 509 |
-
print(f"\n📊 Analysis:")
|
| 510 |
-
print(f" • Loss from full text vs color-only: {abs(accuracy_color - accuracy_text):.4f} ({abs(accuracy_color - accuracy_text)*100:.2f}%)")
|
| 511 |
-
print(f" • Difference text vs image: {abs(accuracy_text - accuracy_image):.4f} ({abs(accuracy_text - accuracy_image)*100:.2f}%)")
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation/{heatmap_color_similarities.py → annex92_color_heatmaps.py}
RENAMED
|
@@ -1,3 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import pandas as pd
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Annex 9.2 Pairwise Colour Similarity Heatmaps
|
| 3 |
+
===============================================
|
| 4 |
+
|
| 5 |
+
Generates the colour-similarity heatmaps shown in **Annex 9.2** of the paper.
|
| 6 |
+
|
| 7 |
+
For each model (GAP-CLIP and the Fashion-CLIP baseline) the script:
|
| 8 |
+
|
| 9 |
+
1. Embeds a fixed set of colour-name text prompts ("a red garment", …).
|
| 10 |
+
2. Computes pairwise cosine similarities across the 13 primary colours.
|
| 11 |
+
3. Renders a seaborn heatmap where the diagonal is intra-colour similarity
|
| 12 |
+
and off-diagonal cells show cross-colour confusion.
|
| 13 |
+
|
| 14 |
+
The heatmaps provide an intuitive visual complement to the quantitative
|
| 15 |
+
separation scores reported in §5.1 (Table 1).
|
| 16 |
+
|
| 17 |
+
See also:
|
| 18 |
+
- §5.1 (``sec51_color_model_eval.py``) – quantitative colour accuracy
|
| 19 |
+
- Annex 9.3 (``annex93_tsne.py``) – t-SNE scatter plots
|
| 20 |
+
"""
|
| 21 |
import os
|
| 22 |
import torch
|
| 23 |
import pandas as pd
|
evaluation/{tsne_images.py → annex93_tsne.py}
RENAMED
|
@@ -1,7 +1,28 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import math
|
|
@@ -462,7 +483,7 @@ if __name__ == "__main__":
|
|
| 462 |
output_hierarchy = "tsne_hierarchy_space.png"
|
| 463 |
|
| 464 |
print("📥 Loading the dataset...")
|
| 465 |
-
df = pd.read_csv("data/
|
| 466 |
df = filter_valid_rows(df)
|
| 467 |
print(f"Total len if the dataset: {len(df)}")
|
| 468 |
df = prepare_dataframe(df, sample_size, per_color_limit, min_per_hierarchy)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Annex 9.3 t-SNE Embedding Visualisations
|
| 4 |
+
==========================================
|
| 5 |
+
|
| 6 |
+
Produces the t-SNE scatter plots shown in **Annex 9.3** of the paper.
|
| 7 |
+
|
| 8 |
+
The script loads the local validation dataset, encodes each image with the
|
| 9 |
+
main GAP-CLIP model (and, optionally, the CLIP baseline), then reduces the
|
| 10 |
+
512-D embeddings to 2-D via t-SNE and renders:
|
| 11 |
+
|
| 12 |
+
* **Colour overlay** – points coloured by garment colour, convex hulls drawn
|
| 13 |
+
around each colour cluster.
|
| 14 |
+
* **Hierarchy overlay** – points coloured by clothing category (top, bottom,
|
| 15 |
+
shoes, …), convex hulls drawn around each category cluster.
|
| 16 |
+
* **Per-hierarchy colour scatter** – one subplot per category, showing how
|
| 17 |
+
colours are distributed within each category.
|
| 18 |
+
|
| 19 |
+
These plots complement the quantitative separation scores in §5.3.6 and
|
| 20 |
+
provide an intuitive sanity check that the dedicated embedding dimensions
|
| 21 |
+
(0–15 for colour, 16–79 for hierarchy) encode the intended structure.
|
| 22 |
+
|
| 23 |
+
See also:
|
| 24 |
+
- §5.3.6 (``sec536_embedding_structure.py``) – quantitative Tests A/B/C
|
| 25 |
+
- Annex 9.2 (``annex92_color_heatmaps.py``) – pairwise colour heatmaps
|
| 26 |
"""
|
| 27 |
|
| 28 |
import math
|
|
|
|
| 483 |
output_hierarchy = "tsne_hierarchy_space.png"
|
| 484 |
|
| 485 |
print("📥 Loading the dataset...")
|
| 486 |
+
df = pd.read_csv("data/data.csv")
|
| 487 |
df = filter_valid_rows(df)
|
| 488 |
print(f"Total len if the dataset: {len(df)}")
|
| 489 |
df = prepare_dataframe(df, sample_size, per_color_limit, min_per_hierarchy)
|
evaluation/annex94_search_demo.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Annex 9.4 — Search Engine Demo
|
| 4 |
+
===============================
|
| 5 |
+
|
| 6 |
+
Interactive fashion search engine using pre-computed GAP-CLIP text embeddings.
|
| 7 |
+
Demonstrates real-world retrieval quality by accepting free-text queries and
|
| 8 |
+
returning the most similar items from the internal dataset, with images and
|
| 9 |
+
similarity scores displayed in a grid layout.
|
| 10 |
+
|
| 11 |
+
Run directly:
|
| 12 |
+
python annex94_search_demo.py
|
| 13 |
+
|
| 14 |
+
Paper reference: Section 9.4 (Appendix), Figure 5.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pandas as pd
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 23 |
+
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 24 |
+
import warnings
|
| 25 |
+
import os
|
| 26 |
+
import sys
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import List, Optional
|
| 29 |
+
|
| 30 |
+
# Ensure project root is importable when running this file directly.
|
| 31 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 32 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 33 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 34 |
+
|
| 35 |
+
# Import custom models
|
| 36 |
+
try:
|
| 37 |
+
from training.color_model import CLIPModel as ColorModel
|
| 38 |
+
except ModuleNotFoundError:
|
| 39 |
+
ColorModel = None
|
| 40 |
+
from training.hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 41 |
+
import config
|
| 42 |
+
|
| 43 |
+
warnings.filterwarnings("ignore")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class FashionSearchEngine:
|
| 47 |
+
"""
|
| 48 |
+
Fashion search engine using multi-modal embeddings with category emphasis
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self, top_k: int = 10, max_items: int = 10000, use_baseline: bool = False
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Initialize the fashion search engine
|
| 56 |
+
Args:
|
| 57 |
+
top_k: Number of top results to return
|
| 58 |
+
max_items: Maximum number of items to process (for faster initialization)
|
| 59 |
+
use_baseline: If True, use the Fashion-CLIP baseline instead of GAP-CLIP.
|
| 60 |
+
"""
|
| 61 |
+
self.device = config.device
|
| 62 |
+
self.top_k = top_k
|
| 63 |
+
self.max_items = max_items
|
| 64 |
+
self.color_dim = config.color_emb_dim
|
| 65 |
+
self.hierarchy_dim = config.hierarchy_emb_dim
|
| 66 |
+
self.use_baseline = use_baseline
|
| 67 |
+
|
| 68 |
+
# Load models
|
| 69 |
+
self._load_models()
|
| 70 |
+
|
| 71 |
+
# Load dataset
|
| 72 |
+
self._load_dataset()
|
| 73 |
+
|
| 74 |
+
# Pre-compute embeddings for all items
|
| 75 |
+
self._precompute_embeddings()
|
| 76 |
+
|
| 77 |
+
print("✅ Fashion Search Engine ready!")
|
| 78 |
+
|
| 79 |
+
def _load_models(self):
|
| 80 |
+
"""Load all required models"""
|
| 81 |
+
print("📦 Loading models...")
|
| 82 |
+
|
| 83 |
+
# Load color model (optional for search in this script).
|
| 84 |
+
self.color_model = None
|
| 85 |
+
color_model_path = getattr(config, "color_model_path", None)
|
| 86 |
+
if ColorModel is None:
|
| 87 |
+
print("⚠️ color_model.py not found; continuing without color model.")
|
| 88 |
+
elif not color_model_path or not Path(color_model_path).exists():
|
| 89 |
+
print("⚠️ color model checkpoint not found; continuing without color model.")
|
| 90 |
+
else:
|
| 91 |
+
color_checkpoint = torch.load(
|
| 92 |
+
color_model_path, map_location=self.device, weights_only=True
|
| 93 |
+
)
|
| 94 |
+
self.color_model = ColorModel(embed_dim=self.color_dim).to(self.device)
|
| 95 |
+
self.color_model.load_state_dict(color_checkpoint)
|
| 96 |
+
self.color_model.eval()
|
| 97 |
+
|
| 98 |
+
# Load hierarchy model
|
| 99 |
+
hierarchy_checkpoint = torch.load(
|
| 100 |
+
config.hierarchy_model_path, map_location=self.device
|
| 101 |
+
)
|
| 102 |
+
self.hierarchy_classes = hierarchy_checkpoint.get("hierarchy_classes", [])
|
| 103 |
+
self.hierarchy_model = HierarchyModel(
|
| 104 |
+
num_hierarchy_classes=len(self.hierarchy_classes),
|
| 105 |
+
embed_dim=self.hierarchy_dim,
|
| 106 |
+
).to(self.device)
|
| 107 |
+
self.hierarchy_model.load_state_dict(hierarchy_checkpoint["model_state"])
|
| 108 |
+
|
| 109 |
+
# Set hierarchy extractor
|
| 110 |
+
hierarchy_extractor = HierarchyExtractor(self.hierarchy_classes, verbose=False)
|
| 111 |
+
self.hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
|
| 112 |
+
self.hierarchy_model.eval()
|
| 113 |
+
|
| 114 |
+
# Load main CLIP model (baseline or fine-tuned GAP-CLIP)
|
| 115 |
+
if self.use_baseline:
|
| 116 |
+
baseline_name = "patrickjohncyh/fashion-clip"
|
| 117 |
+
print(f"📦 Loading baseline Fashion-CLIP model ({baseline_name})...")
|
| 118 |
+
self.main_model = CLIPModel_transformers.from_pretrained(baseline_name).to(
|
| 119 |
+
self.device
|
| 120 |
+
)
|
| 121 |
+
self.main_model.eval()
|
| 122 |
+
self.clip_processor = CLIPProcessor.from_pretrained(baseline_name)
|
| 123 |
+
else:
|
| 124 |
+
self.main_model = CLIPModel_transformers.from_pretrained(
|
| 125 |
+
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
| 126 |
+
)
|
| 127 |
+
checkpoint = torch.load(config.main_model_path, map_location=self.device)
|
| 128 |
+
if "model_state_dict" in checkpoint:
|
| 129 |
+
self.main_model.load_state_dict(checkpoint["model_state_dict"])
|
| 130 |
+
else:
|
| 131 |
+
self.main_model.load_state_dict(checkpoint)
|
| 132 |
+
|
| 133 |
+
self.main_model.to(self.device)
|
| 134 |
+
self.main_model.eval()
|
| 135 |
+
self.clip_processor = CLIPProcessor.from_pretrained(
|
| 136 |
+
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
model_label = "Fashion-CLIP baseline" if self.use_baseline else "GAP-CLIP"
|
| 140 |
+
print(
|
| 141 |
+
f"✅ Models loaded ({model_label}) - Colors: {self.color_dim}D, Hierarchy: {self.hierarchy_dim}D"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def _load_dataset(self):
|
| 145 |
+
"""Load the fashion dataset.
|
| 146 |
+
|
| 147 |
+
Tries ``config.local_dataset_path`` first. If it doesn't exist,
|
| 148 |
+
falls back to ``data/data.csv`` (the raw catalogue without
|
| 149 |
+
``local_image_path``).
|
| 150 |
+
"""
|
| 151 |
+
print("📊 Loading dataset...")
|
| 152 |
+
dataset_path = config.local_dataset_path
|
| 153 |
+
if not Path(dataset_path).exists():
|
| 154 |
+
fallback = Path(config.ROOT_DIR) / "data" / "data.csv"
|
| 155 |
+
if fallback.exists():
|
| 156 |
+
print(f"⚠️ {dataset_path} not found, falling back to {fallback}")
|
| 157 |
+
dataset_path = str(fallback)
|
| 158 |
+
else:
|
| 159 |
+
raise FileNotFoundError(
|
| 160 |
+
f"Neither {config.local_dataset_path} nor {fallback} found."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
self.df = pd.read_csv(dataset_path)
|
| 164 |
+
|
| 165 |
+
# If local_image_path column is missing, create an empty one so the
|
| 166 |
+
# rest of the pipeline can proceed (text-only search still works).
|
| 167 |
+
if config.column_local_image_path not in self.df.columns:
|
| 168 |
+
self.df[config.column_local_image_path] = ""
|
| 169 |
+
|
| 170 |
+
self.df_clean = self.df.dropna(subset=[config.text_column])
|
| 171 |
+
print(f"✅ {len(self.df_clean)} items loaded for search")
|
| 172 |
+
|
| 173 |
+
def _precompute_embeddings(self):
|
| 174 |
+
"""Pre-compute text embeddings using stratified sampling (up to 20 items per color-category)."""
|
| 175 |
+
print("🔄 Pre-computing embeddings with stratified sampling...")
|
| 176 |
+
|
| 177 |
+
sampled_df = self.df_clean.groupby(
|
| 178 |
+
[config.color_column, config.hierarchy_column],
|
| 179 |
+
).apply(lambda g: g.sample(n=min(20, len(g)), replace=False))
|
| 180 |
+
sampled_df = sampled_df.reset_index(drop=True)
|
| 181 |
+
|
| 182 |
+
all_embeddings = []
|
| 183 |
+
all_texts = []
|
| 184 |
+
all_colors = []
|
| 185 |
+
all_hierarchies = []
|
| 186 |
+
all_images = []
|
| 187 |
+
all_urls = []
|
| 188 |
+
|
| 189 |
+
batch_size = 32
|
| 190 |
+
from tqdm import tqdm
|
| 191 |
+
|
| 192 |
+
total_batches = (len(sampled_df) + batch_size - 1) // batch_size
|
| 193 |
+
|
| 194 |
+
for i in tqdm(
|
| 195 |
+
range(0, len(sampled_df), batch_size),
|
| 196 |
+
desc="Computing embeddings",
|
| 197 |
+
total=total_batches,
|
| 198 |
+
):
|
| 199 |
+
batch = sampled_df.iloc[i : i + batch_size]
|
| 200 |
+
texts = batch[config.text_column].tolist()
|
| 201 |
+
|
| 202 |
+
all_texts.extend(texts)
|
| 203 |
+
all_colors.extend(batch[config.color_column].tolist())
|
| 204 |
+
all_hierarchies.extend(batch[config.hierarchy_column].tolist())
|
| 205 |
+
all_images.extend(batch[config.column_local_image_path].tolist())
|
| 206 |
+
all_urls.extend(batch[config.column_url_image].tolist())
|
| 207 |
+
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
text_inputs = self.clip_processor(
|
| 210 |
+
text=texts,
|
| 211 |
+
padding=True,
|
| 212 |
+
truncation=True,
|
| 213 |
+
max_length=77,
|
| 214 |
+
return_tensors="pt",
|
| 215 |
+
)
|
| 216 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 217 |
+
dummy_images = torch.zeros(len(texts), 3, 224, 224).to(self.device)
|
| 218 |
+
outputs = self.main_model(**text_inputs, pixel_values=dummy_images)
|
| 219 |
+
embeddings = outputs.text_embeds.cpu().numpy()
|
| 220 |
+
all_embeddings.extend(embeddings)
|
| 221 |
+
|
| 222 |
+
self.all_embeddings = np.array(all_embeddings)
|
| 223 |
+
self.all_texts = all_texts
|
| 224 |
+
self.all_colors = all_colors
|
| 225 |
+
self.all_hierarchies = all_hierarchies
|
| 226 |
+
self.all_images = all_images
|
| 227 |
+
self.all_urls = all_urls
|
| 228 |
+
|
| 229 |
+
print(f"✅ Pre-computed embeddings for {len(self.all_embeddings)} items")
|
| 230 |
+
|
| 231 |
+
def search_by_text(
|
| 232 |
+
self, query_text: str, filter_category: Optional[str] = None
|
| 233 |
+
) -> List[dict]:
|
| 234 |
+
"""Search for clothing items using a text query.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
query_text: Free-text description (e.g. "red summer dress").
|
| 238 |
+
filter_category: Optional category filter (e.g. "dress").
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
List of result dicts with keys: rank, image_path, text, color,
|
| 242 |
+
hierarchy, similarity, index, url.
|
| 243 |
+
"""
|
| 244 |
+
print(f"🔍 Searching for: '{query_text}'")
|
| 245 |
+
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
text_inputs = self.clip_processor(
|
| 248 |
+
text=[query_text], padding=True, return_tensors="pt"
|
| 249 |
+
)
|
| 250 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 251 |
+
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
| 252 |
+
outputs = self.main_model(**text_inputs, pixel_values=dummy_image)
|
| 253 |
+
query_embedding = outputs.text_embeds.cpu().numpy()
|
| 254 |
+
|
| 255 |
+
similarities = cosine_similarity(query_embedding, self.all_embeddings)[0]
|
| 256 |
+
top_indices = np.argsort(similarities)[::-1][: self.top_k * 2]
|
| 257 |
+
|
| 258 |
+
results = []
|
| 259 |
+
for idx in top_indices:
|
| 260 |
+
if similarities[idx] > -0.5:
|
| 261 |
+
if (
|
| 262 |
+
filter_category
|
| 263 |
+
and filter_category.lower() not in self.all_hierarchies[idx].lower()
|
| 264 |
+
):
|
| 265 |
+
continue
|
| 266 |
+
results.append(
|
| 267 |
+
{
|
| 268 |
+
"rank": len(results) + 1,
|
| 269 |
+
"image_path": self.all_images[idx],
|
| 270 |
+
"text": self.all_texts[idx],
|
| 271 |
+
"color": self.all_colors[idx],
|
| 272 |
+
"hierarchy": self.all_hierarchies[idx],
|
| 273 |
+
"similarity": float(similarities[idx]),
|
| 274 |
+
"index": int(idx),
|
| 275 |
+
"url": self.all_urls[idx],
|
| 276 |
+
}
|
| 277 |
+
)
|
| 278 |
+
if len(results) >= self.top_k:
|
| 279 |
+
break
|
| 280 |
+
|
| 281 |
+
print(f"✅ Found {len(results)} results")
|
| 282 |
+
return results
|
| 283 |
+
|
| 284 |
+
@staticmethod
|
| 285 |
+
def _fetch_image_from_url(url: str, timeout: int = 5):
|
| 286 |
+
"""Try to download an image from *url*; return a PIL Image or None."""
|
| 287 |
+
import requests
|
| 288 |
+
from io import BytesIO
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
resp = requests.get(url, timeout=timeout)
|
| 292 |
+
resp.raise_for_status()
|
| 293 |
+
return Image.open(BytesIO(resp.content)).convert("RGB")
|
| 294 |
+
except Exception:
|
| 295 |
+
return None
|
| 296 |
+
|
| 297 |
+
def display_results(
|
| 298 |
+
self, results: List[dict], query_info: str = "", save_path: Optional[str] = None
|
| 299 |
+
):
|
| 300 |
+
"""Display search results as an image grid with similarity scores.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
results: List of result dicts from search_by_text().
|
| 304 |
+
query_info: Label shown in the plot title.
|
| 305 |
+
save_path: If given, save the figure to this path instead of plt.show().
|
| 306 |
+
"""
|
| 307 |
+
if not results:
|
| 308 |
+
print("❌ No results found")
|
| 309 |
+
return
|
| 310 |
+
|
| 311 |
+
print(f"\n🎯 Search Results for: {query_info}")
|
| 312 |
+
print("=" * 80)
|
| 313 |
+
|
| 314 |
+
n_results = len(results)
|
| 315 |
+
cols = min(5, n_results)
|
| 316 |
+
rows = (n_results + cols - 1) // cols
|
| 317 |
+
|
| 318 |
+
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 5 * rows))
|
| 319 |
+
if rows == 1:
|
| 320 |
+
axes = axes.reshape(1, -1)
|
| 321 |
+
elif cols == 1:
|
| 322 |
+
axes = axes.reshape(-1, 1)
|
| 323 |
+
|
| 324 |
+
for i, result in enumerate(results):
|
| 325 |
+
row = i // cols
|
| 326 |
+
col = i % cols
|
| 327 |
+
ax = axes[row, col]
|
| 328 |
+
title = (
|
| 329 |
+
f"#{result['rank']} (Sim: {result['similarity']:.3f})\n"
|
| 330 |
+
f"{result['color']} {result['hierarchy']}"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Try local file → URL download → text fallback
|
| 334 |
+
img = None
|
| 335 |
+
if result.get("image_path") and Path(result["image_path"]).is_file():
|
| 336 |
+
try:
|
| 337 |
+
img = Image.open(result["image_path"])
|
| 338 |
+
except Exception:
|
| 339 |
+
pass
|
| 340 |
+
if img is None and result.get("url"):
|
| 341 |
+
img = self._fetch_image_from_url(result["url"])
|
| 342 |
+
|
| 343 |
+
if img is not None:
|
| 344 |
+
ax.imshow(img)
|
| 345 |
+
else:
|
| 346 |
+
ax.set_facecolor("#f0f0f0")
|
| 347 |
+
snippet = result["text"][:80]
|
| 348 |
+
ax.text(
|
| 349 |
+
0.5,
|
| 350 |
+
0.5,
|
| 351 |
+
snippet,
|
| 352 |
+
ha="center",
|
| 353 |
+
va="center",
|
| 354 |
+
transform=ax.transAxes,
|
| 355 |
+
fontsize=8,
|
| 356 |
+
wrap=True,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
ax.set_title(title, fontsize=10)
|
| 360 |
+
ax.axis("off")
|
| 361 |
+
|
| 362 |
+
for i in range(n_results, rows * cols):
|
| 363 |
+
axes[i // cols, i % cols].axis("off")
|
| 364 |
+
|
| 365 |
+
fig.suptitle(f'Search: "{query_info}"', fontsize=14, fontweight="bold")
|
| 366 |
+
plt.tight_layout()
|
| 367 |
+
|
| 368 |
+
if save_path:
|
| 369 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 370 |
+
print(f"📊 Figure saved to {save_path}")
|
| 371 |
+
else:
|
| 372 |
+
plt.show()
|
| 373 |
+
plt.close(fig)
|
| 374 |
+
|
| 375 |
+
print("\n📋 Detailed Results:")
|
| 376 |
+
for result in results:
|
| 377 |
+
print(
|
| 378 |
+
f"#{result['rank']:2d} | Similarity: {result['similarity']:.3f} | "
|
| 379 |
+
f"Color: {result['color']:12s} | Category: {result['hierarchy']:15s} | "
|
| 380 |
+
f"Text: {result['text'][:50]}..."
|
| 381 |
+
)
|
| 382 |
+
print(f" 🔗 URL: {result['url']}")
|
| 383 |
+
print()
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
if __name__ == "__main__":
|
| 387 |
+
import argparse
|
| 388 |
+
|
| 389 |
+
parser = argparse.ArgumentParser(
|
| 390 |
+
description="Annex 9.4 — Fashion Search Engine Demo"
|
| 391 |
+
)
|
| 392 |
+
parser.add_argument(
|
| 393 |
+
"--baseline",
|
| 394 |
+
action="store_true",
|
| 395 |
+
help="Use the Fashion-CLIP baseline instead of GAP-CLIP",
|
| 396 |
+
)
|
| 397 |
+
parser.add_argument(
|
| 398 |
+
"--queries",
|
| 399 |
+
nargs="*",
|
| 400 |
+
default=None,
|
| 401 |
+
help="Queries to run (e.g. 'red dress' 'blue pants')",
|
| 402 |
+
)
|
| 403 |
+
args = parser.parse_args()
|
| 404 |
+
|
| 405 |
+
label = "Baseline Fashion-CLIP" if args.baseline else "GAP-CLIP"
|
| 406 |
+
print(f"🎯 Initializing Fashion Search Engine ({label})")
|
| 407 |
+
engine = FashionSearchEngine(top_k=10, max_items=10000, use_baseline=args.baseline)
|
| 408 |
+
print("✅ Engine initialized (models loaded, embeddings precomputed).")
|
| 409 |
+
|
| 410 |
+
if args.queries:
|
| 411 |
+
all_results = {}
|
| 412 |
+
figures_dir = Path(args.save).parent if args.save else Path("evaluation")
|
| 413 |
+
figures_dir.mkdir(parents=True, exist_ok=True)
|
| 414 |
+
(figures_dir / "figures").mkdir(parents=True, exist_ok=True)
|
| 415 |
+
|
| 416 |
+
for query in args.queries:
|
| 417 |
+
results = engine.search_by_text(query)
|
| 418 |
+
slug = query.replace(" ", "_")
|
| 419 |
+
fig_path = (
|
| 420 |
+
figures_dir / f"figures/baseline_{slug}.png"
|
| 421 |
+
if args.baseline
|
| 422 |
+
else figures_dir / f"figures/gapclip_{slug}.png"
|
| 423 |
+
)
|
| 424 |
+
engine.display_results(results, query_info=query, save_path=str(fig_path))
|
| 425 |
+
all_results[query] = results
|
evaluation/basic_test_generalized.py
DELETED
|
@@ -1,425 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Generalized evaluation of the main model with sub-module comparison.
|
| 3 |
-
This file evaluates the main model's performance by comparing specialized parts
|
| 4 |
-
(color and hierarchy) with corresponding specialized models. It calculates similarity
|
| 5 |
-
matrices, linear projections between embedding spaces, and generates detailed statistics
|
| 6 |
-
on alignment between different representations.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import os
|
| 10 |
-
import json
|
| 11 |
-
import argparse
|
| 12 |
-
import config
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
import pandas as pd
|
| 16 |
-
from PIL import Image
|
| 17 |
-
from torchvision import transforms
|
| 18 |
-
from transformers import CLIPProcessor, CLIPModel as CLIPModelTransformers
|
| 19 |
-
from tqdm.auto import tqdm
|
| 20 |
-
|
| 21 |
-
# Local imports
|
| 22 |
-
from color_model import ColorCLIP as ColorModel, ColorDataset, Tokenizer
|
| 23 |
-
from config import color_model_path, color_emb_dim, device, hierarchy_model_path, hierarchy_emb_dim
|
| 24 |
-
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def load_color_model(color_model_path, color_emb_dim, device):
|
| 28 |
-
# Load color model
|
| 29 |
-
color_checkpoint = torch.load(color_model_path, map_location=device, weights_only=True)
|
| 30 |
-
color_model = ColorModel(vocab_size=39, embedding_dim=color_emb_dim).to(device)
|
| 31 |
-
color_model.load_state_dict(color_checkpoint)
|
| 32 |
-
|
| 33 |
-
# Load and set the tokenizer
|
| 34 |
-
tokenizer = Tokenizer()
|
| 35 |
-
with open(config.tokeniser_path, 'r') as f:
|
| 36 |
-
vocab_dict = json.load(f)
|
| 37 |
-
color_model.tokenizer = tokenizer
|
| 38 |
-
|
| 39 |
-
color_model.eval()
|
| 40 |
-
return color_model
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def get_emb_color_model(color_model, image_path_to_encode, text_to_encode):
|
| 44 |
-
# Load and preprocess image
|
| 45 |
-
image = Image.open(image_path_to_encode).convert('RGB')
|
| 46 |
-
|
| 47 |
-
transform = transforms.Compose([
|
| 48 |
-
transforms.Resize((224, 224)),
|
| 49 |
-
transforms.ToTensor(),
|
| 50 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 51 |
-
])
|
| 52 |
-
|
| 53 |
-
processed_image = transform(image)
|
| 54 |
-
|
| 55 |
-
# Get embeddings
|
| 56 |
-
processed_image_batch = processed_image.unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
|
| 57 |
-
with torch.no_grad():
|
| 58 |
-
image_emb = color_model.image_encoder(processed_image_batch)
|
| 59 |
-
|
| 60 |
-
# Text embedding via tokenizer + text_encoder
|
| 61 |
-
token_ids = torch.tensor([color_model.tokenizer(text_to_encode)], dtype=torch.long, device=device)
|
| 62 |
-
lengths = torch.tensor([token_ids.size(1) if token_ids.dim() > 1 else token_ids.size(0)], dtype=torch.long, device=device)
|
| 63 |
-
with torch.no_grad():
|
| 64 |
-
txt_emb = color_model.text_encoder(token_ids, lengths)
|
| 65 |
-
|
| 66 |
-
return image_emb, txt_emb
|
| 67 |
-
|
| 68 |
-
def load_main_model(main_model_path, device):
|
| 69 |
-
checkpoint = torch.load(main_model_path, map_location=device)
|
| 70 |
-
main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 71 |
-
state = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint else checkpoint
|
| 72 |
-
try:
|
| 73 |
-
main_model.load_state_dict(state, strict=False)
|
| 74 |
-
except Exception:
|
| 75 |
-
# Fallback: filter matching keys
|
| 76 |
-
model_state = main_model.state_dict()
|
| 77 |
-
filtered = {k: v for k, v in state.items() if k in model_state and model_state[k].shape == v.shape}
|
| 78 |
-
main_model.load_state_dict(filtered, strict=False)
|
| 79 |
-
main_model.to(device)
|
| 80 |
-
main_model.eval()
|
| 81 |
-
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 82 |
-
return main_model, processor
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def load_hierarchy_model(hierarchy_model_path, device):
|
| 86 |
-
checkpoint = torch.load(hierarchy_model_path, map_location=device)
|
| 87 |
-
hierarchy_classes = checkpoint.get('hierarchy_classes', [])
|
| 88 |
-
model = HierarchyModel(num_hierarchy_classes=len(hierarchy_classes), embed_dim=config.hierarchy_emb_dim).to(device)
|
| 89 |
-
model.load_state_dict(checkpoint['model_state'])
|
| 90 |
-
extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
|
| 91 |
-
model.set_hierarchy_extractor(extractor)
|
| 92 |
-
model.eval()
|
| 93 |
-
return model
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def get_emb_hierarchy_model(hierarchy_model, image_path_to_encode, text_to_encode):
|
| 97 |
-
image = Image.open(image_path_to_encode).convert('RGB')
|
| 98 |
-
transform = transforms.Compose([
|
| 99 |
-
transforms.Resize((224, 224)),
|
| 100 |
-
transforms.ToTensor(),
|
| 101 |
-
])
|
| 102 |
-
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 103 |
-
|
| 104 |
-
with torch.no_grad():
|
| 105 |
-
img_emb = hierarchy_model.get_image_embeddings(image_tensor)
|
| 106 |
-
txt_emb = hierarchy_model.get_text_embeddings(text_to_encode)
|
| 107 |
-
|
| 108 |
-
return img_emb, txt_emb
|
| 109 |
-
|
| 110 |
-
def get_emb_main_model(main_model, processor, image_path_to_encode, text_to_encode):
|
| 111 |
-
image = Image.open(image_path_to_encode).convert('RGB')
|
| 112 |
-
transform = transforms.Compose([
|
| 113 |
-
transforms.Resize((224, 224)),
|
| 114 |
-
transforms.ToTensor(),
|
| 115 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 116 |
-
])
|
| 117 |
-
image = transform(image)
|
| 118 |
-
image = image.unsqueeze(0).to(device)
|
| 119 |
-
# Prepare text inputs via processor
|
| 120 |
-
text_inputs = processor(text=[text_to_encode], return_tensors="pt", padding=True)
|
| 121 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 122 |
-
outputs = main_model(**text_inputs, pixel_values=image)
|
| 123 |
-
text_emb = outputs.text_embeds
|
| 124 |
-
image_emb = outputs.image_embeds
|
| 125 |
-
|
| 126 |
-
return text_emb, image_emb
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
if __name__ == '__main__':
|
| 130 |
-
parser = argparse.ArgumentParser(description='Evaluate main model parts vs small models and build similarity matrices')
|
| 131 |
-
parser.add_argument('--main-checkpoint', type=str, default='models/laion_explicable_model.pth')
|
| 132 |
-
parser.add_argument('--color-checkpoint', type=str, default='models/color_model.pt')
|
| 133 |
-
parser.add_argument('--csv', type=str, default='data/data_with_local_paths.csv')
|
| 134 |
-
parser.add_argument('--color-emb-dim', type=int, default=16)
|
| 135 |
-
parser.add_argument('--num-samples', type=int, default=200)
|
| 136 |
-
parser.add_argument('--seed', type=int, default=42)
|
| 137 |
-
parser.add_argument('--primary-metric', type=str, default='sim_color_txt_img',
|
| 138 |
-
choices=['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
|
| 139 |
-
'sim_txt_hierarchy_part', 'sim_img_hierarchy_part'])
|
| 140 |
-
parser.add_argument('--top-k', type=int, default=30)
|
| 141 |
-
parser.add_argument('--heatmap', action='store_true')
|
| 142 |
-
parser.add_argument('--l2-grid', type=str, default='1e-5,1e-4,1e-3,1e-2,1e-1')
|
| 143 |
-
args = parser.parse_args()
|
| 144 |
-
|
| 145 |
-
main_checkpoint = args.main_checkpoint
|
| 146 |
-
color_checkpoint = args.color_checkpoint
|
| 147 |
-
csv = args.csv
|
| 148 |
-
color_emb_dim = args.color_emb_dim
|
| 149 |
-
num_samples = args.num_samples
|
| 150 |
-
seed = args.seed
|
| 151 |
-
primary_metric = args.primary_metric
|
| 152 |
-
top_k = args.top_k
|
| 153 |
-
l2_grid = [float(x) for x in args.l2_grid.split(',') if x]
|
| 154 |
-
device = torch.device("mps")
|
| 155 |
-
|
| 156 |
-
df = pd.read_csv(csv)
|
| 157 |
-
|
| 158 |
-
# Normalize colors (reduce aliasing and sparsity)
|
| 159 |
-
def normalize_color(c):
|
| 160 |
-
if pd.isna(c):
|
| 161 |
-
return c
|
| 162 |
-
s = str(c).strip().lower()
|
| 163 |
-
aliases = {
|
| 164 |
-
'grey': 'gray',
|
| 165 |
-
'navy blue': 'navy',
|
| 166 |
-
'light blue': 'blue',
|
| 167 |
-
'dark blue': 'blue',
|
| 168 |
-
'light grey': 'gray',
|
| 169 |
-
'dark grey': 'gray',
|
| 170 |
-
'light gray': 'gray',
|
| 171 |
-
'dark gray': 'gray',
|
| 172 |
-
}
|
| 173 |
-
return aliases.get(s, s)
|
| 174 |
-
|
| 175 |
-
if config.color_column in df.columns:
|
| 176 |
-
df[config.color_column] = df[config.color_column].apply(normalize_color)
|
| 177 |
-
|
| 178 |
-
color_model = load_color_model(color_checkpoint, color_emb_dim, device)
|
| 179 |
-
main_model, processor = load_main_model(main_checkpoint, device)
|
| 180 |
-
hierarchy_model = load_hierarchy_model(hierarchy_model_path, device)
|
| 181 |
-
|
| 182 |
-
# Results container
|
| 183 |
-
results = []
|
| 184 |
-
|
| 185 |
-
# Accumulators for projection (A: main part, B: small model)
|
| 186 |
-
color_txt_As, color_txt_Bs = [], []
|
| 187 |
-
color_img_As, color_img_Bs = [], []
|
| 188 |
-
hier_txt_As, hier_txt_Bs = [], []
|
| 189 |
-
hier_img_As, hier_img_Bs = [], []
|
| 190 |
-
|
| 191 |
-
# Ensure determinism for sampling
|
| 192 |
-
pd.options.mode.copy_on_write = True
|
| 193 |
-
rng = pd.Series(range(len(df)), dtype=int)
|
| 194 |
-
_ = rng # silence lint
|
| 195 |
-
torch.manual_seed(seed)
|
| 196 |
-
|
| 197 |
-
unique_hiers = sorted(df[config.hierarchy_column].dropna().unique())
|
| 198 |
-
unique_colors = sorted(df[config.color_column].dropna().unique())
|
| 199 |
-
|
| 200 |
-
# Progress bar across all (hierarchy, color) pairs
|
| 201 |
-
total_pairs = len(unique_hiers) * len(unique_colors)
|
| 202 |
-
pair_pbar = tqdm(total=total_pairs, desc="Evaluating pairs", leave=False)
|
| 203 |
-
for hierarchy in unique_hiers:
|
| 204 |
-
for color in unique_colors:
|
| 205 |
-
group = df[(df[config.hierarchy_column] == hierarchy) & (df[config.color_column] == color)]
|
| 206 |
-
|
| 207 |
-
# Sample up to num_samples per (hierarchy, color)
|
| 208 |
-
k = min(num_samples, len(group))
|
| 209 |
-
group_iter = group.sample(n=k, random_state=seed) if len(group) > k else group.iloc[:k]
|
| 210 |
-
|
| 211 |
-
# Progress bar for samples within the pair
|
| 212 |
-
inner_pbar = tqdm(total=len(group_iter), desc=f"{hierarchy}/{color}", leave=False)
|
| 213 |
-
for row_idx, (_, example) in enumerate(group_iter.iterrows()):
|
| 214 |
-
try:
|
| 215 |
-
image_emb, txt_emb = get_emb_color_model(color_model, example['local_image_path'], example['text'])
|
| 216 |
-
image_emb_hier, txt_emb_hier = get_emb_hierarchy_model(hierarchy_model, example['local_image_path'], example['text'])
|
| 217 |
-
text_emb_main_model, image_emb_main_model = get_emb_main_model(
|
| 218 |
-
main_model, processor, example['local_image_path'], example['text']
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
color_part_txt = text_emb_main_model[:, :color_emb_dim]
|
| 222 |
-
color_part_img = image_emb_main_model[:, :color_emb_dim]
|
| 223 |
-
hier_part_txt = text_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
|
| 224 |
-
hier_part_img = image_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
|
| 225 |
-
|
| 226 |
-
# L2-normalize parts and small-model embeddings for stable cosine
|
| 227 |
-
color_part_txt = F.normalize(color_part_txt, dim=1)
|
| 228 |
-
color_part_img = F.normalize(color_part_img, dim=1)
|
| 229 |
-
hier_part_txt = F.normalize(hier_part_txt, dim=1)
|
| 230 |
-
hier_part_img = F.normalize(hier_part_img, dim=1)
|
| 231 |
-
txt_emb = F.normalize(txt_emb, dim=1)
|
| 232 |
-
image_emb = F.normalize(image_emb, dim=1)
|
| 233 |
-
txt_emb_hier = F.normalize(txt_emb_hier, dim=1)
|
| 234 |
-
image_emb_hier = F.normalize(image_emb_hier, dim=1)
|
| 235 |
-
|
| 236 |
-
sim_txt_color_part = F.cosine_similarity(txt_emb, color_part_txt).item()
|
| 237 |
-
sim_img_color_part = F.cosine_similarity(image_emb, color_part_img).item()
|
| 238 |
-
sim_color_txt_img = F.cosine_similarity(color_part_txt, color_part_img).item()
|
| 239 |
-
sim_small_txt_img = F.cosine_similarity(txt_emb, image_emb).item()
|
| 240 |
-
|
| 241 |
-
sim_txt_hierarchy_part = F.cosine_similarity(txt_emb_hier, hier_part_txt).item()
|
| 242 |
-
sim_img_hierarchy_part = F.cosine_similarity(image_emb_hier, hier_part_img).item()
|
| 243 |
-
|
| 244 |
-
# Accumulate for projection fitting later
|
| 245 |
-
color_txt_As.append(color_part_txt.squeeze(0).detach().cpu())
|
| 246 |
-
color_txt_Bs.append(txt_emb.squeeze(0).detach().cpu())
|
| 247 |
-
color_img_As.append(color_part_img.squeeze(0).detach().cpu())
|
| 248 |
-
color_img_Bs.append(image_emb.squeeze(0).detach().cpu())
|
| 249 |
-
|
| 250 |
-
hier_txt_As.append(hier_part_txt.squeeze(0).detach().cpu())
|
| 251 |
-
hier_txt_Bs.append(txt_emb_hier.squeeze(0).detach().cpu())
|
| 252 |
-
hier_img_As.append(hier_part_img.squeeze(0).detach().cpu())
|
| 253 |
-
hier_img_Bs.append(image_emb_hier.squeeze(0).detach().cpu())
|
| 254 |
-
|
| 255 |
-
results.append({
|
| 256 |
-
'hierarchy': hierarchy,
|
| 257 |
-
'color': color,
|
| 258 |
-
'row_index': int(row_idx),
|
| 259 |
-
'sim_txt_color_part': float(sim_txt_color_part),
|
| 260 |
-
'sim_img_color_part': float(sim_img_color_part),
|
| 261 |
-
'sim_color_txt_img': float(sim_color_txt_img),
|
| 262 |
-
'sim_small_txt_img': float(sim_small_txt_img),
|
| 263 |
-
'sim_txt_hierarchy_part': float(sim_txt_hierarchy_part),
|
| 264 |
-
'sim_img_hierarchy_part': float(sim_img_hierarchy_part),
|
| 265 |
-
})
|
| 266 |
-
except Exception as e:
|
| 267 |
-
print(f"Skipping example due to error: {e}")
|
| 268 |
-
finally:
|
| 269 |
-
inner_pbar.update(1)
|
| 270 |
-
inner_pbar.close()
|
| 271 |
-
pair_pbar.update(1)
|
| 272 |
-
pair_pbar.close()
|
| 273 |
-
|
| 274 |
-
results_df = pd.DataFrame(results)
|
| 275 |
-
|
| 276 |
-
# Save raw results
|
| 277 |
-
os.makedirs('evaluation_outputs', exist_ok=True)
|
| 278 |
-
raw_path = os.path.join('evaluation_outputs', 'similarities_raw.csv')
|
| 279 |
-
results_df.to_csv(raw_path, index=False)
|
| 280 |
-
print(f"Saved raw similarities to {raw_path}")
|
| 281 |
-
|
| 282 |
-
# Intelligent averages
|
| 283 |
-
metrics = ['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
|
| 284 |
-
'sim_txt_hierarchy_part', 'sim_img_hierarchy_part']
|
| 285 |
-
|
| 286 |
-
# Overall means
|
| 287 |
-
overall_means = results_df[metrics].mean().to_frame(name='mean').T
|
| 288 |
-
overall_means.insert(0, 'level', 'overall')
|
| 289 |
-
|
| 290 |
-
# By hierarchy
|
| 291 |
-
by_hierarchy = results_df.groupby(config.hierarchy_column)[metrics].mean().reset_index()
|
| 292 |
-
by_hierarchy.insert(0, 'level', config.hierarchy_column)
|
| 293 |
-
|
| 294 |
-
# By color
|
| 295 |
-
by_color = results_df.groupby(config.color_column)[metrics].mean().reset_index()
|
| 296 |
-
by_color.insert(0, 'level', config.color_column)
|
| 297 |
-
|
| 298 |
-
# By hierarchy+color
|
| 299 |
-
by_pair = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
|
| 300 |
-
by_pair.insert(0, 'level', 'hierarchy_color')
|
| 301 |
-
|
| 302 |
-
summary_df = pd.concat([overall_means, by_hierarchy, by_color, by_pair], ignore_index=True)
|
| 303 |
-
summary_path = os.path.join('evaluation_outputs', 'similarities_summary.csv')
|
| 304 |
-
summary_df.to_csv(summary_path, index=False)
|
| 305 |
-
print(f"Saved summary statistics to {summary_path}")
|
| 306 |
-
|
| 307 |
-
# =====================
|
| 308 |
-
# Similarity matrices for best hierarchy-color combinations
|
| 309 |
-
# =====================
|
| 310 |
-
try:
|
| 311 |
-
by_pair_core = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
|
| 312 |
-
top_pairs = by_pair_core.nlargest(top_k, primary_metric)
|
| 313 |
-
matrix = top_pairs.pivot(index=config.hierarchy_column, columns=config.color_column, values=primary_metric)
|
| 314 |
-
os.makedirs('evaluation_outputs', exist_ok=True)
|
| 315 |
-
matrix_csv_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.csv')
|
| 316 |
-
matrix.to_csv(matrix_csv_path)
|
| 317 |
-
print(f"Saved similarity matrix to {matrix_csv_path}")
|
| 318 |
-
|
| 319 |
-
if args.heatmap:
|
| 320 |
-
try:
|
| 321 |
-
import seaborn as sns
|
| 322 |
-
import matplotlib.pyplot as plt
|
| 323 |
-
plt.figure(figsize=(max(6, 0.5 * len(matrix.columns)), max(4, 0.5 * len(matrix.index))))
|
| 324 |
-
sns.heatmap(matrix, annot=False, cmap='viridis')
|
| 325 |
-
plt.title(f'Similarity matrix (top {top_k}) - {primary_metric}')
|
| 326 |
-
heatmap_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.png')
|
| 327 |
-
plt.tight_layout()
|
| 328 |
-
plt.savefig(heatmap_path, dpi=200)
|
| 329 |
-
plt.close()
|
| 330 |
-
print(f"Saved similarity heatmap to {heatmap_path}")
|
| 331 |
-
except Exception as e:
|
| 332 |
-
print(f"Skipping heatmap generation: {e}")
|
| 333 |
-
except Exception as e:
|
| 334 |
-
print(f"Skipping matrix generation: {e}")
|
| 335 |
-
|
| 336 |
-
# =====================
|
| 337 |
-
# Learn projections A->B and report projected cosine means
|
| 338 |
-
# =====================
|
| 339 |
-
def fit_ridge_projection(A, B, l2_reg=1e-3):
|
| 340 |
-
# A: [N, D_in], B: [N, D_out]
|
| 341 |
-
A = torch.stack(A) # [N, D_in]
|
| 342 |
-
B = torch.stack(B) # [N, D_out]
|
| 343 |
-
# Closed-form ridge: W = (A^T A + λI)^-1 A^T B
|
| 344 |
-
AtA = A.T @ A
|
| 345 |
-
D_in = AtA.shape[0]
|
| 346 |
-
AtA_reg = AtA + l2_reg * torch.eye(D_in)
|
| 347 |
-
W = torch.linalg.solve(AtA_reg, A.T @ B)
|
| 348 |
-
return W # [D_in, D_out]
|
| 349 |
-
|
| 350 |
-
def fit_ridge_with_cv(A, B, l2_values):
|
| 351 |
-
# Simple holdout CV: 80/20 split
|
| 352 |
-
if len(A) < 10:
|
| 353 |
-
# Not enough data for split; fallback to middle lambda
|
| 354 |
-
best_l2 = l2_values[min(len(l2_values) // 2, len(l2_values)-1)]
|
| 355 |
-
W = fit_ridge_projection(A, B, best_l2)
|
| 356 |
-
return W, best_l2, None
|
| 357 |
-
|
| 358 |
-
N = len(A)
|
| 359 |
-
idx = torch.randperm(N)
|
| 360 |
-
split = int(0.8 * N)
|
| 361 |
-
train_idx = idx[:split]
|
| 362 |
-
val_idx = idx[split:]
|
| 363 |
-
|
| 364 |
-
A_tensor = torch.stack(A)
|
| 365 |
-
B_tensor = torch.stack(B)
|
| 366 |
-
|
| 367 |
-
A_train, B_train = A_tensor[train_idx], B_tensor[train_idx]
|
| 368 |
-
A_val, B_val = A_tensor[val_idx], B_tensor[val_idx]
|
| 369 |
-
|
| 370 |
-
def to_list(t):
|
| 371 |
-
return [row for row in t]
|
| 372 |
-
|
| 373 |
-
best_l2 = None
|
| 374 |
-
best_score = -1.0
|
| 375 |
-
for l2 in l2_values:
|
| 376 |
-
W = fit_ridge_projection(to_list(A_train), to_list(B_train), l2)
|
| 377 |
-
score = mean_projected_cosine(to_list(A_val), to_list(B_val), W)
|
| 378 |
-
if score > best_score:
|
| 379 |
-
best_score = score
|
| 380 |
-
best_l2 = l2
|
| 381 |
-
|
| 382 |
-
# Refit on all with best_l2
|
| 383 |
-
W_best = fit_ridge_projection(A, B, best_l2)
|
| 384 |
-
return W_best, best_l2, best_score
|
| 385 |
-
|
| 386 |
-
def mean_projected_cosine(A, B, W):
|
| 387 |
-
A = torch.stack(A)
|
| 388 |
-
B = torch.stack(B)
|
| 389 |
-
A_proj = A @ W
|
| 390 |
-
A_proj = F.normalize(A_proj, dim=1)
|
| 391 |
-
B = F.normalize(B, dim=1)
|
| 392 |
-
return torch.mean(torch.sum(A_proj * B, dim=1)).item()
|
| 393 |
-
|
| 394 |
-
projection_report = {}
|
| 395 |
-
|
| 396 |
-
if len(color_txt_As) >= 8:
|
| 397 |
-
W_ct, best_l2_ct, cv_ct = fit_ridge_with_cv(color_txt_As, color_txt_Bs, l2_grid)
|
| 398 |
-
projection_report['proj_sim_txt_color_part_mean'] = mean_projected_cosine(color_txt_As, color_txt_Bs, W_ct)
|
| 399 |
-
projection_report['proj_txt_color_part_best_l2'] = best_l2_ct
|
| 400 |
-
if cv_ct is not None:
|
| 401 |
-
projection_report['proj_txt_color_part_cv_val'] = cv_ct
|
| 402 |
-
if len(color_img_As) >= 8:
|
| 403 |
-
W_ci, best_l2_ci, cv_ci = fit_ridge_with_cv(color_img_As, color_img_Bs, l2_grid)
|
| 404 |
-
projection_report['proj_sim_img_color_part_mean'] = mean_projected_cosine(color_img_As, color_img_Bs, W_ci)
|
| 405 |
-
projection_report['proj_img_color_part_best_l2'] = best_l2_ci
|
| 406 |
-
if cv_ci is not None:
|
| 407 |
-
projection_report['proj_img_color_part_cv_val'] = cv_ci
|
| 408 |
-
if len(hier_txt_As) >= 8:
|
| 409 |
-
W_ht, best_l2_ht, cv_ht = fit_ridge_with_cv(hier_txt_As, hier_txt_Bs, l2_grid)
|
| 410 |
-
projection_report['proj_sim_txt_hierarchy_part_mean'] = mean_projected_cosine(hier_txt_As, hier_txt_Bs, W_ht)
|
| 411 |
-
projection_report['proj_txt_hierarchy_part_best_l2'] = best_l2_ht
|
| 412 |
-
if cv_ht is not None:
|
| 413 |
-
projection_report['proj_txt_hierarchy_part_cv_val'] = cv_ht
|
| 414 |
-
if len(hier_img_As) >= 8:
|
| 415 |
-
W_hi, best_l2_hi, cv_hi = fit_ridge_with_cv(hier_img_As, hier_img_Bs, l2_grid)
|
| 416 |
-
projection_report['proj_sim_img_hierarchy_part_mean'] = mean_projected_cosine(hier_img_As, hier_img_Bs, W_hi)
|
| 417 |
-
projection_report['proj_img_hierarchy_part_best_l2'] = best_l2_hi
|
| 418 |
-
if cv_hi is not None:
|
| 419 |
-
projection_report['proj_img_hierarchy_part_cv_val'] = cv_hi
|
| 420 |
-
|
| 421 |
-
proj_summary_path = os.path.join('evaluation_outputs', 'projection_summary.json')
|
| 422 |
-
with open(proj_summary_path, 'w') as f:
|
| 423 |
-
json.dump(projection_report, f, indent=2)
|
| 424 |
-
print(f"Saved projection summary to {proj_summary_path}")
|
| 425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation/fashion_search.py
DELETED
|
@@ -1,365 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Fashion search system using multi-modal embeddings.
|
| 4 |
-
This file implements a fashion search engine that allows searching for clothing items
|
| 5 |
-
using text queries. It uses embeddings from the main model to calculate cosine similarities
|
| 6 |
-
and return the most relevant items. The system pre-computes embeddings for all items
|
| 7 |
-
in the dataset for fast search.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
import numpy as np
|
| 12 |
-
import pandas as pd
|
| 13 |
-
from PIL import Image
|
| 14 |
-
import matplotlib.pyplot as plt
|
| 15 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 16 |
-
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 17 |
-
import warnings
|
| 18 |
-
import os
|
| 19 |
-
from typing import List, Tuple, Union, Optional
|
| 20 |
-
import argparse
|
| 21 |
-
|
| 22 |
-
# Import custom models
|
| 23 |
-
from color_model import CLIPModel as ColorModel
|
| 24 |
-
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 25 |
-
from main_model import CustomDataset
|
| 26 |
-
import config
|
| 27 |
-
|
| 28 |
-
warnings.filterwarnings("ignore")
|
| 29 |
-
|
| 30 |
-
class FashionSearchEngine:
|
| 31 |
-
"""
|
| 32 |
-
Fashion search engine using multi-modal embeddings with category emphasis
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
def __init__(self, top_k: int = 10, max_items: int = 10000):
|
| 36 |
-
"""
|
| 37 |
-
Initialize the fashion search engine
|
| 38 |
-
Args:
|
| 39 |
-
top_k: Number of top results to return
|
| 40 |
-
max_items: Maximum number of items to process (for faster initialization)
|
| 41 |
-
hierarchy_weight: Weight for hierarchy/category dimensions (default: 2.0)
|
| 42 |
-
color_weight: Weight for color dimensions (default: 1.0)
|
| 43 |
-
"""
|
| 44 |
-
self.device = config.device
|
| 45 |
-
self.top_k = top_k
|
| 46 |
-
self.max_items = max_items
|
| 47 |
-
self.color_dim = config.color_emb_dim
|
| 48 |
-
self.hierarchy_dim = config.hierarchy_emb_dim
|
| 49 |
-
|
| 50 |
-
# Load models
|
| 51 |
-
self._load_models()
|
| 52 |
-
|
| 53 |
-
# Load dataset
|
| 54 |
-
self._load_dataset()
|
| 55 |
-
|
| 56 |
-
# Pre-compute embeddings for all items
|
| 57 |
-
self._precompute_embeddings()
|
| 58 |
-
|
| 59 |
-
print("✅ Fashion Search Engine ready!")
|
| 60 |
-
|
| 61 |
-
def _load_models(self):
|
| 62 |
-
"""Load all required models"""
|
| 63 |
-
print("📦 Loading models...")
|
| 64 |
-
|
| 65 |
-
# Load color model
|
| 66 |
-
color_checkpoint = torch.load(config.color_model_path, map_location=self.device, weights_only=True)
|
| 67 |
-
self.color_model = ColorModel(embed_dim=self.color_dim).to(self.device)
|
| 68 |
-
self.color_model.load_state_dict(color_checkpoint)
|
| 69 |
-
self.color_model.eval()
|
| 70 |
-
|
| 71 |
-
# Load hierarchy model
|
| 72 |
-
hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=self.device)
|
| 73 |
-
self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
|
| 74 |
-
self.hierarchy_model = HierarchyModel(
|
| 75 |
-
num_hierarchy_classes=len(self.hierarchy_classes),
|
| 76 |
-
embed_dim=self.hierarchy_dim
|
| 77 |
-
).to(self.device)
|
| 78 |
-
self.hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
|
| 79 |
-
|
| 80 |
-
# Set hierarchy extractor
|
| 81 |
-
hierarchy_extractor = HierarchyExtractor(self.hierarchy_classes, verbose=False)
|
| 82 |
-
self.hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
|
| 83 |
-
self.hierarchy_model.eval()
|
| 84 |
-
|
| 85 |
-
# Load main CLIP model - Use the trained model directly
|
| 86 |
-
self.main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 87 |
-
|
| 88 |
-
# Load the trained weights
|
| 89 |
-
checkpoint = torch.load(config.main_model_path, map_location=self.device)
|
| 90 |
-
if 'model_state_dict' in checkpoint:
|
| 91 |
-
self.main_model.load_state_dict(checkpoint['model_state_dict'])
|
| 92 |
-
else:
|
| 93 |
-
# Fallback: try to load as state dict directly
|
| 94 |
-
self.main_model.load_state_dict(checkpoint)
|
| 95 |
-
print("✅ Loaded model weights directly")
|
| 96 |
-
|
| 97 |
-
self.main_model.to(self.device)
|
| 98 |
-
self.main_model.eval()
|
| 99 |
-
|
| 100 |
-
# Load CLIP processor
|
| 101 |
-
self.clip_processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 102 |
-
|
| 103 |
-
print(f"✅ Models loaded - Colors: {self.color_dim}D, Hierarchy: {self.hierarchy_dim}D")
|
| 104 |
-
|
| 105 |
-
def _load_dataset(self):
|
| 106 |
-
"""Load the fashion dataset"""
|
| 107 |
-
print("📊 Loading dataset...")
|
| 108 |
-
|
| 109 |
-
# Load dataset
|
| 110 |
-
self.df = pd.read_csv(config.local_dataset_path)
|
| 111 |
-
self.df_clean = self.df.dropna(subset=[config.column_local_image_path])
|
| 112 |
-
|
| 113 |
-
# Create dataset object
|
| 114 |
-
self.dataset = CustomDataset(self.df_clean)
|
| 115 |
-
self.dataset.set_training_mode(False) # No augmentation for search
|
| 116 |
-
|
| 117 |
-
print(f"✅ {len(self.df_clean)} items loaded for search")
|
| 118 |
-
|
| 119 |
-
def _precompute_embeddings(self):
|
| 120 |
-
"""Pre-compute embeddings for all items in the dataset"""
|
| 121 |
-
print("🔄 Pre-computing embeddings...")
|
| 122 |
-
|
| 123 |
-
# OPTIMIZATION: Sample a subset for faster initialization
|
| 124 |
-
print(f"⚠️ Dataset too large ({len(self.dataset)} items). Using stratified sampling of 10 items per color-category combination.")
|
| 125 |
-
|
| 126 |
-
# Stratified sampling by color-category combinations
|
| 127 |
-
sampled_df = self.df_clean.groupby([config.color_column, config.hierarchy_column]).sample(n=20, replace=False)
|
| 128 |
-
|
| 129 |
-
# Get the original indices of sampled items
|
| 130 |
-
sampled_indices = sampled_df.index.tolist()
|
| 131 |
-
|
| 132 |
-
all_embeddings = []
|
| 133 |
-
all_texts = []
|
| 134 |
-
all_colors = []
|
| 135 |
-
all_hierarchies = []
|
| 136 |
-
all_images = []
|
| 137 |
-
all_urls = []
|
| 138 |
-
|
| 139 |
-
# Process in batches for efficiency
|
| 140 |
-
batch_size = 32
|
| 141 |
-
|
| 142 |
-
# Add progress bar
|
| 143 |
-
from tqdm import tqdm
|
| 144 |
-
total_batches = (len(sampled_indices) + batch_size - 1) // batch_size
|
| 145 |
-
|
| 146 |
-
for i in tqdm(range(0, len(sampled_indices), batch_size),
|
| 147 |
-
desc="Computing embeddings",
|
| 148 |
-
total=total_batches):
|
| 149 |
-
batch_end = min(i + batch_size, len(sampled_indices))
|
| 150 |
-
batch_items = []
|
| 151 |
-
|
| 152 |
-
for j in range(i, batch_end):
|
| 153 |
-
try:
|
| 154 |
-
# Use the original dataset with the sampled index
|
| 155 |
-
original_idx = sampled_indices[j]
|
| 156 |
-
image, text, color, hierarchy = self.dataset[original_idx]
|
| 157 |
-
batch_items.append((image, text, color, hierarchy))
|
| 158 |
-
all_texts.append(text)
|
| 159 |
-
all_colors.append(color)
|
| 160 |
-
all_hierarchies.append(hierarchy)
|
| 161 |
-
all_images.append(self.df_clean.iloc[original_idx][config.column_local_image_path])
|
| 162 |
-
all_urls.append(self.df_clean.iloc[original_idx][config.column_url_image])
|
| 163 |
-
except Exception as e:
|
| 164 |
-
print(f"⚠️ Skipping item {j}: {e}")
|
| 165 |
-
continue
|
| 166 |
-
|
| 167 |
-
if not batch_items:
|
| 168 |
-
continue
|
| 169 |
-
|
| 170 |
-
# Process batch
|
| 171 |
-
images = torch.stack([item[0] for item in batch_items]).to(self.device)
|
| 172 |
-
texts = [item[1] for item in batch_items]
|
| 173 |
-
|
| 174 |
-
with torch.no_grad():
|
| 175 |
-
# Get embeddings from main model (text embeddings only)
|
| 176 |
-
text_inputs = self.clip_processor(text=texts, padding=True, return_tensors="pt")
|
| 177 |
-
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 178 |
-
|
| 179 |
-
# Create dummy images for the model
|
| 180 |
-
dummy_images = torch.zeros(len(texts), 3, 224, 224).to(self.device)
|
| 181 |
-
|
| 182 |
-
outputs = self.main_model(**text_inputs, pixel_values=dummy_images)
|
| 183 |
-
embeddings = outputs.text_embeds.cpu().numpy()
|
| 184 |
-
|
| 185 |
-
all_embeddings.extend(embeddings)
|
| 186 |
-
|
| 187 |
-
self.all_embeddings = np.array(all_embeddings)
|
| 188 |
-
self.all_texts = all_texts
|
| 189 |
-
self.all_colors = all_colors
|
| 190 |
-
self.all_hierarchies = all_hierarchies
|
| 191 |
-
self.all_images = all_images
|
| 192 |
-
self.all_urls = all_urls
|
| 193 |
-
|
| 194 |
-
print(f"✅ Pre-computed embeddings for {len(self.all_embeddings)} items")
|
| 195 |
-
|
| 196 |
-
def search_by_text(self, query_text: str, filter_category: str = None) -> List[dict]:
|
| 197 |
-
"""
|
| 198 |
-
Search for clothing items using text query
|
| 199 |
-
|
| 200 |
-
Args:
|
| 201 |
-
query_text: Text description to search for
|
| 202 |
-
|
| 203 |
-
Returns:
|
| 204 |
-
List of dictionaries containing search results
|
| 205 |
-
"""
|
| 206 |
-
print(f"🔍 Searching for: '{query_text}'")
|
| 207 |
-
|
| 208 |
-
# Get query embedding
|
| 209 |
-
with torch.no_grad():
|
| 210 |
-
text_inputs = self.clip_processor(text=[query_text], padding=True, return_tensors="pt")
|
| 211 |
-
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 212 |
-
|
| 213 |
-
# Create a dummy image tensor to satisfy the model's requirements
|
| 214 |
-
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
| 215 |
-
|
| 216 |
-
outputs = self.main_model(**text_inputs, pixel_values=dummy_image)
|
| 217 |
-
query_embedding = outputs.text_embeds.cpu().numpy()
|
| 218 |
-
|
| 219 |
-
# Calculate similarities
|
| 220 |
-
similarities = cosine_similarity(query_embedding, self.all_embeddings)[0]
|
| 221 |
-
|
| 222 |
-
# Get top-k results
|
| 223 |
-
top_indices = np.argsort(similarities)[::-1][:self.top_k * 2] # Prendre plus de résultats
|
| 224 |
-
|
| 225 |
-
results = []
|
| 226 |
-
for idx in top_indices:
|
| 227 |
-
if similarities[idx] > -0.5:
|
| 228 |
-
# Filter by category if specified
|
| 229 |
-
if filter_category and filter_category.lower() not in self.all_hierarchies[idx].lower():
|
| 230 |
-
continue
|
| 231 |
-
|
| 232 |
-
results.append({
|
| 233 |
-
'rank': len(results) + 1,
|
| 234 |
-
'image_path': self.all_images[idx],
|
| 235 |
-
'text': self.all_texts[idx],
|
| 236 |
-
'color': self.all_colors[idx],
|
| 237 |
-
'hierarchy': self.all_hierarchies[idx],
|
| 238 |
-
'similarity': float(similarities[idx]),
|
| 239 |
-
'index': int(idx),
|
| 240 |
-
'url': self.all_urls[idx]
|
| 241 |
-
})
|
| 242 |
-
|
| 243 |
-
if len(results) >= self.top_k:
|
| 244 |
-
break
|
| 245 |
-
|
| 246 |
-
print(f"✅ Found {len(results)} results")
|
| 247 |
-
return results
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def display_results(self, results: List[dict], query_info: str = ""):
|
| 251 |
-
"""
|
| 252 |
-
Display search results with images and information
|
| 253 |
-
|
| 254 |
-
Args:
|
| 255 |
-
results: List of search result dictionaries
|
| 256 |
-
query_info: Information about the query
|
| 257 |
-
"""
|
| 258 |
-
if not results:
|
| 259 |
-
print("❌ No results found")
|
| 260 |
-
return
|
| 261 |
-
|
| 262 |
-
print(f"\n🎯 Search Results for: {query_info}")
|
| 263 |
-
print("=" * 80)
|
| 264 |
-
|
| 265 |
-
# Calculate grid layout
|
| 266 |
-
n_results = len(results)
|
| 267 |
-
cols = min(5, n_results)
|
| 268 |
-
rows = (n_results + cols - 1) // cols
|
| 269 |
-
|
| 270 |
-
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
|
| 271 |
-
if rows == 1:
|
| 272 |
-
axes = axes.reshape(1, -1)
|
| 273 |
-
elif cols == 1:
|
| 274 |
-
axes = axes.reshape(-1, 1)
|
| 275 |
-
|
| 276 |
-
for i, result in enumerate(results):
|
| 277 |
-
row = i // cols
|
| 278 |
-
col = i % cols
|
| 279 |
-
ax = axes[row, col]
|
| 280 |
-
|
| 281 |
-
try:
|
| 282 |
-
# Load and display image
|
| 283 |
-
image = Image.open(result['image_path'])
|
| 284 |
-
ax.imshow(image)
|
| 285 |
-
ax.axis('off')
|
| 286 |
-
|
| 287 |
-
# Add title with similarity score
|
| 288 |
-
title = f"#{result['rank']} (Similarity: {result['similarity']:.3f})\n{result['color']} {result['hierarchy']}"
|
| 289 |
-
ax.set_title(title, fontsize=10, wrap=True)
|
| 290 |
-
|
| 291 |
-
except Exception as e:
|
| 292 |
-
ax.text(0.5, 0.5, f"Error loading image\n{result['image_path']}",
|
| 293 |
-
ha='center', va='center', transform=ax.transAxes)
|
| 294 |
-
ax.axis('off')
|
| 295 |
-
|
| 296 |
-
# Hide empty subplots
|
| 297 |
-
for i in range(n_results, rows * cols):
|
| 298 |
-
row = i // cols
|
| 299 |
-
col = i % cols
|
| 300 |
-
axes[row, col].axis('off')
|
| 301 |
-
|
| 302 |
-
plt.tight_layout()
|
| 303 |
-
plt.show()
|
| 304 |
-
|
| 305 |
-
# Print detailed results
|
| 306 |
-
print("\n📋 Detailed Results:")
|
| 307 |
-
for result in results:
|
| 308 |
-
print(f"#{result['rank']:2d} | Similarity: {result['similarity']:.3f} | "
|
| 309 |
-
f"Color: {result['color']:12s} | Category: {result['hierarchy']:15s} | "
|
| 310 |
-
f"Text: {result['text'][:50]}...")
|
| 311 |
-
print(f" 🔗 URL: {result['url']}")
|
| 312 |
-
print()
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
def main():
|
| 316 |
-
"""Main function for command-line usage"""
|
| 317 |
-
parser = argparse.ArgumentParser(description="Fashion Search Engine with Category Emphasis")
|
| 318 |
-
parser.add_argument("--query", "-q", type=str, help="Search query")
|
| 319 |
-
parser.add_argument("--top-k", "-k", type=int, default=10, help="Number of results (default: 10)")
|
| 320 |
-
parser.add_argument("--fast", "-f", action="store_true", help="Fast mode (less items)")
|
| 321 |
-
parser.add_argument("--interactive", "-i", action="store_true", help="Interactive mode")
|
| 322 |
-
|
| 323 |
-
args = parser.parse_args()
|
| 324 |
-
|
| 325 |
-
print("🎯 Fashion Search Engine with Category Emphasis")
|
| 326 |
-
|
| 327 |
-
search_engine = FashionSearchEngine(
|
| 328 |
-
top_k=args.top_k,
|
| 329 |
-
)
|
| 330 |
-
print("✅ Ready!")
|
| 331 |
-
|
| 332 |
-
# Single query mode
|
| 333 |
-
if args.query:
|
| 334 |
-
print(f"🔍 Search: '{args.query}'...")
|
| 335 |
-
results = search_engine.search_by_text(args.query)
|
| 336 |
-
search_engine.display_results(results, args.query)
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
# Interactive mode
|
| 340 |
-
print("Enter your query (e.g. 'red dress') or 'quit' to exit")
|
| 341 |
-
|
| 342 |
-
while True:
|
| 343 |
-
try:
|
| 344 |
-
user_input = input("\n🔍 Query: ").strip()
|
| 345 |
-
if not user_input or user_input.lower() in ['quit', 'exit', 'q']:
|
| 346 |
-
print("👋 Goodbye!")
|
| 347 |
-
break
|
| 348 |
-
|
| 349 |
-
if user_input.startswith('verify '):
|
| 350 |
-
if 'yellow accessories' in user_input:
|
| 351 |
-
search_engine.display_yellow_accessories()
|
| 352 |
-
continue
|
| 353 |
-
|
| 354 |
-
print(f"🔍 Search: '{user_input}'...")
|
| 355 |
-
results = search_engine.search_by_text(user_input)
|
| 356 |
-
search_engine.display_results(results, user_input)
|
| 357 |
-
|
| 358 |
-
except KeyboardInterrupt:
|
| 359 |
-
print("\n👋 Goodbye!")
|
| 360 |
-
break
|
| 361 |
-
except Exception as e:
|
| 362 |
-
print(f"❌ Error: {e}")
|
| 363 |
-
|
| 364 |
-
if __name__ == "__main__":
|
| 365 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation/hierarchy_evaluation.py
DELETED
|
@@ -1,1842 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Hierarchy Embedding Evaluation with Fashion-CLIP Baseline Comparison
|
| 3 |
-
|
| 4 |
-
This module provides comprehensive evaluation tools for hierarchy classification models,
|
| 5 |
-
comparing custom model performance against the Fashion-CLIP baseline. It includes:
|
| 6 |
-
|
| 7 |
-
- Embedding quality metrics (intra-class/inter-class similarity)
|
| 8 |
-
- Classification accuracy with multiple methods (nearest neighbor, centroid-based)
|
| 9 |
-
- Confusion matrix generation and visualization
|
| 10 |
-
- Support for multiple datasets (validation set, Fashion-MNIST, Kaggle Marqo)
|
| 11 |
-
- Advanced techniques: ZCA whitening, Mahalanobis distance, Test-Time Augmentation
|
| 12 |
-
|
| 13 |
-
Key Features:
|
| 14 |
-
- Custom model evaluation with full hierarchy classification pipeline
|
| 15 |
-
- Fashion-CLIP baseline comparison for performance benchmarking
|
| 16 |
-
- Multi-dataset evaluation (validation, Fashion-MNIST, Kaggle Marqo)
|
| 17 |
-
- Flexible evaluation options (whitening, Mahalanobis distance)
|
| 18 |
-
- Detailed metrics: accuracy, F1 scores, confusion matrices
|
| 19 |
-
|
| 20 |
-
Author: Fashion Search Team
|
| 21 |
-
License: Apache 2.0
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
# Standard library imports
|
| 25 |
-
import os
|
| 26 |
-
import warnings
|
| 27 |
-
from collections import defaultdict
|
| 28 |
-
from io import BytesIO
|
| 29 |
-
from typing import Dict, List, Tuple, Optional, Union, Any
|
| 30 |
-
|
| 31 |
-
# Third-party imports
|
| 32 |
-
import numpy as np
|
| 33 |
-
import pandas as pd
|
| 34 |
-
import requests
|
| 35 |
-
import torch
|
| 36 |
-
import matplotlib.pyplot as plt
|
| 37 |
-
import seaborn as sns
|
| 38 |
-
from PIL import Image
|
| 39 |
-
from sklearn.metrics import (
|
| 40 |
-
accuracy_score,
|
| 41 |
-
classification_report,
|
| 42 |
-
confusion_matrix,
|
| 43 |
-
f1_score,
|
| 44 |
-
)
|
| 45 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 46 |
-
from sklearn.model_selection import train_test_split
|
| 47 |
-
from torch.utils.data import Dataset, DataLoader
|
| 48 |
-
from torchvision import transforms
|
| 49 |
-
from tqdm import tqdm
|
| 50 |
-
from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel
|
| 51 |
-
|
| 52 |
-
# Local imports
|
| 53 |
-
import config
|
| 54 |
-
from config import device, hierarchy_model_path, hierarchy_column, local_dataset_path
|
| 55 |
-
from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
|
| 56 |
-
|
| 57 |
-
# Suppress warnings for cleaner output
|
| 58 |
-
warnings.filterwarnings('ignore')
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# ============================================================================
|
| 62 |
-
# CONSTANTS AND CONFIGURATION
|
| 63 |
-
# ============================================================================
|
| 64 |
-
|
| 65 |
-
# Maximum number of samples for evaluation to prevent memory issues
|
| 66 |
-
MAX_SAMPLES_EVALUATION = 10000
|
| 67 |
-
|
| 68 |
-
# Maximum number of inter-class comparisons to prevent O(n²) complexity
|
| 69 |
-
MAX_INTER_CLASS_COMPARISONS = 10000
|
| 70 |
-
|
| 71 |
-
# Fashion-MNIST label mapping
|
| 72 |
-
FASHION_MNIST_LABELS = {
|
| 73 |
-
0: "T-shirt/top",
|
| 74 |
-
1: "Trouser",
|
| 75 |
-
2: "Pullover",
|
| 76 |
-
3: "Dress",
|
| 77 |
-
4: "Coat",
|
| 78 |
-
5: "Sandal",
|
| 79 |
-
6: "Shirt",
|
| 80 |
-
7: "Sneaker",
|
| 81 |
-
8: "Bag",
|
| 82 |
-
9: "Ankle boot"
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# ============================================================================
|
| 87 |
-
# UTILITY FUNCTIONS
|
| 88 |
-
# ============================================================================
|
| 89 |
-
|
| 90 |
-
def convert_fashion_mnist_to_image(pixel_values: np.ndarray) -> Image.Image:
|
| 91 |
-
"""
|
| 92 |
-
Convert Fashion-MNIST pixel values to RGB PIL Image.
|
| 93 |
-
|
| 94 |
-
Args:
|
| 95 |
-
pixel_values: Flat array of 784 pixel values (28x28)
|
| 96 |
-
|
| 97 |
-
Returns:
|
| 98 |
-
PIL Image in RGB format
|
| 99 |
-
"""
|
| 100 |
-
# Reshape to 28x28 and convert to uint8
|
| 101 |
-
image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
|
| 102 |
-
|
| 103 |
-
# Convert grayscale to RGB by duplicating channels
|
| 104 |
-
image_array = np.stack([image_array] * 3, axis=-1)
|
| 105 |
-
|
| 106 |
-
return Image.fromarray(image_array)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def get_fashion_mnist_labels() -> Dict[int, str]:
|
| 110 |
-
"""
|
| 111 |
-
Get Fashion-MNIST class labels mapping.
|
| 112 |
-
|
| 113 |
-
Returns:
|
| 114 |
-
Dictionary mapping label IDs to class names
|
| 115 |
-
"""
|
| 116 |
-
return FASHION_MNIST_LABELS.copy()
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def create_fashion_mnist_to_hierarchy_mapping(
|
| 120 |
-
hierarchy_classes: List[str]
|
| 121 |
-
) -> Dict[int, Optional[str]]:
|
| 122 |
-
"""
|
| 123 |
-
Create mapping from Fashion-MNIST labels to custom hierarchy classes.
|
| 124 |
-
|
| 125 |
-
This function performs intelligent matching between Fashion-MNIST categories
|
| 126 |
-
and the custom model's hierarchy classes using exact, partial, and semantic matching.
|
| 127 |
-
|
| 128 |
-
Args:
|
| 129 |
-
hierarchy_classes: List of hierarchy class names from the custom model
|
| 130 |
-
|
| 131 |
-
Returns:
|
| 132 |
-
Dictionary mapping Fashion-MNIST label IDs to hierarchy class names
|
| 133 |
-
(None if no match found)
|
| 134 |
-
"""
|
| 135 |
-
# Normalize hierarchy classes to lowercase for matching
|
| 136 |
-
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
|
| 137 |
-
|
| 138 |
-
# Create mapping dictionary
|
| 139 |
-
mapping = {}
|
| 140 |
-
|
| 141 |
-
for fm_label_id, fm_label in FASHION_MNIST_LABELS.items():
|
| 142 |
-
fm_label_lower = fm_label.lower()
|
| 143 |
-
matched_hierarchy = None
|
| 144 |
-
|
| 145 |
-
# Strategy 1: Try exact match first
|
| 146 |
-
if fm_label_lower in hierarchy_classes_lower:
|
| 147 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)]
|
| 148 |
-
|
| 149 |
-
# Strategy 2: Try partial matches
|
| 150 |
-
elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower):
|
| 151 |
-
for h_class in hierarchy_classes:
|
| 152 |
-
h_lower = h_class.lower()
|
| 153 |
-
if h_lower in fm_label_lower or fm_label_lower in h_lower:
|
| 154 |
-
matched_hierarchy = h_class
|
| 155 |
-
break
|
| 156 |
-
|
| 157 |
-
# Strategy 3: Semantic matching for common fashion categories
|
| 158 |
-
else:
|
| 159 |
-
# T-shirt/top -> shirt or top
|
| 160 |
-
if fm_label_lower in ['t-shirt/top', 'top']:
|
| 161 |
-
if 'top' in hierarchy_classes_lower:
|
| 162 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('top')]
|
| 163 |
-
elif 'shirt' in hierarchy_classes_lower:
|
| 164 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('shirt')]
|
| 165 |
-
|
| 166 |
-
# Trouser -> pant, bottom
|
| 167 |
-
elif 'trouser' in fm_label_lower:
|
| 168 |
-
for possible in ['pant', 'pants', 'trousers', 'trouser', 'bottom']:
|
| 169 |
-
if possible in hierarchy_classes_lower:
|
| 170 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 171 |
-
break
|
| 172 |
-
|
| 173 |
-
# Pullover -> sweater, top
|
| 174 |
-
elif 'pullover' in fm_label_lower:
|
| 175 |
-
for possible in ['sweater', 'pullover', 'top']:
|
| 176 |
-
if possible in hierarchy_classes_lower:
|
| 177 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 178 |
-
break
|
| 179 |
-
|
| 180 |
-
# Dress -> dress
|
| 181 |
-
elif 'dress' in fm_label_lower:
|
| 182 |
-
if 'dress' in hierarchy_classes_lower:
|
| 183 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('dress')]
|
| 184 |
-
|
| 185 |
-
# Coat -> coat, jacket
|
| 186 |
-
elif 'coat' in fm_label_lower:
|
| 187 |
-
for possible in ['coat', 'jacket', 'outerwear']:
|
| 188 |
-
if possible in hierarchy_classes_lower:
|
| 189 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 190 |
-
break
|
| 191 |
-
|
| 192 |
-
# Footwear: Sandal, Sneaker, Ankle boot -> shoes
|
| 193 |
-
elif fm_label_lower in ['sandal', 'sneaker', 'ankle boot']:
|
| 194 |
-
for possible in ['shoes', 'shoe', 'footwear', 'sandal', 'sneaker', 'boot']:
|
| 195 |
-
if possible in hierarchy_classes_lower:
|
| 196 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 197 |
-
break
|
| 198 |
-
|
| 199 |
-
# Bag -> bag
|
| 200 |
-
elif 'bag' in fm_label_lower:
|
| 201 |
-
if 'bag' in hierarchy_classes_lower:
|
| 202 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('bag')]
|
| 203 |
-
|
| 204 |
-
mapping[fm_label_id] = matched_hierarchy
|
| 205 |
-
|
| 206 |
-
# Print mapping result
|
| 207 |
-
if matched_hierarchy:
|
| 208 |
-
print(f" {fm_label} ({fm_label_id}) -> {matched_hierarchy}")
|
| 209 |
-
else:
|
| 210 |
-
print(f" ⚠️ {fm_label} ({fm_label_id}) -> NO MATCH (will be filtered out)")
|
| 211 |
-
|
| 212 |
-
return mapping
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
# ============================================================================
|
| 216 |
-
# DATASET CLASSES
|
| 217 |
-
# ============================================================================
|
| 218 |
-
|
| 219 |
-
class FashionMNISTDataset(Dataset):
|
| 220 |
-
"""
|
| 221 |
-
Fashion-MNIST Dataset class for evaluation.
|
| 222 |
-
|
| 223 |
-
This dataset handles Fashion-MNIST images with proper preprocessing and
|
| 224 |
-
label mapping to custom hierarchy classes. Aligned with main_model_evaluation.py
|
| 225 |
-
for consistent evaluation across different scripts.
|
| 226 |
-
|
| 227 |
-
Args:
|
| 228 |
-
dataframe: Pandas DataFrame containing Fashion-MNIST data with pixel columns
|
| 229 |
-
image_size: Target size for image resizing (default: 224)
|
| 230 |
-
label_mapping: Optional mapping from Fashion-MNIST label IDs to hierarchy classes
|
| 231 |
-
|
| 232 |
-
Returns:
|
| 233 |
-
Tuple of (image_tensor, description, color, hierarchy)
|
| 234 |
-
"""
|
| 235 |
-
|
| 236 |
-
def __init__(
|
| 237 |
-
self,
|
| 238 |
-
dataframe: pd.DataFrame,
|
| 239 |
-
image_size: int = 224,
|
| 240 |
-
label_mapping: Optional[Dict[int, str]] = None
|
| 241 |
-
):
|
| 242 |
-
self.dataframe = dataframe
|
| 243 |
-
self.image_size = image_size
|
| 244 |
-
self.labels_map = get_fashion_mnist_labels()
|
| 245 |
-
self.label_mapping = label_mapping
|
| 246 |
-
|
| 247 |
-
# Standard ImageNet normalization for transfer learning
|
| 248 |
-
self.transform = transforms.Compose([
|
| 249 |
-
transforms.Resize((image_size, image_size)),
|
| 250 |
-
transforms.ToTensor(),
|
| 251 |
-
transforms.Normalize(
|
| 252 |
-
mean=[0.485, 0.456, 0.406],
|
| 253 |
-
std=[0.229, 0.224, 0.225]
|
| 254 |
-
),
|
| 255 |
-
])
|
| 256 |
-
|
| 257 |
-
def __len__(self) -> int:
|
| 258 |
-
return len(self.dataframe)
|
| 259 |
-
|
| 260 |
-
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str, str]:
|
| 261 |
-
"""
|
| 262 |
-
Get a single item from the dataset.
|
| 263 |
-
|
| 264 |
-
Args:
|
| 265 |
-
idx: Index of the item to retrieve
|
| 266 |
-
|
| 267 |
-
Returns:
|
| 268 |
-
Tuple of (image_tensor, description, color, hierarchy)
|
| 269 |
-
"""
|
| 270 |
-
row = self.dataframe.iloc[idx]
|
| 271 |
-
|
| 272 |
-
# Extract pixel values (784 pixels for 28x28 image)
|
| 273 |
-
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
|
| 274 |
-
pixel_values = row[pixel_cols].values
|
| 275 |
-
|
| 276 |
-
# Convert to PIL Image and apply transforms
|
| 277 |
-
image = convert_fashion_mnist_to_image(pixel_values)
|
| 278 |
-
image = self.transform(image)
|
| 279 |
-
|
| 280 |
-
# Get label information
|
| 281 |
-
label_id = int(row['label'])
|
| 282 |
-
description = self.labels_map[label_id]
|
| 283 |
-
color = "unknown" # Fashion-MNIST doesn't have color information
|
| 284 |
-
|
| 285 |
-
# Use mapped hierarchy if available, otherwise use original label
|
| 286 |
-
if self.label_mapping and label_id in self.label_mapping:
|
| 287 |
-
hierarchy = self.label_mapping[label_id]
|
| 288 |
-
else:
|
| 289 |
-
hierarchy = self.labels_map[label_id]
|
| 290 |
-
|
| 291 |
-
return image, description, color, hierarchy
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
class CLIPDataset(Dataset):
|
| 295 |
-
"""
|
| 296 |
-
Dataset class for Fashion-CLIP baseline evaluation.
|
| 297 |
-
|
| 298 |
-
This dataset handles image loading from various sources (local paths, URLs, PIL Images)
|
| 299 |
-
and applies standard validation transforms without augmentation.
|
| 300 |
-
|
| 301 |
-
Args:
|
| 302 |
-
dataframe: Pandas DataFrame containing image and text data
|
| 303 |
-
|
| 304 |
-
Returns:
|
| 305 |
-
Tuple of (image_tensor, description, hierarchy)
|
| 306 |
-
"""
|
| 307 |
-
|
| 308 |
-
def __init__(self, dataframe: pd.DataFrame):
|
| 309 |
-
self.dataframe = dataframe
|
| 310 |
-
|
| 311 |
-
# Validation transforms (no augmentation for fair comparison)
|
| 312 |
-
self.transform = transforms.Compose([
|
| 313 |
-
transforms.Resize((224, 224)),
|
| 314 |
-
transforms.ToTensor(),
|
| 315 |
-
transforms.Normalize(
|
| 316 |
-
mean=[0.485, 0.456, 0.406],
|
| 317 |
-
std=[0.229, 0.224, 0.225]
|
| 318 |
-
)
|
| 319 |
-
])
|
| 320 |
-
|
| 321 |
-
def __len__(self) -> int:
|
| 322 |
-
return len(self.dataframe)
|
| 323 |
-
|
| 324 |
-
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str]:
|
| 325 |
-
"""
|
| 326 |
-
Get a single item from the dataset.
|
| 327 |
-
|
| 328 |
-
Args:
|
| 329 |
-
idx: Index of the item to retrieve
|
| 330 |
-
|
| 331 |
-
Returns:
|
| 332 |
-
Tuple of (image_tensor, description, hierarchy)
|
| 333 |
-
"""
|
| 334 |
-
row = self.dataframe.iloc[idx]
|
| 335 |
-
|
| 336 |
-
# Handle image loading from various sources
|
| 337 |
-
image = self._load_image(row, idx)
|
| 338 |
-
|
| 339 |
-
# Apply transforms
|
| 340 |
-
image_tensor = self.transform(image)
|
| 341 |
-
|
| 342 |
-
description = row[config.text_column]
|
| 343 |
-
hierarchy = row[config.hierarchy_column]
|
| 344 |
-
|
| 345 |
-
return image_tensor, description, hierarchy
|
| 346 |
-
|
| 347 |
-
def _load_image(self, row: pd.Series, idx: int) -> Image.Image:
|
| 348 |
-
"""
|
| 349 |
-
Load image from various sources with fallback handling.
|
| 350 |
-
|
| 351 |
-
Args:
|
| 352 |
-
row: DataFrame row containing image information
|
| 353 |
-
idx: Index for error reporting
|
| 354 |
-
|
| 355 |
-
Returns:
|
| 356 |
-
PIL Image in RGB format
|
| 357 |
-
"""
|
| 358 |
-
# Try loading from local path first
|
| 359 |
-
if config.column_local_image_path in row.index and pd.notna(row[config.column_local_image_path]):
|
| 360 |
-
local_path = row[config.column_local_image_path]
|
| 361 |
-
try:
|
| 362 |
-
if os.path.exists(local_path):
|
| 363 |
-
return Image.open(local_path).convert("RGB")
|
| 364 |
-
else:
|
| 365 |
-
print(f"⚠️ Local image not found: {local_path}")
|
| 366 |
-
except Exception as e:
|
| 367 |
-
print(f"⚠️ Failed to load local image {idx}: {e}")
|
| 368 |
-
|
| 369 |
-
# Try loading from various data formats
|
| 370 |
-
image_data = row.get(config.column_url_image)
|
| 371 |
-
|
| 372 |
-
# Handle dictionary format (with bytes)
|
| 373 |
-
if isinstance(image_data, dict) and 'bytes' in image_data:
|
| 374 |
-
return Image.open(BytesIO(image_data['bytes'])).convert('RGB')
|
| 375 |
-
|
| 376 |
-
# Handle numpy array (Fashion-MNIST format)
|
| 377 |
-
if isinstance(image_data, (list, np.ndarray)):
|
| 378 |
-
pixels = np.array(image_data).reshape(28, 28)
|
| 379 |
-
return Image.fromarray(pixels.astype(np.uint8)).convert("RGB")
|
| 380 |
-
|
| 381 |
-
# Handle PIL Image directly
|
| 382 |
-
if isinstance(image_data, Image.Image):
|
| 383 |
-
return image_data.convert("RGB")
|
| 384 |
-
|
| 385 |
-
# Try loading from URL as fallback
|
| 386 |
-
try:
|
| 387 |
-
response = requests.get(image_data, timeout=10)
|
| 388 |
-
response.raise_for_status()
|
| 389 |
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
| 390 |
-
except Exception as e:
|
| 391 |
-
print(f"⚠️ Failed to load image {idx}: {e}")
|
| 392 |
-
# Return gray placeholder image
|
| 393 |
-
return Image.new('RGB', (224, 224), color='gray')
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
# ============================================================================
|
| 397 |
-
# EVALUATOR CLASSES
|
| 398 |
-
# ============================================================================
|
| 399 |
-
|
| 400 |
-
class CLIPBaselineEvaluator:
|
| 401 |
-
"""
|
| 402 |
-
Fashion-CLIP Baseline Evaluator.
|
| 403 |
-
|
| 404 |
-
This class handles the loading and evaluation of the Fashion-CLIP baseline model
|
| 405 |
-
(patrickjohncyh/fashion-clip) for comparison with custom models.
|
| 406 |
-
|
| 407 |
-
Args:
|
| 408 |
-
device: Device to run the model on ('cuda', 'mps', or 'cpu')
|
| 409 |
-
"""
|
| 410 |
-
|
| 411 |
-
def __init__(self, device: str = 'mps'):
|
| 412 |
-
self.device = torch.device(device)
|
| 413 |
-
|
| 414 |
-
# Load Fashion-CLIP model and processor
|
| 415 |
-
print("🤗 Loading Fashion-CLIP baseline model from transformers...")
|
| 416 |
-
model_name = "patrickjohncyh/fashion-clip"
|
| 417 |
-
self.clip_model = TransformersCLIPModel.from_pretrained(model_name).to(self.device)
|
| 418 |
-
self.clip_processor = CLIPProcessor.from_pretrained(model_name)
|
| 419 |
-
|
| 420 |
-
self.clip_model.eval()
|
| 421 |
-
print("✅ Fashion-CLIP model loaded successfully")
|
| 422 |
-
|
| 423 |
-
def extract_clip_embeddings(
|
| 424 |
-
self,
|
| 425 |
-
images: List[Union[torch.Tensor, Image.Image]],
|
| 426 |
-
texts: List[str]
|
| 427 |
-
) -> Tuple[np.ndarray, np.ndarray]:
|
| 428 |
-
"""
|
| 429 |
-
Extract Fashion-CLIP embeddings for images and texts.
|
| 430 |
-
|
| 431 |
-
This method processes images and texts through the Fashion-CLIP model
|
| 432 |
-
to generate normalized embeddings. Aligned with main_model_evaluation.py
|
| 433 |
-
for consistency.
|
| 434 |
-
|
| 435 |
-
Args:
|
| 436 |
-
images: List of images (tensors or PIL Images)
|
| 437 |
-
texts: List of text descriptions
|
| 438 |
-
|
| 439 |
-
Returns:
|
| 440 |
-
Tuple of (image_embeddings, text_embeddings) as numpy arrays
|
| 441 |
-
"""
|
| 442 |
-
all_image_embeddings = []
|
| 443 |
-
all_text_embeddings = []
|
| 444 |
-
|
| 445 |
-
# Process in batches for efficiency
|
| 446 |
-
batch_size = 32
|
| 447 |
-
num_batches = (len(images) + batch_size - 1) // batch_size
|
| 448 |
-
|
| 449 |
-
with torch.no_grad():
|
| 450 |
-
for batch_idx in tqdm(range(num_batches), desc="Extracting CLIP embeddings"):
|
| 451 |
-
start_idx = batch_idx * batch_size
|
| 452 |
-
end_idx = min(start_idx + batch_size, len(images))
|
| 453 |
-
|
| 454 |
-
batch_images = images[start_idx:end_idx]
|
| 455 |
-
batch_texts = texts[start_idx:end_idx]
|
| 456 |
-
|
| 457 |
-
# Extract text embeddings
|
| 458 |
-
text_features = self._extract_text_features(batch_texts)
|
| 459 |
-
|
| 460 |
-
# Extract image embeddings
|
| 461 |
-
image_features = self._extract_image_features(batch_images)
|
| 462 |
-
|
| 463 |
-
# Store results
|
| 464 |
-
all_image_embeddings.append(image_features.cpu().numpy())
|
| 465 |
-
all_text_embeddings.append(text_features.cpu().numpy())
|
| 466 |
-
|
| 467 |
-
# Clear memory
|
| 468 |
-
del text_features, image_features
|
| 469 |
-
if torch.cuda.is_available():
|
| 470 |
-
torch.cuda.empty_cache()
|
| 471 |
-
|
| 472 |
-
return np.vstack(all_image_embeddings), np.vstack(all_text_embeddings)
|
| 473 |
-
|
| 474 |
-
def _extract_text_features(self, texts: List[str]) -> torch.Tensor:
|
| 475 |
-
"""
|
| 476 |
-
Extract text features using Fashion-CLIP.
|
| 477 |
-
|
| 478 |
-
Args:
|
| 479 |
-
texts: List of text descriptions
|
| 480 |
-
|
| 481 |
-
Returns:
|
| 482 |
-
Normalized text feature embeddings
|
| 483 |
-
"""
|
| 484 |
-
# Process text through Fashion-CLIP processor
|
| 485 |
-
text_inputs = self.clip_processor(
|
| 486 |
-
text=texts,
|
| 487 |
-
return_tensors="pt",
|
| 488 |
-
padding=True,
|
| 489 |
-
truncation=True,
|
| 490 |
-
max_length=77
|
| 491 |
-
)
|
| 492 |
-
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 493 |
-
|
| 494 |
-
# Get text features using dedicated method
|
| 495 |
-
text_features = self.clip_model.get_text_features(**text_inputs)
|
| 496 |
-
|
| 497 |
-
# Apply L2 normalization (critical for CLIP!)
|
| 498 |
-
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 499 |
-
|
| 500 |
-
return text_features
|
| 501 |
-
|
| 502 |
-
def _extract_image_features(
|
| 503 |
-
self,
|
| 504 |
-
images: List[Union[torch.Tensor, Image.Image]]
|
| 505 |
-
) -> torch.Tensor:
|
| 506 |
-
"""
|
| 507 |
-
Extract image features using Fashion-CLIP.
|
| 508 |
-
|
| 509 |
-
Args:
|
| 510 |
-
images: List of images (tensors or PIL Images)
|
| 511 |
-
|
| 512 |
-
Returns:
|
| 513 |
-
Normalized image feature embeddings
|
| 514 |
-
"""
|
| 515 |
-
# Convert tensor images to PIL Images for proper processing
|
| 516 |
-
pil_images = []
|
| 517 |
-
for img in images:
|
| 518 |
-
if isinstance(img, torch.Tensor):
|
| 519 |
-
pil_images.append(self._tensor_to_pil(img))
|
| 520 |
-
elif isinstance(img, Image.Image):
|
| 521 |
-
pil_images.append(img)
|
| 522 |
-
else:
|
| 523 |
-
raise ValueError(f"Unsupported image type: {type(img)}")
|
| 524 |
-
|
| 525 |
-
# Process images through Fashion-CLIP processor
|
| 526 |
-
image_inputs = self.clip_processor(
|
| 527 |
-
images=pil_images,
|
| 528 |
-
return_tensors="pt"
|
| 529 |
-
)
|
| 530 |
-
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
|
| 531 |
-
|
| 532 |
-
# Get image features using dedicated method
|
| 533 |
-
image_features = self.clip_model.get_image_features(**image_inputs)
|
| 534 |
-
|
| 535 |
-
# Apply L2 normalization (critical for CLIP!)
|
| 536 |
-
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 537 |
-
|
| 538 |
-
return image_features
|
| 539 |
-
|
| 540 |
-
def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
|
| 541 |
-
"""
|
| 542 |
-
Convert a normalized tensor to PIL Image.
|
| 543 |
-
|
| 544 |
-
Args:
|
| 545 |
-
tensor: Image tensor (C, H, W)
|
| 546 |
-
|
| 547 |
-
Returns:
|
| 548 |
-
PIL Image
|
| 549 |
-
"""
|
| 550 |
-
if tensor.dim() != 3:
|
| 551 |
-
raise ValueError(f"Expected 3D tensor, got {tensor.dim()}D")
|
| 552 |
-
|
| 553 |
-
# Denormalize if normalized (undo ImageNet normalization)
|
| 554 |
-
if tensor.min() < 0 or tensor.max() > 1:
|
| 555 |
-
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 556 |
-
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 557 |
-
tensor = tensor * std + mean
|
| 558 |
-
tensor = torch.clamp(tensor, 0, 1)
|
| 559 |
-
|
| 560 |
-
# Convert to PIL
|
| 561 |
-
return transforms.ToPILImage()(tensor)
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
class EmbeddingEvaluator:
|
| 565 |
-
"""
|
| 566 |
-
Comprehensive Embedding Evaluator for Hierarchy Classification.
|
| 567 |
-
|
| 568 |
-
This class provides a complete evaluation pipeline for hierarchy classification models,
|
| 569 |
-
including custom model evaluation and Fashion-CLIP baseline comparison. It supports
|
| 570 |
-
multiple evaluation metrics, datasets, and advanced techniques.
|
| 571 |
-
|
| 572 |
-
Key Features:
|
| 573 |
-
- Custom model loading and evaluation
|
| 574 |
-
- Fashion-CLIP baseline comparison
|
| 575 |
-
- Multiple classification methods (nearest neighbor, centroid, Mahalanobis)
|
| 576 |
-
- Advanced techniques (ZCA whitening, Test-Time Augmentation)
|
| 577 |
-
- Comprehensive metrics (accuracy, F1, confusion matrices)
|
| 578 |
-
|
| 579 |
-
Args:
|
| 580 |
-
model_path: Path to the trained custom model checkpoint
|
| 581 |
-
directory: Output directory for saving evaluation results
|
| 582 |
-
"""
|
| 583 |
-
|
| 584 |
-
def __init__(self, model_path: str, directory: str):
|
| 585 |
-
self.directory = directory
|
| 586 |
-
self.device = device
|
| 587 |
-
|
| 588 |
-
# Load and prepare dataset
|
| 589 |
-
print(f"📁 Using dataset with local images: {local_dataset_path}")
|
| 590 |
-
df = pd.read_csv(local_dataset_path)
|
| 591 |
-
print(f"📁 Loaded {len(df)} samples")
|
| 592 |
-
|
| 593 |
-
# Get unique hierarchy classes
|
| 594 |
-
hierarchy_classes = sorted(df[hierarchy_column].unique().tolist())
|
| 595 |
-
print(f"📋 Found {len(hierarchy_classes)} hierarchy classes")
|
| 596 |
-
|
| 597 |
-
# Limit dataset size to prevent memory issues
|
| 598 |
-
if len(df) > MAX_SAMPLES_EVALUATION:
|
| 599 |
-
print(f"⚠️ Dataset too large ({len(df)} samples), sampling to {MAX_SAMPLES_EVALUATION} samples")
|
| 600 |
-
df = self._stratified_sample(df, MAX_SAMPLES_EVALUATION)
|
| 601 |
-
|
| 602 |
-
# Create validation split (20% of data)
|
| 603 |
-
_, self.val_df = train_test_split(
|
| 604 |
-
df,
|
| 605 |
-
test_size=0.2,
|
| 606 |
-
random_state=42,
|
| 607 |
-
stratify=df['hierarchy']
|
| 608 |
-
)
|
| 609 |
-
|
| 610 |
-
# Load the custom model
|
| 611 |
-
self._load_model(model_path)
|
| 612 |
-
|
| 613 |
-
# Initialize Fashion-CLIP baseline
|
| 614 |
-
self.clip_evaluator = CLIPBaselineEvaluator(device)
|
| 615 |
-
|
| 616 |
-
def _stratified_sample(self, df: pd.DataFrame, max_samples: int) -> pd.DataFrame:
|
| 617 |
-
"""
|
| 618 |
-
Perform stratified sampling to maintain class distribution.
|
| 619 |
-
|
| 620 |
-
Args:
|
| 621 |
-
df: Original DataFrame
|
| 622 |
-
max_samples: Maximum number of samples to keep
|
| 623 |
-
|
| 624 |
-
Returns:
|
| 625 |
-
Sampled DataFrame
|
| 626 |
-
"""
|
| 627 |
-
# Stratified sampling by hierarchy
|
| 628 |
-
df_sampled = df.groupby('hierarchy', group_keys=False).apply(
|
| 629 |
-
lambda x: x.sample(
|
| 630 |
-
n=min(len(x), int(max_samples * len(x) / len(df))),
|
| 631 |
-
random_state=42
|
| 632 |
-
)
|
| 633 |
-
).reset_index(drop=True)
|
| 634 |
-
|
| 635 |
-
# Adjust to reach exactly max_samples if necessary
|
| 636 |
-
if len(df_sampled) < max_samples:
|
| 637 |
-
remaining = max_samples - len(df_sampled)
|
| 638 |
-
extra = df.sample(n=remaining, random_state=42)
|
| 639 |
-
df_sampled = pd.concat([df_sampled, extra]).reset_index(drop=True)
|
| 640 |
-
|
| 641 |
-
return df_sampled
|
| 642 |
-
|
| 643 |
-
def _load_model(self, model_path: str):
|
| 644 |
-
"""
|
| 645 |
-
Load the custom hierarchy classification model.
|
| 646 |
-
|
| 647 |
-
Args:
|
| 648 |
-
model_path: Path to the model checkpoint
|
| 649 |
-
|
| 650 |
-
Raises:
|
| 651 |
-
FileNotFoundError: If model file doesn't exist
|
| 652 |
-
"""
|
| 653 |
-
if not os.path.exists(model_path):
|
| 654 |
-
raise FileNotFoundError(f"Model file {model_path} not found")
|
| 655 |
-
|
| 656 |
-
# Load checkpoint
|
| 657 |
-
checkpoint = torch.load(model_path, map_location=self.device)
|
| 658 |
-
|
| 659 |
-
# Extract configuration
|
| 660 |
-
config_dict = checkpoint.get('config', {})
|
| 661 |
-
saved_hierarchy_classes = checkpoint['hierarchy_classes']
|
| 662 |
-
|
| 663 |
-
# Store hierarchy classes
|
| 664 |
-
self.hierarchy_classes = saved_hierarchy_classes
|
| 665 |
-
|
| 666 |
-
# Create hierarchy extractor
|
| 667 |
-
self.vocab = HierarchyExtractor(saved_hierarchy_classes)
|
| 668 |
-
|
| 669 |
-
# Create model with saved configuration
|
| 670 |
-
self.model = Model(
|
| 671 |
-
num_hierarchy_classes=len(saved_hierarchy_classes),
|
| 672 |
-
embed_dim=config_dict['embed_dim'],
|
| 673 |
-
dropout=config_dict['dropout']
|
| 674 |
-
).to(self.device)
|
| 675 |
-
|
| 676 |
-
# Load model weights
|
| 677 |
-
self.model.load_state_dict(checkpoint['model_state'])
|
| 678 |
-
self.model.eval()
|
| 679 |
-
|
| 680 |
-
# Print model information
|
| 681 |
-
print(f"✅ Custom model loaded with:")
|
| 682 |
-
print(f"📋 Hierarchy classes: {len(saved_hierarchy_classes)}")
|
| 683 |
-
print(f"🎯 Embed dim: {config_dict['embed_dim']}")
|
| 684 |
-
print(f"💧 Dropout: {config_dict['dropout']}")
|
| 685 |
-
print(f"📅 Epoch: {checkpoint.get('epoch', 'unknown')}")
|
| 686 |
-
|
| 687 |
-
def _collate_fn_wrapper(self, batch: List[Tuple]) -> Dict[str, torch.Tensor]:
|
| 688 |
-
"""
|
| 689 |
-
Wrapper for collate_fn that can be pickled (required for DataLoader).
|
| 690 |
-
|
| 691 |
-
Handles both formats:
|
| 692 |
-
- (image, description, hierarchy) for HierarchyDataset
|
| 693 |
-
- (image, description, color, hierarchy) for FashionMNISTDataset
|
| 694 |
-
|
| 695 |
-
Args:
|
| 696 |
-
batch: List of samples from dataset
|
| 697 |
-
|
| 698 |
-
Returns:
|
| 699 |
-
Collated batch dictionary
|
| 700 |
-
"""
|
| 701 |
-
# Check batch format
|
| 702 |
-
if len(batch[0]) == 4:
|
| 703 |
-
# FashionMNISTDataset format: convert to expected format
|
| 704 |
-
batch_converted = [(b[0], b[1], b[3]) for b in batch]
|
| 705 |
-
return collate_fn(batch_converted, self.vocab)
|
| 706 |
-
else:
|
| 707 |
-
# HierarchyDataset format: use as is
|
| 708 |
-
return collate_fn(batch, self.vocab)
|
| 709 |
-
|
| 710 |
-
def create_dataloader(
|
| 711 |
-
self,
|
| 712 |
-
dataframe_or_dataset: Union[pd.DataFrame, Dataset],
|
| 713 |
-
batch_size: int = 16
|
| 714 |
-
) -> DataLoader:
|
| 715 |
-
"""
|
| 716 |
-
Create a DataLoader for the custom model.
|
| 717 |
-
|
| 718 |
-
Aligned with main_model_evaluation.py for consistency.
|
| 719 |
-
|
| 720 |
-
Args:
|
| 721 |
-
dataframe_or_dataset: Either a pandas DataFrame or a Dataset object
|
| 722 |
-
batch_size: Batch size for the DataLoader
|
| 723 |
-
|
| 724 |
-
Returns:
|
| 725 |
-
Configured DataLoader
|
| 726 |
-
"""
|
| 727 |
-
# Check if it's already a Dataset object
|
| 728 |
-
if isinstance(dataframe_or_dataset, Dataset):
|
| 729 |
-
dataset = dataframe_or_dataset
|
| 730 |
-
print(f"🔍 Using pre-created Dataset object")
|
| 731 |
-
|
| 732 |
-
# Otherwise create dataset from dataframe
|
| 733 |
-
elif isinstance(dataframe_or_dataset, pd.DataFrame):
|
| 734 |
-
# Check if this is Fashion-MNIST data
|
| 735 |
-
if 'pixel1' in dataframe_or_dataset.columns:
|
| 736 |
-
print(f"🔍 Detected Fashion-MNIST data, creating FashionMNISTDataset")
|
| 737 |
-
dataset = FashionMNISTDataset(dataframe_or_dataset, image_size=224)
|
| 738 |
-
else:
|
| 739 |
-
dataset = HierarchyDataset(dataframe_or_dataset, image_size=224)
|
| 740 |
-
else:
|
| 741 |
-
raise ValueError(f"Unsupported type: {type(dataframe_or_dataset)}")
|
| 742 |
-
|
| 743 |
-
# Create DataLoader
|
| 744 |
-
# Note: num_workers=0 to avoid pickling issues on macOS
|
| 745 |
-
dataloader = DataLoader(
|
| 746 |
-
dataset,
|
| 747 |
-
batch_size=batch_size,
|
| 748 |
-
shuffle=False,
|
| 749 |
-
collate_fn=self._collate_fn_wrapper,
|
| 750 |
-
num_workers=0,
|
| 751 |
-
pin_memory=False
|
| 752 |
-
)
|
| 753 |
-
|
| 754 |
-
return dataloader
|
| 755 |
-
|
| 756 |
-
def create_clip_dataloader(
|
| 757 |
-
self,
|
| 758 |
-
dataframe_or_dataset: Union[pd.DataFrame, Dataset],
|
| 759 |
-
batch_size: int = 16
|
| 760 |
-
) -> DataLoader:
|
| 761 |
-
"""
|
| 762 |
-
Create a DataLoader for Fashion-CLIP baseline.
|
| 763 |
-
|
| 764 |
-
Args:
|
| 765 |
-
dataframe_or_dataset: Either a pandas DataFrame or a Dataset object
|
| 766 |
-
batch_size: Batch size for the DataLoader
|
| 767 |
-
|
| 768 |
-
Returns:
|
| 769 |
-
Configured DataLoader
|
| 770 |
-
"""
|
| 771 |
-
# Check if it's already a Dataset object
|
| 772 |
-
if isinstance(dataframe_or_dataset, Dataset):
|
| 773 |
-
dataset = dataframe_or_dataset
|
| 774 |
-
print(f"🔍 Using pre-created Dataset object for CLIP")
|
| 775 |
-
|
| 776 |
-
# Otherwise create dataset from dataframe
|
| 777 |
-
elif isinstance(dataframe_or_dataset, pd.DataFrame):
|
| 778 |
-
# Check if this is Fashion-MNIST data
|
| 779 |
-
if 'pixel1' in dataframe_or_dataset.columns:
|
| 780 |
-
print("🔍 Detected Fashion-MNIST data for Fashion-CLIP")
|
| 781 |
-
dataset = FashionMNISTDataset(dataframe_or_dataset, image_size=224)
|
| 782 |
-
else:
|
| 783 |
-
dataset = CLIPDataset(dataframe_or_dataset)
|
| 784 |
-
else:
|
| 785 |
-
raise ValueError(f"Unsupported type: {type(dataframe_or_dataset)}")
|
| 786 |
-
|
| 787 |
-
# Create DataLoader
|
| 788 |
-
dataloader = DataLoader(
|
| 789 |
-
dataset,
|
| 790 |
-
batch_size=batch_size,
|
| 791 |
-
shuffle=False,
|
| 792 |
-
num_workers=0,
|
| 793 |
-
pin_memory=False
|
| 794 |
-
)
|
| 795 |
-
|
| 796 |
-
return dataloader
|
| 797 |
-
|
| 798 |
-
def extract_custom_embeddings(
|
| 799 |
-
self,
|
| 800 |
-
dataloader: DataLoader,
|
| 801 |
-
embedding_type: str = 'text',
|
| 802 |
-
use_tta: bool = False
|
| 803 |
-
) -> Tuple[np.ndarray, List[str], List[str]]:
|
| 804 |
-
"""
|
| 805 |
-
Extract embeddings from custom model with optional Test-Time Augmentation.
|
| 806 |
-
|
| 807 |
-
Args:
|
| 808 |
-
dataloader: DataLoader for the dataset
|
| 809 |
-
embedding_type: Type of embedding to extract ('text', 'image', or 'both')
|
| 810 |
-
use_tta: Whether to use Test-Time Augmentation for images
|
| 811 |
-
|
| 812 |
-
Returns:
|
| 813 |
-
Tuple of (embeddings, labels, texts)
|
| 814 |
-
"""
|
| 815 |
-
all_embeddings = []
|
| 816 |
-
all_labels = []
|
| 817 |
-
all_texts = []
|
| 818 |
-
|
| 819 |
-
with torch.no_grad():
|
| 820 |
-
for batch in tqdm(dataloader, desc=f"Extracting custom {embedding_type} embeddings{' with TTA' if use_tta else ''}"):
|
| 821 |
-
images = batch['image'].to(self.device)
|
| 822 |
-
hierarchy_indices = batch['hierarchy_indices'].to(self.device)
|
| 823 |
-
hierarchy_labels = batch['hierarchy']
|
| 824 |
-
|
| 825 |
-
# Handle Test-Time Augmentation
|
| 826 |
-
if use_tta and embedding_type == 'image' and images.dim() == 5:
|
| 827 |
-
embeddings = self._extract_with_tta(images, hierarchy_indices)
|
| 828 |
-
else:
|
| 829 |
-
# Standard forward pass
|
| 830 |
-
out = self.model(image=images, hierarchy_indices=hierarchy_indices)
|
| 831 |
-
embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img']
|
| 832 |
-
|
| 833 |
-
all_embeddings.append(embeddings.cpu().numpy())
|
| 834 |
-
all_labels.extend(hierarchy_labels)
|
| 835 |
-
all_texts.extend(hierarchy_labels)
|
| 836 |
-
|
| 837 |
-
# Clear memory
|
| 838 |
-
del images, hierarchy_indices, embeddings, out
|
| 839 |
-
if str(self.device) != 'cpu':
|
| 840 |
-
if torch.cuda.is_available():
|
| 841 |
-
torch.cuda.empty_cache()
|
| 842 |
-
|
| 843 |
-
return np.vstack(all_embeddings), all_labels, all_texts
|
| 844 |
-
|
| 845 |
-
def _extract_with_tta(
|
| 846 |
-
self,
|
| 847 |
-
images: torch.Tensor,
|
| 848 |
-
hierarchy_indices: torch.Tensor
|
| 849 |
-
) -> torch.Tensor:
|
| 850 |
-
"""
|
| 851 |
-
Extract embeddings using Test-Time Augmentation.
|
| 852 |
-
|
| 853 |
-
Args:
|
| 854 |
-
images: Images with TTA crops [batch_size, tta_crops, C, H, W]
|
| 855 |
-
hierarchy_indices: Hierarchy class indices
|
| 856 |
-
|
| 857 |
-
Returns:
|
| 858 |
-
Averaged embeddings [batch_size, embed_dim]
|
| 859 |
-
"""
|
| 860 |
-
batch_size, tta_crops, C, H, W = images.shape
|
| 861 |
-
|
| 862 |
-
# Reshape to [batch_size * tta_crops, C, H, W]
|
| 863 |
-
images_flat = images.view(batch_size * tta_crops, C, H, W)
|
| 864 |
-
|
| 865 |
-
# Repeat hierarchy indices for each TTA crop
|
| 866 |
-
hierarchy_indices_repeated = hierarchy_indices.unsqueeze(1).repeat(1, tta_crops).view(-1)
|
| 867 |
-
|
| 868 |
-
# Forward pass on all TTA crops
|
| 869 |
-
out = self.model(image=images_flat, hierarchy_indices=hierarchy_indices_repeated)
|
| 870 |
-
embeddings_flat = out['z_img']
|
| 871 |
-
|
| 872 |
-
# Reshape back to [batch_size, tta_crops, embed_dim]
|
| 873 |
-
embeddings = embeddings_flat.view(batch_size, tta_crops, -1)
|
| 874 |
-
|
| 875 |
-
# Average over TTA crops
|
| 876 |
-
embeddings = embeddings.mean(dim=1)
|
| 877 |
-
|
| 878 |
-
return embeddings
|
| 879 |
-
|
| 880 |
-
def apply_whitening(
|
| 881 |
-
self,
|
| 882 |
-
embeddings: np.ndarray,
|
| 883 |
-
epsilon: float = 1e-5
|
| 884 |
-
) -> np.ndarray:
|
| 885 |
-
"""
|
| 886 |
-
Apply ZCA whitening to embeddings for better feature decorrelation.
|
| 887 |
-
|
| 888 |
-
Whitening removes correlations between dimensions and can improve
|
| 889 |
-
class separation by normalizing the feature space.
|
| 890 |
-
|
| 891 |
-
Args:
|
| 892 |
-
embeddings: Input embeddings [N, D]
|
| 893 |
-
epsilon: Small constant for numerical stability
|
| 894 |
-
|
| 895 |
-
Returns:
|
| 896 |
-
Whitened embeddings [N, D]
|
| 897 |
-
"""
|
| 898 |
-
# Center the data
|
| 899 |
-
mean = np.mean(embeddings, axis=0, keepdims=True)
|
| 900 |
-
centered = embeddings - mean
|
| 901 |
-
|
| 902 |
-
# Compute covariance matrix
|
| 903 |
-
cov = np.cov(centered.T)
|
| 904 |
-
|
| 905 |
-
# Eigenvalue decomposition
|
| 906 |
-
eigenvalues, eigenvectors = np.linalg.eigh(cov)
|
| 907 |
-
|
| 908 |
-
# ZCA whitening transformation
|
| 909 |
-
d = np.diag(1.0 / np.sqrt(eigenvalues + epsilon))
|
| 910 |
-
whiten_transform = eigenvectors @ d @ eigenvectors.T
|
| 911 |
-
|
| 912 |
-
# Apply whitening
|
| 913 |
-
whitened = centered @ whiten_transform
|
| 914 |
-
|
| 915 |
-
# L2 normalize after whitening
|
| 916 |
-
norms = np.linalg.norm(whitened, axis=1, keepdims=True)
|
| 917 |
-
whitened = whitened / (norms + epsilon)
|
| 918 |
-
|
| 919 |
-
return whitened
|
| 920 |
-
|
| 921 |
-
def compute_similarity_metrics(
|
| 922 |
-
self,
|
| 923 |
-
embeddings: np.ndarray,
|
| 924 |
-
labels: List[str],
|
| 925 |
-
apply_whitening_norm: bool = False
|
| 926 |
-
) -> Dict[str, Any]:
|
| 927 |
-
"""
|
| 928 |
-
Compute intra-class and inter-class similarity metrics.
|
| 929 |
-
|
| 930 |
-
Args:
|
| 931 |
-
embeddings: Embedding vectors
|
| 932 |
-
labels: Class labels
|
| 933 |
-
apply_whitening_norm: Whether to apply ZCA whitening
|
| 934 |
-
|
| 935 |
-
Returns:
|
| 936 |
-
Dictionary containing similarity metrics and accuracies
|
| 937 |
-
"""
|
| 938 |
-
# Apply whitening if requested
|
| 939 |
-
if apply_whitening_norm:
|
| 940 |
-
embeddings = self.apply_whitening(embeddings)
|
| 941 |
-
|
| 942 |
-
# Compute pairwise cosine similarities
|
| 943 |
-
similarities = cosine_similarity(embeddings)
|
| 944 |
-
|
| 945 |
-
# Group embeddings by hierarchy
|
| 946 |
-
hierarchy_groups = defaultdict(list)
|
| 947 |
-
for i, hierarchy in enumerate(labels):
|
| 948 |
-
hierarchy_groups[hierarchy].append(i)
|
| 949 |
-
|
| 950 |
-
# Calculate intra-class similarities (same hierarchy)
|
| 951 |
-
intra_class_similarities = self._compute_intra_class_similarities(
|
| 952 |
-
similarities, hierarchy_groups
|
| 953 |
-
)
|
| 954 |
-
|
| 955 |
-
# Calculate inter-class similarities (different hierarchies)
|
| 956 |
-
inter_class_similarities = self._compute_inter_class_similarities(
|
| 957 |
-
similarities, hierarchy_groups
|
| 958 |
-
)
|
| 959 |
-
|
| 960 |
-
# Calculate classification accuracies
|
| 961 |
-
nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
|
| 962 |
-
centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
|
| 963 |
-
|
| 964 |
-
return {
|
| 965 |
-
'intra_class_similarities': intra_class_similarities,
|
| 966 |
-
'inter_class_similarities': inter_class_similarities,
|
| 967 |
-
'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
|
| 968 |
-
'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
|
| 969 |
-
'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
|
| 970 |
-
'accuracy': nn_accuracy,
|
| 971 |
-
'centroid_accuracy': centroid_accuracy
|
| 972 |
-
}
|
| 973 |
-
|
| 974 |
-
def _compute_intra_class_similarities(
|
| 975 |
-
self,
|
| 976 |
-
similarities: np.ndarray,
|
| 977 |
-
hierarchy_groups: Dict[str, List[int]]
|
| 978 |
-
) -> List[float]:
|
| 979 |
-
"""
|
| 980 |
-
Compute within-class similarities.
|
| 981 |
-
|
| 982 |
-
Args:
|
| 983 |
-
similarities: Pairwise similarity matrix
|
| 984 |
-
hierarchy_groups: Mapping from hierarchy to sample indices
|
| 985 |
-
|
| 986 |
-
Returns:
|
| 987 |
-
List of intra-class similarity values
|
| 988 |
-
"""
|
| 989 |
-
intra_class_similarities = []
|
| 990 |
-
|
| 991 |
-
for hierarchy, indices in hierarchy_groups.items():
|
| 992 |
-
if len(indices) > 1:
|
| 993 |
-
# Compare all pairs within the same class
|
| 994 |
-
for i in range(len(indices)):
|
| 995 |
-
for j in range(i + 1, len(indices)):
|
| 996 |
-
sim = similarities[indices[i], indices[j]]
|
| 997 |
-
intra_class_similarities.append(sim)
|
| 998 |
-
|
| 999 |
-
return intra_class_similarities
|
| 1000 |
-
|
| 1001 |
-
def _compute_inter_class_similarities(
|
| 1002 |
-
self,
|
| 1003 |
-
similarities: np.ndarray,
|
| 1004 |
-
hierarchy_groups: Dict[str, List[int]]
|
| 1005 |
-
) -> List[float]:
|
| 1006 |
-
"""
|
| 1007 |
-
Compute between-class similarities with sampling for efficiency.
|
| 1008 |
-
|
| 1009 |
-
To prevent O(n²) complexity on large datasets, we limit the number
|
| 1010 |
-
of comparisons through sampling.
|
| 1011 |
-
|
| 1012 |
-
Args:
|
| 1013 |
-
similarities: Pairwise similarity matrix
|
| 1014 |
-
hierarchy_groups: Mapping from hierarchy to sample indices
|
| 1015 |
-
|
| 1016 |
-
Returns:
|
| 1017 |
-
List of inter-class similarity values
|
| 1018 |
-
"""
|
| 1019 |
-
inter_class_similarities = []
|
| 1020 |
-
hierarchies = list(hierarchy_groups.keys())
|
| 1021 |
-
comparison_count = 0
|
| 1022 |
-
|
| 1023 |
-
for i in range(len(hierarchies)):
|
| 1024 |
-
for j in range(i + 1, len(hierarchies)):
|
| 1025 |
-
hierarchy1_indices = hierarchy_groups[hierarchies[i]]
|
| 1026 |
-
hierarchy2_indices = hierarchy_groups[hierarchies[j]]
|
| 1027 |
-
|
| 1028 |
-
# Sample if too many comparisons
|
| 1029 |
-
max_samples_per_pair = min(100, len(hierarchy1_indices), len(hierarchy2_indices))
|
| 1030 |
-
sampled_idx1 = np.random.choice(
|
| 1031 |
-
hierarchy1_indices,
|
| 1032 |
-
size=min(max_samples_per_pair, len(hierarchy1_indices)),
|
| 1033 |
-
replace=False
|
| 1034 |
-
)
|
| 1035 |
-
sampled_idx2 = np.random.choice(
|
| 1036 |
-
hierarchy2_indices,
|
| 1037 |
-
size=min(max_samples_per_pair, len(hierarchy2_indices)),
|
| 1038 |
-
replace=False
|
| 1039 |
-
)
|
| 1040 |
-
|
| 1041 |
-
# Compute similarities between sampled pairs
|
| 1042 |
-
for idx1 in sampled_idx1:
|
| 1043 |
-
for idx2 in sampled_idx2:
|
| 1044 |
-
if comparison_count >= MAX_INTER_CLASS_COMPARISONS:
|
| 1045 |
-
break
|
| 1046 |
-
sim = similarities[idx1, idx2]
|
| 1047 |
-
inter_class_similarities.append(sim)
|
| 1048 |
-
comparison_count += 1
|
| 1049 |
-
if comparison_count >= MAX_INTER_CLASS_COMPARISONS:
|
| 1050 |
-
break
|
| 1051 |
-
if comparison_count >= MAX_INTER_CLASS_COMPARISONS:
|
| 1052 |
-
break
|
| 1053 |
-
if comparison_count >= MAX_INTER_CLASS_COMPARISONS:
|
| 1054 |
-
break
|
| 1055 |
-
|
| 1056 |
-
return inter_class_similarities
|
| 1057 |
-
|
| 1058 |
-
def compute_embedding_accuracy(
|
| 1059 |
-
self,
|
| 1060 |
-
embeddings: np.ndarray,
|
| 1061 |
-
labels: List[str],
|
| 1062 |
-
similarities: np.ndarray
|
| 1063 |
-
) -> float:
|
| 1064 |
-
"""
|
| 1065 |
-
Compute classification accuracy using nearest neighbor in embedding space.
|
| 1066 |
-
|
| 1067 |
-
Args:
|
| 1068 |
-
embeddings: Embedding vectors
|
| 1069 |
-
labels: True class labels
|
| 1070 |
-
similarities: Precomputed similarity matrix
|
| 1071 |
-
|
| 1072 |
-
Returns:
|
| 1073 |
-
Classification accuracy
|
| 1074 |
-
"""
|
| 1075 |
-
correct_predictions = 0
|
| 1076 |
-
total_predictions = len(labels)
|
| 1077 |
-
|
| 1078 |
-
for i in range(len(embeddings)):
|
| 1079 |
-
true_label = labels[i]
|
| 1080 |
-
|
| 1081 |
-
# Find the most similar embedding (excluding itself)
|
| 1082 |
-
similarities_row = similarities[i].copy()
|
| 1083 |
-
similarities_row[i] = -1 # Exclude self-similarity
|
| 1084 |
-
nearest_neighbor_idx = np.argmax(similarities_row)
|
| 1085 |
-
predicted_label = labels[nearest_neighbor_idx]
|
| 1086 |
-
|
| 1087 |
-
if predicted_label == true_label:
|
| 1088 |
-
correct_predictions += 1
|
| 1089 |
-
|
| 1090 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 1091 |
-
|
| 1092 |
-
def compute_centroid_accuracy(
|
| 1093 |
-
self,
|
| 1094 |
-
embeddings: np.ndarray,
|
| 1095 |
-
labels: List[str]
|
| 1096 |
-
) -> float:
|
| 1097 |
-
"""
|
| 1098 |
-
Compute classification accuracy using hierarchy centroids.
|
| 1099 |
-
|
| 1100 |
-
Args:
|
| 1101 |
-
embeddings: Embedding vectors
|
| 1102 |
-
labels: True class labels
|
| 1103 |
-
|
| 1104 |
-
Returns:
|
| 1105 |
-
Classification accuracy
|
| 1106 |
-
"""
|
| 1107 |
-
# Create centroids for each hierarchy
|
| 1108 |
-
unique_hierarchies = list(set(labels))
|
| 1109 |
-
centroids = {}
|
| 1110 |
-
|
| 1111 |
-
for hierarchy in unique_hierarchies:
|
| 1112 |
-
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
|
| 1113 |
-
hierarchy_embeddings = embeddings[hierarchy_indices]
|
| 1114 |
-
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
|
| 1115 |
-
|
| 1116 |
-
# Classify each embedding to nearest centroid
|
| 1117 |
-
correct_predictions = 0
|
| 1118 |
-
total_predictions = len(labels)
|
| 1119 |
-
|
| 1120 |
-
for i, embedding in enumerate(embeddings):
|
| 1121 |
-
true_label = labels[i]
|
| 1122 |
-
|
| 1123 |
-
# Find closest centroid
|
| 1124 |
-
best_similarity = -1
|
| 1125 |
-
predicted_label = None
|
| 1126 |
-
|
| 1127 |
-
for hierarchy, centroid in centroids.items():
|
| 1128 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 1129 |
-
if similarity > best_similarity:
|
| 1130 |
-
best_similarity = similarity
|
| 1131 |
-
predicted_label = hierarchy
|
| 1132 |
-
|
| 1133 |
-
if predicted_label == true_label:
|
| 1134 |
-
correct_predictions += 1
|
| 1135 |
-
|
| 1136 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 1137 |
-
|
| 1138 |
-
def compute_mahalanobis_distance(
|
| 1139 |
-
self,
|
| 1140 |
-
point: np.ndarray,
|
| 1141 |
-
centroid: np.ndarray,
|
| 1142 |
-
cov_inv: np.ndarray
|
| 1143 |
-
) -> float:
|
| 1144 |
-
"""
|
| 1145 |
-
Compute Mahalanobis distance between a point and a centroid.
|
| 1146 |
-
|
| 1147 |
-
The Mahalanobis distance takes into account the covariance structure
|
| 1148 |
-
of the data, making it more robust than Euclidean distance for
|
| 1149 |
-
high-dimensional spaces.
|
| 1150 |
-
|
| 1151 |
-
Args:
|
| 1152 |
-
point: Query point
|
| 1153 |
-
centroid: Class centroid
|
| 1154 |
-
cov_inv: Inverse covariance matrix
|
| 1155 |
-
|
| 1156 |
-
Returns:
|
| 1157 |
-
Mahalanobis distance
|
| 1158 |
-
"""
|
| 1159 |
-
diff = point - centroid
|
| 1160 |
-
distance = np.sqrt(np.dot(np.dot(diff, cov_inv), diff.T))
|
| 1161 |
-
return distance
|
| 1162 |
-
|
| 1163 |
-
def predict_hierarchy_from_embeddings(
|
| 1164 |
-
self,
|
| 1165 |
-
embeddings: np.ndarray,
|
| 1166 |
-
labels: List[str],
|
| 1167 |
-
use_mahalanobis: bool = False
|
| 1168 |
-
) -> List[str]:
|
| 1169 |
-
"""
|
| 1170 |
-
Predict hierarchy from embeddings using centroid-based classification.
|
| 1171 |
-
|
| 1172 |
-
Args:
|
| 1173 |
-
embeddings: Embedding vectors
|
| 1174 |
-
labels: Training labels for computing centroids
|
| 1175 |
-
use_mahalanobis: Whether to use Mahalanobis distance
|
| 1176 |
-
|
| 1177 |
-
Returns:
|
| 1178 |
-
List of predicted hierarchy labels
|
| 1179 |
-
"""
|
| 1180 |
-
# Create hierarchy centroids from training data
|
| 1181 |
-
unique_hierarchies = list(set(labels))
|
| 1182 |
-
centroids = {}
|
| 1183 |
-
cov_inverses = {}
|
| 1184 |
-
|
| 1185 |
-
for hierarchy in unique_hierarchies:
|
| 1186 |
-
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
|
| 1187 |
-
hierarchy_embeddings = embeddings[hierarchy_indices]
|
| 1188 |
-
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
|
| 1189 |
-
|
| 1190 |
-
# Compute covariance for Mahalanobis distance
|
| 1191 |
-
if use_mahalanobis and len(hierarchy_embeddings) > 1:
|
| 1192 |
-
cov = np.cov(hierarchy_embeddings.T)
|
| 1193 |
-
# Add regularization for numerical stability
|
| 1194 |
-
cov += np.eye(cov.shape[0]) * 1e-6
|
| 1195 |
-
try:
|
| 1196 |
-
cov_inverses[hierarchy] = np.linalg.inv(cov)
|
| 1197 |
-
except np.linalg.LinAlgError:
|
| 1198 |
-
# If inversion fails, fallback to identity (Euclidean)
|
| 1199 |
-
cov_inverses[hierarchy] = np.eye(cov.shape[0])
|
| 1200 |
-
|
| 1201 |
-
# Predict hierarchy for all embeddings
|
| 1202 |
-
predictions = []
|
| 1203 |
-
|
| 1204 |
-
for embedding in embeddings:
|
| 1205 |
-
if use_mahalanobis:
|
| 1206 |
-
predicted_hierarchy = self._predict_with_mahalanobis(
|
| 1207 |
-
embedding, centroids, cov_inverses
|
| 1208 |
-
)
|
| 1209 |
-
else:
|
| 1210 |
-
predicted_hierarchy = self._predict_with_cosine(
|
| 1211 |
-
embedding, centroids
|
| 1212 |
-
)
|
| 1213 |
-
predictions.append(predicted_hierarchy)
|
| 1214 |
-
|
| 1215 |
-
return predictions
|
| 1216 |
-
|
| 1217 |
-
def _predict_with_mahalanobis(
|
| 1218 |
-
self,
|
| 1219 |
-
embedding: np.ndarray,
|
| 1220 |
-
centroids: Dict[str, np.ndarray],
|
| 1221 |
-
cov_inverses: Dict[str, np.ndarray]
|
| 1222 |
-
) -> str:
|
| 1223 |
-
"""
|
| 1224 |
-
Predict class using Mahalanobis distance (lower is better).
|
| 1225 |
-
|
| 1226 |
-
Args:
|
| 1227 |
-
embedding: Query embedding
|
| 1228 |
-
centroids: Class centroids
|
| 1229 |
-
cov_inverses: Inverse covariance matrices
|
| 1230 |
-
|
| 1231 |
-
Returns:
|
| 1232 |
-
Predicted class label
|
| 1233 |
-
"""
|
| 1234 |
-
best_distance = float('inf')
|
| 1235 |
-
predicted_hierarchy = None
|
| 1236 |
-
|
| 1237 |
-
for hierarchy, centroid in centroids.items():
|
| 1238 |
-
if hierarchy in cov_inverses:
|
| 1239 |
-
distance = self.compute_mahalanobis_distance(
|
| 1240 |
-
embedding, centroid, cov_inverses[hierarchy]
|
| 1241 |
-
)
|
| 1242 |
-
else:
|
| 1243 |
-
# Fallback to cosine similarity for classes with insufficient samples
|
| 1244 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 1245 |
-
distance = 1 - similarity
|
| 1246 |
-
|
| 1247 |
-
if distance < best_distance:
|
| 1248 |
-
best_distance = distance
|
| 1249 |
-
predicted_hierarchy = hierarchy
|
| 1250 |
-
|
| 1251 |
-
return predicted_hierarchy
|
| 1252 |
-
|
| 1253 |
-
def _predict_with_cosine(
|
| 1254 |
-
self,
|
| 1255 |
-
embedding: np.ndarray,
|
| 1256 |
-
centroids: Dict[str, np.ndarray]
|
| 1257 |
-
) -> str:
|
| 1258 |
-
"""
|
| 1259 |
-
Predict class using cosine similarity (higher is better).
|
| 1260 |
-
|
| 1261 |
-
Args:
|
| 1262 |
-
embedding: Query embedding
|
| 1263 |
-
centroids: Class centroids
|
| 1264 |
-
|
| 1265 |
-
Returns:
|
| 1266 |
-
Predicted class label
|
| 1267 |
-
"""
|
| 1268 |
-
best_similarity = -1
|
| 1269 |
-
predicted_hierarchy = None
|
| 1270 |
-
|
| 1271 |
-
for hierarchy, centroid in centroids.items():
|
| 1272 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 1273 |
-
if similarity > best_similarity:
|
| 1274 |
-
best_similarity = similarity
|
| 1275 |
-
predicted_hierarchy = hierarchy
|
| 1276 |
-
|
| 1277 |
-
return predicted_hierarchy
|
| 1278 |
-
|
| 1279 |
-
def create_confusion_matrix(
|
| 1280 |
-
self,
|
| 1281 |
-
true_labels: List[str],
|
| 1282 |
-
predicted_labels: List[str],
|
| 1283 |
-
title: str = "Confusion Matrix"
|
| 1284 |
-
) -> Tuple[plt.Figure, float, np.ndarray]:
|
| 1285 |
-
"""
|
| 1286 |
-
Create and plot confusion matrix.
|
| 1287 |
-
|
| 1288 |
-
Args:
|
| 1289 |
-
true_labels: Ground truth labels
|
| 1290 |
-
predicted_labels: Predicted labels
|
| 1291 |
-
title: Plot title
|
| 1292 |
-
|
| 1293 |
-
Returns:
|
| 1294 |
-
Tuple of (figure, accuracy, confusion_matrix)
|
| 1295 |
-
"""
|
| 1296 |
-
# Get unique labels
|
| 1297 |
-
unique_labels = sorted(list(set(true_labels + predicted_labels)))
|
| 1298 |
-
|
| 1299 |
-
# Create confusion matrix
|
| 1300 |
-
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
|
| 1301 |
-
|
| 1302 |
-
# Calculate accuracy
|
| 1303 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 1304 |
-
|
| 1305 |
-
# Plot confusion matrix
|
| 1306 |
-
plt.figure(figsize=(12, 10))
|
| 1307 |
-
sns.heatmap(
|
| 1308 |
-
cm,
|
| 1309 |
-
annot=True,
|
| 1310 |
-
fmt='d',
|
| 1311 |
-
cmap='Blues',
|
| 1312 |
-
xticklabels=unique_labels,
|
| 1313 |
-
yticklabels=unique_labels
|
| 1314 |
-
)
|
| 1315 |
-
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
|
| 1316 |
-
plt.ylabel('True Hierarchy')
|
| 1317 |
-
plt.xlabel('Predicted Hierarchy')
|
| 1318 |
-
plt.xticks(rotation=45)
|
| 1319 |
-
plt.yticks(rotation=0)
|
| 1320 |
-
plt.tight_layout()
|
| 1321 |
-
|
| 1322 |
-
return plt.gcf(), accuracy, cm
|
| 1323 |
-
|
| 1324 |
-
def evaluate_classification_performance(
|
| 1325 |
-
self,
|
| 1326 |
-
embeddings: np.ndarray,
|
| 1327 |
-
labels: List[str],
|
| 1328 |
-
embedding_type: str = "Embeddings",
|
| 1329 |
-
apply_whitening_norm: bool = False,
|
| 1330 |
-
use_mahalanobis: bool = False
|
| 1331 |
-
) -> Dict[str, Any]:
|
| 1332 |
-
"""
|
| 1333 |
-
Evaluate classification performance and create confusion matrix.
|
| 1334 |
-
|
| 1335 |
-
Args:
|
| 1336 |
-
embeddings: Embedding vectors
|
| 1337 |
-
labels: True class labels
|
| 1338 |
-
embedding_type: Description of embedding type for display
|
| 1339 |
-
apply_whitening_norm: Whether to apply ZCA whitening
|
| 1340 |
-
use_mahalanobis: Whether to use Mahalanobis distance
|
| 1341 |
-
|
| 1342 |
-
Returns:
|
| 1343 |
-
Dictionary containing classification metrics and visualizations
|
| 1344 |
-
"""
|
| 1345 |
-
# Apply whitening if requested
|
| 1346 |
-
if apply_whitening_norm:
|
| 1347 |
-
embeddings = self.apply_whitening(embeddings)
|
| 1348 |
-
|
| 1349 |
-
# Predict hierarchy
|
| 1350 |
-
predictions = self.predict_hierarchy_from_embeddings(
|
| 1351 |
-
embeddings, labels, use_mahalanobis=use_mahalanobis
|
| 1352 |
-
)
|
| 1353 |
-
|
| 1354 |
-
# Calculate accuracy
|
| 1355 |
-
accuracy = accuracy_score(labels, predictions)
|
| 1356 |
-
|
| 1357 |
-
# Calculate F1 scores
|
| 1358 |
-
unique_labels = sorted(list(set(labels)))
|
| 1359 |
-
f1_macro = f1_score(
|
| 1360 |
-
labels, predictions, labels=unique_labels,
|
| 1361 |
-
average='macro', zero_division=0
|
| 1362 |
-
)
|
| 1363 |
-
f1_weighted = f1_score(
|
| 1364 |
-
labels, predictions, labels=unique_labels,
|
| 1365 |
-
average='weighted', zero_division=0
|
| 1366 |
-
)
|
| 1367 |
-
f1_per_class = f1_score(
|
| 1368 |
-
labels, predictions, labels=unique_labels,
|
| 1369 |
-
average=None, zero_division=0
|
| 1370 |
-
)
|
| 1371 |
-
|
| 1372 |
-
# Create confusion matrix
|
| 1373 |
-
fig, acc, cm = self.create_confusion_matrix(
|
| 1374 |
-
labels, predictions,
|
| 1375 |
-
f"{embedding_type} - Hierarchy Classification"
|
| 1376 |
-
)
|
| 1377 |
-
|
| 1378 |
-
# Generate classification report
|
| 1379 |
-
report = classification_report(
|
| 1380 |
-
labels, predictions, labels=unique_labels,
|
| 1381 |
-
target_names=unique_labels, output_dict=True
|
| 1382 |
-
)
|
| 1383 |
-
|
| 1384 |
-
return {
|
| 1385 |
-
'accuracy': accuracy,
|
| 1386 |
-
'f1_macro': f1_macro,
|
| 1387 |
-
'f1_weighted': f1_weighted,
|
| 1388 |
-
'f1_per_class': f1_per_class,
|
| 1389 |
-
'predictions': predictions,
|
| 1390 |
-
'confusion_matrix': cm,
|
| 1391 |
-
'classification_report': report,
|
| 1392 |
-
'figure': fig
|
| 1393 |
-
}
|
| 1394 |
-
|
| 1395 |
-
def evaluate_dataset_with_baselines(
|
| 1396 |
-
self,
|
| 1397 |
-
dataframe: Union[pd.DataFrame, Dataset],
|
| 1398 |
-
dataset_name: str = "Dataset",
|
| 1399 |
-
use_whitening: bool = False,
|
| 1400 |
-
use_mahalanobis: bool = False
|
| 1401 |
-
) -> Dict[str, Dict[str, Any]]:
|
| 1402 |
-
"""
|
| 1403 |
-
Evaluate embeddings on a given dataset with both custom model and CLIP baseline.
|
| 1404 |
-
|
| 1405 |
-
This is the main evaluation method that compares the custom model against
|
| 1406 |
-
the Fashion-CLIP baseline across multiple metrics and embedding types.
|
| 1407 |
-
Aligned with main_model_evaluation.py for consistency (no TTA for fair comparison).
|
| 1408 |
-
|
| 1409 |
-
Args:
|
| 1410 |
-
dataframe: DataFrame or Dataset to evaluate on
|
| 1411 |
-
dataset_name: Name of the dataset for display
|
| 1412 |
-
use_whitening: Whether to apply ZCA whitening
|
| 1413 |
-
use_mahalanobis: Whether to use Mahalanobis distance
|
| 1414 |
-
|
| 1415 |
-
Returns:
|
| 1416 |
-
Dictionary containing results for all models and embedding types
|
| 1417 |
-
"""
|
| 1418 |
-
print(f"\n{'='*60}")
|
| 1419 |
-
print(f"Evaluating {dataset_name}")
|
| 1420 |
-
if use_whitening:
|
| 1421 |
-
print(f"🎯 ZCA Whitening ENABLED for better feature decorrelation")
|
| 1422 |
-
if use_mahalanobis:
|
| 1423 |
-
print(f"🎯 Mahalanobis Distance ENABLED for classification")
|
| 1424 |
-
print(f"{'='*60}")
|
| 1425 |
-
|
| 1426 |
-
results = {}
|
| 1427 |
-
|
| 1428 |
-
# ===== CUSTOM MODEL EVALUATION =====
|
| 1429 |
-
print(f"\n🔧 Evaluating Custom Model on {dataset_name}")
|
| 1430 |
-
print("-" * 40)
|
| 1431 |
-
|
| 1432 |
-
# Create dataloader
|
| 1433 |
-
custom_dataloader = self.create_dataloader(dataframe, batch_size=16)
|
| 1434 |
-
|
| 1435 |
-
# Evaluate text embeddings
|
| 1436 |
-
text_embeddings, text_labels, texts = self.extract_custom_embeddings(
|
| 1437 |
-
custom_dataloader, 'text', use_tta=False
|
| 1438 |
-
)
|
| 1439 |
-
text_metrics = self.compute_similarity_metrics(
|
| 1440 |
-
text_embeddings, text_labels, apply_whitening_norm=use_whitening
|
| 1441 |
-
)
|
| 1442 |
-
text_classification = self.evaluate_classification_performance(
|
| 1443 |
-
text_embeddings, text_labels, "Custom Text Embeddings",
|
| 1444 |
-
apply_whitening_norm=use_whitening, use_mahalanobis=use_mahalanobis
|
| 1445 |
-
)
|
| 1446 |
-
text_metrics.update(text_classification)
|
| 1447 |
-
results['custom_text'] = text_metrics
|
| 1448 |
-
|
| 1449 |
-
# Evaluate image embeddings
|
| 1450 |
-
# NOTE: TTA disabled for fair comparison
|
| 1451 |
-
image_embeddings, image_labels, _ = self.extract_custom_embeddings(
|
| 1452 |
-
custom_dataloader, 'image', use_tta=False
|
| 1453 |
-
)
|
| 1454 |
-
image_metrics = self.compute_similarity_metrics(
|
| 1455 |
-
image_embeddings, image_labels, apply_whitening_norm=use_whitening
|
| 1456 |
-
)
|
| 1457 |
-
whitening_suffix = " + Whitening" if use_whitening else ""
|
| 1458 |
-
mahalanobis_suffix = " + Mahalanobis" if use_mahalanobis else ""
|
| 1459 |
-
image_classification = self.evaluate_classification_performance(
|
| 1460 |
-
image_embeddings, image_labels,
|
| 1461 |
-
f"Custom Image Embeddings{whitening_suffix}{mahalanobis_suffix}",
|
| 1462 |
-
apply_whitening_norm=use_whitening, use_mahalanobis=use_mahalanobis
|
| 1463 |
-
)
|
| 1464 |
-
image_metrics.update(image_classification)
|
| 1465 |
-
results['custom_image'] = image_metrics
|
| 1466 |
-
|
| 1467 |
-
# ===== FASHION-CLIP BASELINE EVALUATION =====
|
| 1468 |
-
print(f"\n🤗 Evaluating Fashion-CLIP Baseline on {dataset_name}")
|
| 1469 |
-
print("-" * 40)
|
| 1470 |
-
|
| 1471 |
-
# Create dataloader for Fashion-CLIP
|
| 1472 |
-
clip_dataloader = self.create_clip_dataloader(dataframe, batch_size=8)
|
| 1473 |
-
|
| 1474 |
-
# Extract data for Fashion-CLIP
|
| 1475 |
-
all_images = []
|
| 1476 |
-
all_texts = []
|
| 1477 |
-
all_labels = []
|
| 1478 |
-
|
| 1479 |
-
for batch in tqdm(clip_dataloader, desc="Preparing data for Fashion-CLIP"):
|
| 1480 |
-
# Handle different batch formats
|
| 1481 |
-
if len(batch) == 4:
|
| 1482 |
-
images, descriptions, colors, hierarchies = batch
|
| 1483 |
-
else:
|
| 1484 |
-
images, descriptions, hierarchies = batch
|
| 1485 |
-
|
| 1486 |
-
all_images.extend(images)
|
| 1487 |
-
all_texts.extend(descriptions)
|
| 1488 |
-
all_labels.extend(hierarchies)
|
| 1489 |
-
|
| 1490 |
-
# Get Fashion-CLIP embeddings
|
| 1491 |
-
clip_image_embeddings, clip_text_embeddings = self.clip_evaluator.extract_clip_embeddings(
|
| 1492 |
-
all_images, all_texts
|
| 1493 |
-
)
|
| 1494 |
-
|
| 1495 |
-
# Evaluate Fashion-CLIP text embeddings
|
| 1496 |
-
clip_text_metrics = self.compute_similarity_metrics(
|
| 1497 |
-
clip_text_embeddings, all_labels
|
| 1498 |
-
)
|
| 1499 |
-
clip_text_classification = self.evaluate_classification_performance(
|
| 1500 |
-
clip_text_embeddings, all_labels, "Fashion-CLIP Text Embeddings"
|
| 1501 |
-
)
|
| 1502 |
-
clip_text_metrics.update(clip_text_classification)
|
| 1503 |
-
results['clip_text'] = clip_text_metrics
|
| 1504 |
-
|
| 1505 |
-
# Evaluate Fashion-CLIP image embeddings
|
| 1506 |
-
clip_image_metrics = self.compute_similarity_metrics(
|
| 1507 |
-
clip_image_embeddings, all_labels
|
| 1508 |
-
)
|
| 1509 |
-
clip_image_classification = self.evaluate_classification_performance(
|
| 1510 |
-
clip_image_embeddings, all_labels, "Fashion-CLIP Image Embeddings"
|
| 1511 |
-
)
|
| 1512 |
-
clip_image_metrics.update(clip_image_classification)
|
| 1513 |
-
results['clip_image'] = clip_image_metrics
|
| 1514 |
-
|
| 1515 |
-
# ===== PRINT COMPARISON RESULTS =====
|
| 1516 |
-
self._print_comparison_results(dataframe, dataset_name, results)
|
| 1517 |
-
|
| 1518 |
-
# ===== SAVE VISUALIZATIONS =====
|
| 1519 |
-
self._save_visualizations(dataset_name, results)
|
| 1520 |
-
|
| 1521 |
-
return results
|
| 1522 |
-
|
| 1523 |
-
def _print_comparison_results(
|
| 1524 |
-
self,
|
| 1525 |
-
dataframe: Union[pd.DataFrame, Dataset],
|
| 1526 |
-
dataset_name: str,
|
| 1527 |
-
results: Dict[str, Dict[str, Any]]
|
| 1528 |
-
):
|
| 1529 |
-
"""
|
| 1530 |
-
Print formatted comparison results.
|
| 1531 |
-
|
| 1532 |
-
Args:
|
| 1533 |
-
dataframe: Dataset being evaluated
|
| 1534 |
-
dataset_name: Name of the dataset
|
| 1535 |
-
results: Evaluation results dictionary
|
| 1536 |
-
"""
|
| 1537 |
-
dataset_size = len(dataframe) if hasattr(dataframe, '__len__') else "N/A"
|
| 1538 |
-
|
| 1539 |
-
print(f"\n{dataset_name} Results Comparison:")
|
| 1540 |
-
print(f"Dataset size: {dataset_size} samples")
|
| 1541 |
-
print("=" * 80)
|
| 1542 |
-
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<10} {'NN Acc':<8} {'Centroid Acc':<12} {'F1 Macro':<10}")
|
| 1543 |
-
print("-" * 80)
|
| 1544 |
-
|
| 1545 |
-
for model_type in ['custom', 'clip']:
|
| 1546 |
-
for emb_type in ['text', 'image']:
|
| 1547 |
-
key = f"{model_type}_{emb_type}"
|
| 1548 |
-
if key in results:
|
| 1549 |
-
metrics = results[key]
|
| 1550 |
-
model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline"
|
| 1551 |
-
print(
|
| 1552 |
-
f"{model_name:<20} "
|
| 1553 |
-
f"{emb_type.capitalize():<10} "
|
| 1554 |
-
f"{metrics['separation_score']:<10.4f} "
|
| 1555 |
-
f"{metrics['accuracy']*100:<8.1f}% "
|
| 1556 |
-
f"{metrics['centroid_accuracy']*100:<12.1f}% "
|
| 1557 |
-
f"{metrics['f1_macro']*100:<10.1f}%"
|
| 1558 |
-
)
|
| 1559 |
-
|
| 1560 |
-
def _save_visualizations(
|
| 1561 |
-
self,
|
| 1562 |
-
dataset_name: str,
|
| 1563 |
-
results: Dict[str, Dict[str, Any]]
|
| 1564 |
-
):
|
| 1565 |
-
"""
|
| 1566 |
-
Save confusion matrices and other visualizations.
|
| 1567 |
-
|
| 1568 |
-
Args:
|
| 1569 |
-
dataset_name: Name of the dataset
|
| 1570 |
-
results: Evaluation results dictionary
|
| 1571 |
-
"""
|
| 1572 |
-
os.makedirs(self.directory, exist_ok=True)
|
| 1573 |
-
|
| 1574 |
-
# Save confusion matrices
|
| 1575 |
-
for key, metrics in results.items():
|
| 1576 |
-
if 'figure' in metrics:
|
| 1577 |
-
filename = f'{self.directory}/{dataset_name.lower()}_{key}_confusion_matrix.png'
|
| 1578 |
-
metrics['figure'].savefig(filename, dpi=300, bbox_inches='tight')
|
| 1579 |
-
plt.close(metrics['figure'])
|
| 1580 |
-
|
| 1581 |
-
|
| 1582 |
-
# ============================================================================
|
| 1583 |
-
# DATASET LOADING FUNCTIONS
|
| 1584 |
-
# ============================================================================
|
| 1585 |
-
|
| 1586 |
-
def load_fashion_mnist_dataset(
|
| 1587 |
-
evaluator: EmbeddingEvaluator,
|
| 1588 |
-
max_samples: int = 1000
|
| 1589 |
-
) -> FashionMNISTDataset:
|
| 1590 |
-
"""
|
| 1591 |
-
Load and prepare Fashion-MNIST test dataset.
|
| 1592 |
-
|
| 1593 |
-
This function loads the Fashion-MNIST test set and creates appropriate
|
| 1594 |
-
mappings to the custom model's hierarchy classes.
|
| 1595 |
-
Exactly aligned with main_model_evaluation.py for consistency.
|
| 1596 |
-
|
| 1597 |
-
Args:
|
| 1598 |
-
evaluator: EmbeddingEvaluator instance with loaded model
|
| 1599 |
-
max_samples: Maximum number of samples to use
|
| 1600 |
-
|
| 1601 |
-
Returns:
|
| 1602 |
-
FashionMNISTDataset object
|
| 1603 |
-
"""
|
| 1604 |
-
print("📊 Loading Fashion-MNIST test dataset...")
|
| 1605 |
-
df = pd.read_csv(config.fashion_mnist_test_path)
|
| 1606 |
-
print(f"✅ Fashion-MNIST dataset loaded: {len(df)} samples")
|
| 1607 |
-
|
| 1608 |
-
# Create mapping if hierarchy classes are provided
|
| 1609 |
-
label_mapping = None
|
| 1610 |
-
if evaluator.hierarchy_classes is not None:
|
| 1611 |
-
print("\n🔗 Creating mapping from Fashion-MNIST labels to hierarchy classes:")
|
| 1612 |
-
label_mapping = create_fashion_mnist_to_hierarchy_mapping(
|
| 1613 |
-
evaluator.hierarchy_classes
|
| 1614 |
-
)
|
| 1615 |
-
|
| 1616 |
-
# Filter dataset to only include samples that can be mapped
|
| 1617 |
-
valid_label_ids = [
|
| 1618 |
-
label_id for label_id, hierarchy in label_mapping.items()
|
| 1619 |
-
if hierarchy is not None
|
| 1620 |
-
]
|
| 1621 |
-
df_filtered = df[df['label'].isin(valid_label_ids)]
|
| 1622 |
-
print(
|
| 1623 |
-
f"\n📊 After filtering to mappable labels: "
|
| 1624 |
-
f"{len(df_filtered)} samples (from {len(df)})"
|
| 1625 |
-
)
|
| 1626 |
-
|
| 1627 |
-
# Apply max_samples limit after filtering
|
| 1628 |
-
df_sample = df_filtered.head(max_samples)
|
| 1629 |
-
else:
|
| 1630 |
-
df_sample = df.head(max_samples)
|
| 1631 |
-
|
| 1632 |
-
print(f"📊 Using {len(df_sample)} samples for evaluation")
|
| 1633 |
-
return FashionMNISTDataset(df_sample, label_mapping=label_mapping)
|
| 1634 |
-
|
| 1635 |
-
|
| 1636 |
-
def load_kagl_marqo_dataset(evaluator: EmbeddingEvaluator) -> pd.DataFrame:
|
| 1637 |
-
"""
|
| 1638 |
-
Load and prepare Kaggle Marqo dataset for evaluation.
|
| 1639 |
-
|
| 1640 |
-
This function loads the Marqo fashion dataset from Hugging Face
|
| 1641 |
-
and preprocesses it for evaluation with the custom model.
|
| 1642 |
-
|
| 1643 |
-
Args:
|
| 1644 |
-
evaluator: EmbeddingEvaluator instance with loaded model
|
| 1645 |
-
|
| 1646 |
-
Returns:
|
| 1647 |
-
Formatted pandas DataFrame ready for evaluation
|
| 1648 |
-
"""
|
| 1649 |
-
from datasets import load_dataset
|
| 1650 |
-
|
| 1651 |
-
print("📊 Loading Kaggle Marqo dataset...")
|
| 1652 |
-
|
| 1653 |
-
# Load the dataset from Hugging Face
|
| 1654 |
-
dataset = load_dataset("Marqo/KAGL")
|
| 1655 |
-
df = dataset["data"].to_pandas()
|
| 1656 |
-
|
| 1657 |
-
print(f"✅ Dataset Kaggle loaded")
|
| 1658 |
-
print(f"📊 Before filtering: {len(df)} samples")
|
| 1659 |
-
print(f"📋 Available columns: {list(df.columns)}")
|
| 1660 |
-
print(f"🎨 Available categories: {sorted(df['category2'].unique())}")
|
| 1661 |
-
|
| 1662 |
-
# Map categories to our hierarchy format
|
| 1663 |
-
df['hierarchy'] = df['category2'].str.lower()
|
| 1664 |
-
df['hierarchy'] = df['hierarchy'].replace({
|
| 1665 |
-
'bags': 'bag',
|
| 1666 |
-
'topwear': 'top',
|
| 1667 |
-
'flip flops': 'shoes',
|
| 1668 |
-
'sandal': 'shoes'
|
| 1669 |
-
})
|
| 1670 |
-
|
| 1671 |
-
# Filter to only include valid hierarchies
|
| 1672 |
-
valid_hierarchies = df['hierarchy'].dropna().unique()
|
| 1673 |
-
print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
|
| 1674 |
-
print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
|
| 1675 |
-
|
| 1676 |
-
df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
|
| 1677 |
-
print(f"📊 After filtering to model hierarchies: {len(df)} samples")
|
| 1678 |
-
|
| 1679 |
-
if len(df) == 0:
|
| 1680 |
-
print("❌ No samples left after hierarchy filtering.")
|
| 1681 |
-
return pd.DataFrame()
|
| 1682 |
-
|
| 1683 |
-
# Ensure we have text and image data
|
| 1684 |
-
df = df.dropna(subset=['text', 'image'])
|
| 1685 |
-
print(f"📊 After removing missing text/image: {len(df)} samples")
|
| 1686 |
-
|
| 1687 |
-
# Show sample of text data to verify quality
|
| 1688 |
-
print(f"📝 Sample texts:")
|
| 1689 |
-
for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
|
| 1690 |
-
print(f" {i+1}. [{hierarchy}] {text[:100]}...")
|
| 1691 |
-
|
| 1692 |
-
# Limit size to prevent memory overload
|
| 1693 |
-
max_samples = 1000
|
| 1694 |
-
if len(df) > max_samples:
|
| 1695 |
-
print(f"⚠️ Dataset too large ({len(df)} samples), sampling to {max_samples} samples")
|
| 1696 |
-
df_test = df.sample(n=max_samples, random_state=42).reset_index(drop=True)
|
| 1697 |
-
else:
|
| 1698 |
-
df_test = df.copy()
|
| 1699 |
-
|
| 1700 |
-
print(f"📊 After sampling: {len(df_test)} samples")
|
| 1701 |
-
print(f"📊 Samples per hierarchy:")
|
| 1702 |
-
for hierarchy in sorted(df_test['hierarchy'].unique()):
|
| 1703 |
-
count = len(df_test[df_test['hierarchy'] == hierarchy])
|
| 1704 |
-
print(f" {hierarchy}: {count} samples")
|
| 1705 |
-
|
| 1706 |
-
# Create formatted dataset with proper column names
|
| 1707 |
-
kagl_formatted = pd.DataFrame({
|
| 1708 |
-
'image_url': df_test['image'],
|
| 1709 |
-
'text': df_test['text'],
|
| 1710 |
-
'hierarchy': df_test['hierarchy']
|
| 1711 |
-
})
|
| 1712 |
-
|
| 1713 |
-
print(f"📊 Final dataset size: {len(kagl_formatted)} samples")
|
| 1714 |
-
return kagl_formatted
|
| 1715 |
-
|
| 1716 |
-
|
| 1717 |
-
# ============================================================================
|
| 1718 |
-
# MAIN EXECUTION
|
| 1719 |
-
# ============================================================================
|
| 1720 |
-
|
| 1721 |
-
def main():
|
| 1722 |
-
"""
|
| 1723 |
-
Main evaluation function that runs comprehensive evaluation across multiple datasets.
|
| 1724 |
-
|
| 1725 |
-
This function evaluates the custom hierarchy classification model against the
|
| 1726 |
-
Fashion-CLIP baseline on:
|
| 1727 |
-
1. Validation dataset (from training data)
|
| 1728 |
-
2. Fashion-MNIST test dataset
|
| 1729 |
-
3. Kaggle Marqo dataset
|
| 1730 |
-
|
| 1731 |
-
Results include detailed metrics, confusion matrices, and performance comparisons.
|
| 1732 |
-
"""
|
| 1733 |
-
# Setup output directory
|
| 1734 |
-
directory = "hierarchy_model_analysis"
|
| 1735 |
-
|
| 1736 |
-
print(f"🚀 Starting evaluation with custom model: {hierarchy_model_path}")
|
| 1737 |
-
print(f"🤗 Including Fashion-CLIP baseline comparison")
|
| 1738 |
-
|
| 1739 |
-
# Initialize evaluator
|
| 1740 |
-
evaluator = EmbeddingEvaluator(hierarchy_model_path, directory)
|
| 1741 |
-
|
| 1742 |
-
print(
|
| 1743 |
-
f"📊 Final hierarchy classes after initialization: "
|
| 1744 |
-
f"{len(evaluator.vocab.hierarchy_classes)} classes"
|
| 1745 |
-
)
|
| 1746 |
-
|
| 1747 |
-
# ===== EVALUATION 1: VALIDATION DATASET =====
|
| 1748 |
-
print("\n" + "="*60)
|
| 1749 |
-
print("EVALUATING VALIDATION DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE")
|
| 1750 |
-
print("="*60)
|
| 1751 |
-
val_results = evaluator.evaluate_dataset_with_baselines(
|
| 1752 |
-
evaluator.val_df,
|
| 1753 |
-
"Validation Dataset"
|
| 1754 |
-
)
|
| 1755 |
-
|
| 1756 |
-
# ===== EVALUATION 2: FASHION-MNIST TEST DATASET =====
|
| 1757 |
-
print("\n" + "="*60)
|
| 1758 |
-
print("EVALUATING FASHION-MNIST TEST DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE")
|
| 1759 |
-
print("="*60)
|
| 1760 |
-
fashion_mnist_dataset = load_fashion_mnist_dataset(evaluator, max_samples=1000)
|
| 1761 |
-
if fashion_mnist_dataset is not None:
|
| 1762 |
-
# Aligned with main_model_evaluation.py: NO TTA for fair baseline comparison
|
| 1763 |
-
fashion_mnist_results = evaluator.evaluate_dataset_with_baselines(
|
| 1764 |
-
fashion_mnist_dataset,
|
| 1765 |
-
"Fashion-MNIST Test Dataset",
|
| 1766 |
-
use_whitening=False, # Disabled for fair comparison
|
| 1767 |
-
use_mahalanobis=False # Disabled for fair comparison
|
| 1768 |
-
)
|
| 1769 |
-
else:
|
| 1770 |
-
fashion_mnist_results = {}
|
| 1771 |
-
|
| 1772 |
-
# ===== EVALUATION 3: KAGGLE MARQO DATASET =====
|
| 1773 |
-
print("\n" + "="*60)
|
| 1774 |
-
print("EVALUATING KAGGLE MARQO DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE")
|
| 1775 |
-
print("="*60)
|
| 1776 |
-
df_kagl_marqo = load_kagl_marqo_dataset(evaluator)
|
| 1777 |
-
if len(df_kagl_marqo) > 0:
|
| 1778 |
-
kagl_results = evaluator.evaluate_dataset_with_baselines(
|
| 1779 |
-
df_kagl_marqo,
|
| 1780 |
-
"Kaggle Marqo Dataset"
|
| 1781 |
-
)
|
| 1782 |
-
else:
|
| 1783 |
-
kagl_results = {}
|
| 1784 |
-
|
| 1785 |
-
# ===== FINAL SUMMARY =====
|
| 1786 |
-
print(f"\n{'='*80}")
|
| 1787 |
-
print("FINAL EVALUATION SUMMARY - CUSTOM MODEL vs FASHION-CLIP BASELINE")
|
| 1788 |
-
print(f"{'='*80}")
|
| 1789 |
-
|
| 1790 |
-
# Print validation results
|
| 1791 |
-
print("\n🔍 VALIDATION DATASET RESULTS:")
|
| 1792 |
-
_print_dataset_results(val_results, len(evaluator.val_df))
|
| 1793 |
-
|
| 1794 |
-
# Print Fashion-MNIST results
|
| 1795 |
-
if fashion_mnist_results:
|
| 1796 |
-
print("\n👗 FASHION-MNIST TEST DATASET RESULTS:")
|
| 1797 |
-
_print_dataset_results(fashion_mnist_results, 1000)
|
| 1798 |
-
|
| 1799 |
-
# Print Kaggle results
|
| 1800 |
-
if kagl_results:
|
| 1801 |
-
print("\n🌐 KAGGLE MARQO DATASET RESULTS:")
|
| 1802 |
-
_print_dataset_results(
|
| 1803 |
-
kagl_results,
|
| 1804 |
-
len(df_kagl_marqo) if df_kagl_marqo is not None else 'N/A'
|
| 1805 |
-
)
|
| 1806 |
-
|
| 1807 |
-
# Final completion message
|
| 1808 |
-
print(f"\n✅ Evaluation completed! Check '{directory}/' for visualization files.")
|
| 1809 |
-
print(f"📊 Custom model hierarchy classes: {len(evaluator.vocab.hierarchy_classes)} classes")
|
| 1810 |
-
print(f"🤗 Fashion-CLIP baseline comparison included")
|
| 1811 |
-
|
| 1812 |
-
|
| 1813 |
-
def _print_dataset_results(results: Dict[str, Dict[str, Any]], dataset_size: int):
|
| 1814 |
-
"""
|
| 1815 |
-
Print formatted results for a single dataset.
|
| 1816 |
-
|
| 1817 |
-
Args:
|
| 1818 |
-
results: Dictionary containing evaluation results
|
| 1819 |
-
dataset_size: Number of samples in the dataset
|
| 1820 |
-
"""
|
| 1821 |
-
print(f"Dataset size: {dataset_size} samples")
|
| 1822 |
-
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
|
| 1823 |
-
print("-" * 80)
|
| 1824 |
-
|
| 1825 |
-
for model_type in ['custom', 'clip']:
|
| 1826 |
-
for emb_type in ['text', 'image']:
|
| 1827 |
-
key = f"{model_type}_{emb_type}"
|
| 1828 |
-
if key in results:
|
| 1829 |
-
metrics = results[key]
|
| 1830 |
-
model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline"
|
| 1831 |
-
print(
|
| 1832 |
-
f"{model_name:<20} "
|
| 1833 |
-
f"{emb_type.capitalize():<10} "
|
| 1834 |
-
f"{metrics['separation_score']:<12.4f} "
|
| 1835 |
-
f"{metrics['accuracy']*100:<10.1f}% "
|
| 1836 |
-
f"{metrics['centroid_accuracy']*100:<12.1f}% "
|
| 1837 |
-
f"{metrics['f1_macro']*100:<10.1f}%"
|
| 1838 |
-
)
|
| 1839 |
-
|
| 1840 |
-
|
| 1841 |
-
if __name__ == "__main__":
|
| 1842 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluation/run_all_evaluations.py
CHANGED
|
@@ -1,327 +1,226 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
-
===========================
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
|
| 9 |
-
Usage
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
Author: Lea Attia Sarfati
|
| 20 |
"""
|
| 21 |
|
| 22 |
-
import os
|
| 23 |
-
import sys
|
| 24 |
-
import json
|
| 25 |
import argparse
|
| 26 |
-
|
|
|
|
| 27 |
from datetime import datetime
|
| 28 |
-
|
| 29 |
-
import pandas as pd
|
| 30 |
|
| 31 |
-
#
|
| 32 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 33 |
|
| 34 |
-
|
| 35 |
-
try:
|
| 36 |
-
from evaluation.main_model_evaluation import (
|
| 37 |
-
evaluate_fashion_mnist,
|
| 38 |
-
evaluate_kaggle_marqo,
|
| 39 |
-
evaluate_local_validation
|
| 40 |
-
)
|
| 41 |
-
from example_usage import load_models_from_hf
|
| 42 |
-
except ImportError as e:
|
| 43 |
-
print(f"⚠️ Import error: {e}")
|
| 44 |
-
print("Make sure you're running from the correct directory")
|
| 45 |
-
sys.exit(1)
|
| 46 |
|
| 47 |
|
| 48 |
class EvaluationRunner:
|
| 49 |
-
"""
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
Runs all available evaluations and generates a summary report.
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
def __init__(self, repo_id: str, output_dir: str = "evaluation_results"):
|
| 56 |
-
"""
|
| 57 |
-
Initialize the evaluation runner.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
repo_id: Hugging Face repository ID
|
| 61 |
-
output_dir: Directory to save results
|
| 62 |
-
"""
|
| 63 |
-
self.repo_id = repo_id
|
| 64 |
self.output_dir = Path(output_dir)
|
| 65 |
self.output_dir.mkdir(exist_ok=True, parents=True)
|
| 66 |
-
|
| 67 |
-
# Create timestamp for this run
|
| 68 |
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 69 |
-
self.
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def
|
| 76 |
-
"""
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
return False
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
"
|
| 91 |
-
print("
|
| 92 |
-
print("👕 Fashion-MNIST Evaluation")
|
| 93 |
-
print("=" * 80)
|
| 94 |
-
|
| 95 |
-
try:
|
| 96 |
-
results = evaluate_fashion_mnist(
|
| 97 |
-
model=self.models['main_model'],
|
| 98 |
-
processor=self.models['processor'],
|
| 99 |
-
device=self.models['device']
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
self.results['fashion_mnist'] = results
|
| 103 |
-
print("✅ Fashion-MNIST evaluation completed")
|
| 104 |
-
return results
|
| 105 |
-
|
| 106 |
-
except Exception as e:
|
| 107 |
-
print(f"❌ Fashion-MNIST evaluation failed: {e}")
|
| 108 |
-
return None
|
| 109 |
-
|
| 110 |
-
def run_kaggle_evaluation(self):
|
| 111 |
-
"""Run KAGL Marqo evaluation."""
|
| 112 |
-
print("\n" + "=" * 80)
|
| 113 |
-
print("🛍️ KAGL Marqo Evaluation")
|
| 114 |
-
print("=" * 80)
|
| 115 |
-
|
| 116 |
-
try:
|
| 117 |
-
results = evaluate_kaggle_marqo(
|
| 118 |
-
model=self.models['main_model'],
|
| 119 |
-
processor=self.models['processor'],
|
| 120 |
-
device=self.models['device']
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
self.results['kaggle_marqo'] = results
|
| 124 |
-
print("✅ KAGL Marqo evaluation completed")
|
| 125 |
-
return results
|
| 126 |
-
|
| 127 |
-
except Exception as e:
|
| 128 |
-
print(f"❌ KAGL Marqo evaluation failed: {e}")
|
| 129 |
-
return None
|
| 130 |
-
|
| 131 |
-
def run_local_evaluation(self):
|
| 132 |
-
"""Run local validation evaluation."""
|
| 133 |
-
print("\n" + "=" * 80)
|
| 134 |
-
print("📁 Local Validation Evaluation")
|
| 135 |
-
print("=" * 80)
|
| 136 |
-
|
| 137 |
try:
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
return results
|
| 147 |
-
|
| 148 |
-
except Exception as e:
|
| 149 |
-
print(f"❌ Local validation evaluation failed: {e}")
|
| 150 |
-
return None
|
| 151 |
-
|
| 152 |
-
def generate_summary(self):
|
| 153 |
-
"""Generate summary report."""
|
| 154 |
-
print("\n" + "=" * 80)
|
| 155 |
-
print("📊 Generating Summary Report")
|
| 156 |
-
print("=" * 80)
|
| 157 |
-
|
| 158 |
-
summary = {
|
| 159 |
-
'timestamp': self.timestamp,
|
| 160 |
-
'repo_id': self.repo_id,
|
| 161 |
-
'evaluations': {}
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
# Collect all results
|
| 165 |
-
for eval_name, eval_results in self.results.items():
|
| 166 |
-
if eval_results:
|
| 167 |
-
summary['evaluations'][eval_name] = eval_results
|
| 168 |
-
|
| 169 |
-
# Save to JSON
|
| 170 |
-
summary_path = self.run_dir / "summary.json"
|
| 171 |
-
with open(summary_path, 'w') as f:
|
| 172 |
-
json.dump(summary, f, indent=2)
|
| 173 |
-
|
| 174 |
-
print(f"✅ Summary saved to: {summary_path}")
|
| 175 |
-
|
| 176 |
-
# Print summary
|
| 177 |
-
self.print_summary(summary)
|
| 178 |
-
|
| 179 |
-
return summary
|
| 180 |
-
|
| 181 |
-
def print_summary(self, summary):
|
| 182 |
-
"""Print formatted summary."""
|
| 183 |
-
print("\n" + "=" * 80)
|
| 184 |
-
print("📈 Evaluation Summary")
|
| 185 |
-
print("=" * 80)
|
| 186 |
-
print(f"\nRepository: {summary['repo_id']}")
|
| 187 |
-
print(f"Timestamp: {summary['timestamp']}\n")
|
| 188 |
-
|
| 189 |
-
for eval_name, eval_results in summary['evaluations'].items():
|
| 190 |
-
print(f"\n{'─' * 40}")
|
| 191 |
-
print(f"📊 {eval_name.upper()}")
|
| 192 |
-
print(f"{'─' * 40}")
|
| 193 |
-
|
| 194 |
-
if isinstance(eval_results, dict):
|
| 195 |
-
for key, value in eval_results.items():
|
| 196 |
-
if isinstance(value, (int, float)):
|
| 197 |
-
print(f" {key}: {value:.4f}")
|
| 198 |
-
else:
|
| 199 |
-
print(f" {key}: {value}")
|
| 200 |
-
|
| 201 |
-
print("\n" + "=" * 80)
|
| 202 |
-
|
| 203 |
-
def create_visualizations(self):
|
| 204 |
-
"""Create summary visualizations."""
|
| 205 |
-
print("\n" + "=" * 80)
|
| 206 |
-
print("📊 Creating Visualizations")
|
| 207 |
-
print("=" * 80)
|
| 208 |
-
|
| 209 |
-
# Create comparison chart
|
| 210 |
-
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
|
| 211 |
-
|
| 212 |
-
# Collect metrics
|
| 213 |
-
datasets = []
|
| 214 |
-
color_accuracies = []
|
| 215 |
-
hierarchy_accuracies = []
|
| 216 |
-
|
| 217 |
-
for eval_name, eval_results in self.results.items():
|
| 218 |
-
if eval_results and isinstance(eval_results, dict):
|
| 219 |
-
datasets.append(eval_name)
|
| 220 |
-
|
| 221 |
-
# Try to get color accuracy
|
| 222 |
-
color_acc = eval_results.get('color_nn_accuracy', 0)
|
| 223 |
-
color_accuracies.append(color_acc)
|
| 224 |
-
|
| 225 |
-
# Try to get hierarchy accuracy
|
| 226 |
-
hier_acc = eval_results.get('hierarchy_nn_accuracy', 0)
|
| 227 |
-
hierarchy_accuracies.append(hier_acc)
|
| 228 |
-
|
| 229 |
-
# Plot color accuracies
|
| 230 |
-
if color_accuracies:
|
| 231 |
-
axes[0].bar(datasets, color_accuracies, color='skyblue')
|
| 232 |
-
axes[0].set_title('Color Classification Accuracy', fontsize=14, fontweight='bold')
|
| 233 |
-
axes[0].set_ylabel('Accuracy', fontsize=12)
|
| 234 |
-
axes[0].set_ylim([0, 1])
|
| 235 |
-
axes[0].grid(axis='y', alpha=0.3)
|
| 236 |
-
|
| 237 |
-
# Add value labels
|
| 238 |
-
for i, v in enumerate(color_accuracies):
|
| 239 |
-
axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10)
|
| 240 |
-
|
| 241 |
-
# Plot hierarchy accuracies
|
| 242 |
-
if hierarchy_accuracies:
|
| 243 |
-
axes[1].bar(datasets, hierarchy_accuracies, color='lightcoral')
|
| 244 |
-
axes[1].set_title('Hierarchy Classification Accuracy', fontsize=14, fontweight='bold')
|
| 245 |
-
axes[1].set_ylabel('Accuracy', fontsize=12)
|
| 246 |
-
axes[1].set_ylim([0, 1])
|
| 247 |
-
axes[1].grid(axis='y', alpha=0.3)
|
| 248 |
-
|
| 249 |
-
# Add value labels
|
| 250 |
-
for i, v in enumerate(hierarchy_accuracies):
|
| 251 |
-
axes[1].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10)
|
| 252 |
-
|
| 253 |
-
plt.tight_layout()
|
| 254 |
-
|
| 255 |
-
# Save figure
|
| 256 |
-
fig_path = self.run_dir / "summary_comparison.png"
|
| 257 |
-
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
|
| 258 |
-
plt.close()
|
| 259 |
-
|
| 260 |
-
print(f"✅ Visualization saved to: {fig_path}")
|
| 261 |
-
|
| 262 |
-
def run_all(self):
|
| 263 |
-
"""Run all evaluations."""
|
| 264 |
-
print("=" * 80)
|
| 265 |
-
print("🚀 GAP-CLIP Comprehensive Evaluation")
|
| 266 |
-
print("=" * 80)
|
| 267 |
-
print(f"Repository: {self.repo_id}")
|
| 268 |
-
print(f"Output directory: {self.run_dir}\n")
|
| 269 |
-
|
| 270 |
-
# Load models
|
| 271 |
-
if not self.load_models():
|
| 272 |
-
print("❌ Failed to load models. Exiting.")
|
| 273 |
return False
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
self.
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
print("=
|
| 287 |
-
print(
|
| 288 |
-
print(f"
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
|
| 295 |
def main():
|
| 296 |
-
"""Main function for command-line usage."""
|
| 297 |
parser = argparse.ArgumentParser(
|
| 298 |
-
description="Run
|
| 299 |
-
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
)
|
| 301 |
-
|
| 302 |
parser.add_argument(
|
| 303 |
-
"--
|
| 304 |
type=str,
|
| 305 |
-
default="
|
| 306 |
-
help=
|
|
|
|
|
|
|
|
|
|
| 307 |
)
|
| 308 |
-
|
| 309 |
parser.add_argument(
|
| 310 |
"--output",
|
| 311 |
type=str,
|
| 312 |
default="evaluation_results",
|
| 313 |
-
help="
|
| 314 |
)
|
| 315 |
-
|
| 316 |
args = parser.parse_args()
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
success = runner.
|
| 325 |
sys.exit(0 if success else 1)
|
| 326 |
|
| 327 |
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
GAP-CLIP Evaluation Runner
|
| 4 |
+
===========================
|
| 5 |
|
| 6 |
+
Orchestrates all evaluation scripts, one per paper section. Each evaluation
|
| 7 |
+
is independent and can be run in isolation via ``--steps``.
|
| 8 |
|
| 9 |
+
Usage
|
| 10 |
+
-----
|
| 11 |
+
Run everything::
|
| 12 |
|
| 13 |
+
python evaluation/run_all_evaluations.py
|
| 14 |
+
|
| 15 |
+
Run specific sections::
|
| 16 |
+
|
| 17 |
+
python evaluation/run_all_evaluations.py --steps sec51,sec52
|
| 18 |
+
python evaluation/run_all_evaluations.py --steps annex92,annex93
|
| 19 |
+
|
| 20 |
+
Available steps
|
| 21 |
+
---------------
|
| 22 |
+
sec51 §5.1 Colour model accuracy (Table 1)
|
| 23 |
+
sec52 §5.2 Category model confusion matrix (Table 2)
|
| 24 |
+
sec533 §5.3.3 NN classification accuracy (Table 3)
|
| 25 |
+
sec5354 §5.3.4+5 Separation & zero-shot semantic eval
|
| 26 |
+
sec536 §5.3.6 Embedding structure Tests A/B/C (Table 4)
|
| 27 |
+
annex92 Annex 9.2 Pairwise colour similarity heatmaps
|
| 28 |
+
annex93 Annex 9.3 t-SNE visualisations
|
| 29 |
+
annex94 Annex 9.4 Fashion search demo
|
| 30 |
|
| 31 |
Author: Lea Attia Sarfati
|
| 32 |
"""
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
import argparse
|
| 35 |
+
import sys
|
| 36 |
+
import traceback
|
| 37 |
from datetime import datetime
|
| 38 |
+
from pathlib import Path
|
|
|
|
| 39 |
|
| 40 |
+
# Make sure the repo root is on the path so that `config` is importable.
|
| 41 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 42 |
|
| 43 |
+
ALL_STEPS = ["sec51", "sec52", "sec533", "sec5354", "sec536", "annex92", "annex93", "annex94"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
class EvaluationRunner:
|
| 47 |
+
"""Runs one or more evaluation sections and collects pass/fail status."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, output_dir: str = "evaluation_results"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
self.output_dir = Path(output_dir)
|
| 51 |
self.output_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
|
|
| 52 |
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 53 |
+
self.results: dict[str, str] = {} # step -> "ok" | "failed" | "skipped"
|
| 54 |
+
|
| 55 |
+
# ------------------------------------------------------------------
|
| 56 |
+
# Individual section runners (lazy imports to allow partial execution)
|
| 57 |
+
# ------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
def run_sec51(self):
|
| 60 |
+
"""§5.1 – Colour model accuracy (Table 1)."""
|
| 61 |
+
from sec51_color_model_eval import ColorEvaluator
|
| 62 |
+
import torch
|
| 63 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 64 |
+
evaluator = ColorEvaluator(device=device, output_dir=str(self.output_dir / "sec51"))
|
| 65 |
+
evaluator.run_full_evaluation()
|
| 66 |
+
|
| 67 |
+
def run_sec52(self):
|
| 68 |
+
"""§5.2 – Category model confusion matrix (Table 2)."""
|
| 69 |
+
from sec52_category_model_eval import CategoryModelEvaluator
|
| 70 |
+
import torch
|
| 71 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 72 |
+
evaluator = CategoryModelEvaluator(device=device, directory=str(self.output_dir / "sec52"))
|
| 73 |
+
evaluator.run_full_evaluation()
|
| 74 |
+
|
| 75 |
+
def run_sec533(self):
|
| 76 |
+
"""§5.3.3 – Nearest-neighbour classification accuracy (Table 3)."""
|
| 77 |
+
from sec533_clip_nn_accuracy import ColorHierarchyEvaluator
|
| 78 |
+
import torch
|
| 79 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 80 |
+
evaluator = ColorHierarchyEvaluator(
|
| 81 |
+
device=device,
|
| 82 |
+
directory=str(self.output_dir / "sec533"),
|
| 83 |
+
)
|
| 84 |
+
max_samples = 10_000
|
| 85 |
+
evaluator.evaluate_fashion_mnist(max_samples=max_samples)
|
| 86 |
+
evaluator.evaluate_kaggle_marqo(max_samples=max_samples)
|
| 87 |
+
evaluator.evaluate_local_validation(max_samples=max_samples)
|
| 88 |
+
evaluator.evaluate_baseline_fashion_mnist(max_samples=max_samples)
|
| 89 |
+
evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples)
|
| 90 |
+
evaluator.evaluate_baseline_local_validation(max_samples=max_samples)
|
| 91 |
+
|
| 92 |
+
def run_sec5354(self):
|
| 93 |
+
"""§5.3.4+5 – Embedding separation & zero-shot semantic eval."""
|
| 94 |
+
# sec5354 has a self-contained __main__ block that handles dataset loading.
|
| 95 |
+
import runpy
|
| 96 |
+
runpy.run_path(
|
| 97 |
+
str(Path(__file__).parent / "sec5354_separation_semantic.py"),
|
| 98 |
+
run_name="__main__",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def run_sec536(self):
|
| 102 |
+
"""§5.3.6 – Embedding structure Tests A/B/C."""
|
| 103 |
+
from sec536_embedding_structure import main as sec536_main
|
| 104 |
+
sec536_main(selected_tests=["A", "B", "C"])
|
| 105 |
+
|
| 106 |
+
def run_annex92(self):
|
| 107 |
+
"""Annex 9.2 – Pairwise colour similarity heatmaps."""
|
| 108 |
+
# annex92 is a self-contained script; run its __main__ guard.
|
| 109 |
+
import importlib, runpy
|
| 110 |
+
runpy.run_path(
|
| 111 |
+
str(Path(__file__).parent / "annex92_color_heatmaps.py"),
|
| 112 |
+
run_name="__main__",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def run_annex93(self):
|
| 116 |
+
"""Annex 9.3 – t-SNE visualisations."""
|
| 117 |
+
import runpy
|
| 118 |
+
runpy.run_path(
|
| 119 |
+
str(Path(__file__).parent / "annex93_tsne.py"),
|
| 120 |
+
run_name="__main__",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def run_annex94(self):
|
| 124 |
+
"""Annex 9.4 – Fashion search demo."""
|
| 125 |
+
import runpy
|
| 126 |
+
runpy.run_path(
|
| 127 |
+
str(Path(__file__).parent / "annex94_search_demo.py"),
|
| 128 |
+
run_name="__main__",
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# ------------------------------------------------------------------
|
| 132 |
+
# Orchestration
|
| 133 |
+
# ------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
def _run_step(self, step: str) -> bool:
|
| 136 |
+
method = getattr(self, f"run_{step.replace('-', '_')}", None)
|
| 137 |
+
if method is None:
|
| 138 |
+
print(f"⚠️ Unknown step '{step}' – skipping.")
|
| 139 |
+
self.results[step] = "skipped"
|
| 140 |
return False
|
| 141 |
+
|
| 142 |
+
print(f"\n{'='*70}")
|
| 143 |
+
print(f"▶ Running {step} ({method.__doc__ or ''})")
|
| 144 |
+
print(f"{'='*70}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
try:
|
| 146 |
+
method()
|
| 147 |
+
self.results[step] = "ok"
|
| 148 |
+
print(f"✅ {step} completed successfully.")
|
| 149 |
+
return True
|
| 150 |
+
except Exception:
|
| 151 |
+
self.results[step] = "failed"
|
| 152 |
+
print(f"❌ {step} FAILED:")
|
| 153 |
+
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
return False
|
| 155 |
+
|
| 156 |
+
def run(self, steps: list[str]) -> bool:
|
| 157 |
+
print("=" * 70)
|
| 158 |
+
print(f"🚀 GAP-CLIP Evaluation ({self.timestamp})")
|
| 159 |
+
print(f" Steps: {', '.join(steps)}")
|
| 160 |
+
print(f" Output: {self.output_dir}")
|
| 161 |
+
print("=" * 70)
|
| 162 |
+
|
| 163 |
+
for step in steps:
|
| 164 |
+
self._run_step(step)
|
| 165 |
+
|
| 166 |
+
# Summary
|
| 167 |
+
print(f"\n{'='*70}")
|
| 168 |
+
print("📊 Summary")
|
| 169 |
+
print(f"{'='*70}")
|
| 170 |
+
all_ok = True
|
| 171 |
+
for step in steps:
|
| 172 |
+
status = self.results.get(step, "skipped")
|
| 173 |
+
icon = {"ok": "✅", "failed": "❌", "skipped": "⚠️ "}.get(status, "?")
|
| 174 |
+
print(f" {icon} {step:15s} {status}")
|
| 175 |
+
if status == "failed":
|
| 176 |
+
all_ok = False
|
| 177 |
+
|
| 178 |
+
print("=" * 70)
|
| 179 |
+
return all_ok
|
| 180 |
|
| 181 |
|
| 182 |
def main():
|
|
|
|
| 183 |
parser = argparse.ArgumentParser(
|
| 184 |
+
description="Run GAP-CLIP evaluations.",
|
| 185 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 186 |
+
epilog="\n".join(
|
| 187 |
+
[
|
| 188 |
+
"Available steps:",
|
| 189 |
+
" sec51 §5.1 Colour model (Table 1)",
|
| 190 |
+
" sec52 §5.2 Category model (Table 2)",
|
| 191 |
+
" sec533 §5.3.3 NN accuracy (Table 3)",
|
| 192 |
+
" sec5354 §5.3.4+5 Separation & semantic eval",
|
| 193 |
+
" sec536 §5.3.6 Embedding structure tests (Table 4)",
|
| 194 |
+
" annex92 Annex 9.2 Colour heatmaps",
|
| 195 |
+
" annex93 Annex 9.3 t-SNE",
|
| 196 |
+
" annex94 Annex 9.4 Search demo",
|
| 197 |
+
]
|
| 198 |
+
),
|
| 199 |
)
|
|
|
|
| 200 |
parser.add_argument(
|
| 201 |
+
"--steps",
|
| 202 |
type=str,
|
| 203 |
+
default="all",
|
| 204 |
+
help=(
|
| 205 |
+
"Comma-separated list of steps to run, or 'all' to run everything "
|
| 206 |
+
"(default: all). Example: --steps sec51,sec52,sec536"
|
| 207 |
+
),
|
| 208 |
)
|
|
|
|
| 209 |
parser.add_argument(
|
| 210 |
"--output",
|
| 211 |
type=str,
|
| 212 |
default="evaluation_results",
|
| 213 |
+
help="Directory to save results (default: evaluation_results).",
|
| 214 |
)
|
|
|
|
| 215 |
args = parser.parse_args()
|
| 216 |
+
|
| 217 |
+
if args.steps.strip().lower() == "all":
|
| 218 |
+
steps = ALL_STEPS
|
| 219 |
+
else:
|
| 220 |
+
steps = [s.strip() for s in args.steps.split(",") if s.strip()]
|
| 221 |
+
|
| 222 |
+
runner = EvaluationRunner(output_dir=args.output)
|
| 223 |
+
success = runner.run(steps)
|
| 224 |
sys.exit(0 if success else 1)
|
| 225 |
|
| 226 |
|
evaluation/{color_evaluation.py → sec51_color_model_eval.py}
RENAMED
|
@@ -1,6 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
|
|
|
|
|
|
| 3 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import pandas as pd
|
|
@@ -19,6 +44,12 @@ from io import BytesIO
|
|
| 19 |
import warnings
|
| 20 |
warnings.filterwarnings('ignore')
|
| 21 |
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
from config import (
|
| 24 |
color_model_path,
|
|
@@ -26,8 +57,9 @@ from config import (
|
|
| 26 |
local_dataset_path,
|
| 27 |
column_local_image_path,
|
| 28 |
tokeniser_path,
|
|
|
|
| 29 |
)
|
| 30 |
-
from color_model import ColorCLIP, Tokenizer
|
| 31 |
|
| 32 |
|
| 33 |
class KaggleDataset(Dataset):
|
|
@@ -145,17 +177,33 @@ class LocalDataset(Dataset):
|
|
| 145 |
|
| 146 |
def __getitem__(self, idx):
|
| 147 |
row = self.dataframe.iloc[idx]
|
| 148 |
-
|
| 149 |
-
# Load image from local path
|
| 150 |
-
image_path = row[column_local_image_path]
|
| 151 |
try:
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
except Exception as e:
|
| 154 |
-
print(f"Error loading image at index {idx} from {image_path}: {e}")
|
| 155 |
-
# Create a dummy image if loading fails
|
| 156 |
image = Image.new('RGB', (224, 224), color='gray')
|
| 157 |
-
|
| 158 |
-
# Apply
|
| 159 |
image = self.transform(image)
|
| 160 |
|
| 161 |
# Get text and labels
|
|
@@ -172,9 +220,10 @@ def load_local_validation_dataset(max_samples=5000):
|
|
| 172 |
df = pd.read_csv(local_dataset_path)
|
| 173 |
print(f"✅ Dataset loaded: {len(df)} samples")
|
| 174 |
|
| 175 |
-
# Filter out rows with NaN values in image path
|
| 176 |
-
|
| 177 |
-
|
|
|
|
| 178 |
|
| 179 |
# Filter for colors that were used during training (11 colors)
|
| 180 |
valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
|
|
@@ -224,10 +273,18 @@ def collate_fn_filter_none(batch):
|
|
| 224 |
class ColorEvaluator:
|
| 225 |
"""Evaluate color 16 embeddings"""
|
| 226 |
|
| 227 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
self.device = torch.device(device)
|
| 229 |
self.directory = directory
|
| 230 |
self.color_emb_dim = color_emb_dim
|
|
|
|
|
|
|
| 231 |
os.makedirs(self.directory, exist_ok=True)
|
| 232 |
|
| 233 |
# Load baseline Fashion CLIP model
|
|
@@ -248,23 +305,34 @@ class ColorEvaluator:
|
|
| 248 |
if self.color_model is not None and self.color_tokenizer is not None:
|
| 249 |
return
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
-
print("🎨 Loading specialized color model (16D)...")
|
| 257 |
-
|
| 258 |
-
# Load checkpoint first to get the actual vocab size
|
| 259 |
-
state_dict = torch.load(color_model_path, map_location=self.device)
|
| 260 |
-
|
| 261 |
# Get vocab size from the embedding weight shape in checkpoint
|
| 262 |
vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
|
| 263 |
print(f" Detected vocab size from checkpoint: {vocab_size}")
|
| 264 |
-
|
| 265 |
-
# Load tokenizer vocab
|
| 266 |
-
with open(tokeniser_path, "r") as f:
|
| 267 |
-
vocab = json.load(f)
|
| 268 |
|
| 269 |
self.color_tokenizer = Tokenizer()
|
| 270 |
self.color_tokenizer.load_vocab(vocab)
|
|
@@ -541,8 +609,8 @@ class ColorEvaluator:
|
|
| 541 |
|
| 542 |
accuracy = accuracy_score(filtered_labels, filtered_predictions)
|
| 543 |
fig, acc, cm = self.create_confusion_matrix(
|
| 544 |
-
filtered_labels, filtered_predictions,
|
| 545 |
-
|
| 546 |
label_type
|
| 547 |
)
|
| 548 |
unique_labels = sorted(list(set(filtered_labels)))
|
|
@@ -578,15 +646,15 @@ class ColorEvaluator:
|
|
| 578 |
image_full_embeddings, image_colors_full = self.extract_color_embeddings(dataloader, embedding_type='image', max_samples=max_samples)
|
| 579 |
text_color_metrics = self.compute_similarity_metrics(text_full_embeddings, text_colors_full)
|
| 580 |
text_color_class = self.evaluate_classification_performance(
|
| 581 |
-
text_full_embeddings, text_colors_full,
|
| 582 |
-
"
|
| 583 |
)
|
| 584 |
text_color_metrics.update(text_color_class)
|
| 585 |
results['text_color'] = text_color_metrics
|
| 586 |
image_color_metrics = self.compute_similarity_metrics(image_full_embeddings, image_colors_full)
|
| 587 |
image_color_class = self.evaluate_classification_performance(
|
| 588 |
image_full_embeddings, image_colors_full,
|
| 589 |
-
"
|
| 590 |
)
|
| 591 |
image_color_metrics.update(image_color_class)
|
| 592 |
results['image_color'] = image_color_metrics
|
|
@@ -628,7 +696,7 @@ class ColorEvaluator:
|
|
| 628 |
print(f" Text color embeddings shape: {text_color_embeddings.shape}")
|
| 629 |
text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors)
|
| 630 |
text_color_class = self.evaluate_classification_performance(
|
| 631 |
-
text_color_embeddings, text_colors, "
|
| 632 |
)
|
| 633 |
text_color_metrics.update(text_color_class)
|
| 634 |
results['text_color'] = text_color_metrics
|
|
@@ -642,7 +710,7 @@ class ColorEvaluator:
|
|
| 642 |
print(f" Image color embeddings shape: {image_color_embeddings.shape}")
|
| 643 |
image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors)
|
| 644 |
image_color_class = self.evaluate_classification_performance(
|
| 645 |
-
image_color_embeddings, image_colors, "
|
| 646 |
)
|
| 647 |
image_color_metrics.update(image_color_class)
|
| 648 |
results['image_color'] = image_color_metrics
|
|
@@ -687,7 +755,7 @@ class ColorEvaluator:
|
|
| 687 |
text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
|
| 688 |
|
| 689 |
text_color_classification = self.evaluate_classification_performance(
|
| 690 |
-
text_embeddings, text_colors, "
|
| 691 |
)
|
| 692 |
text_color_metrics.update(text_color_classification)
|
| 693 |
results['text'] = {
|
|
@@ -705,7 +773,7 @@ class ColorEvaluator:
|
|
| 705 |
image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
|
| 706 |
|
| 707 |
image_color_classification = self.evaluate_classification_performance(
|
| 708 |
-
image_embeddings, image_colors, "
|
| 709 |
)
|
| 710 |
image_color_metrics.update(image_color_classification)
|
| 711 |
results['image'] = {
|
|
@@ -755,7 +823,7 @@ class ColorEvaluator:
|
|
| 755 |
text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
|
| 756 |
|
| 757 |
text_color_classification = self.evaluate_classification_performance(
|
| 758 |
-
text_embeddings, text_colors, "
|
| 759 |
)
|
| 760 |
text_color_metrics.update(text_color_classification)
|
| 761 |
results['text'] = {
|
|
@@ -773,7 +841,7 @@ class ColorEvaluator:
|
|
| 773 |
image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
|
| 774 |
|
| 775 |
image_color_classification = self.evaluate_classification_performance(
|
| 776 |
-
image_embeddings, image_colors, "
|
| 777 |
)
|
| 778 |
image_color_metrics.update(image_color_classification)
|
| 779 |
results['image'] = {
|
|
@@ -798,49 +866,99 @@ class ColorEvaluator:
|
|
| 798 |
|
| 799 |
return results
|
| 800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
|
| 802 |
if __name__ == "__main__":
|
| 803 |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 804 |
print(f"Using device: {device}")
|
| 805 |
|
| 806 |
-
directory = '
|
| 807 |
max_samples = 10000
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
print("
|
| 814 |
-
print("
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
print("
|
| 819 |
-
print(
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
print(
|
| 823 |
-
print(f"
|
| 824 |
-
|
| 825 |
-
#
|
| 826 |
-
|
| 827 |
-
print("
|
| 828 |
-
print("
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
print("
|
| 833 |
-
print(
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
print(
|
| 837 |
-
print(f"
|
|
|
|
| 838 |
|
| 839 |
# Evaluate Local Validation Dataset
|
| 840 |
print("\n" + "="*60)
|
| 841 |
print("🚀 Starting evaluation of Local Validation Dataset with Color embeddings")
|
| 842 |
print("="*60)
|
| 843 |
-
results_local = evaluator.evaluate_local_validation(max_samples=
|
| 844 |
|
| 845 |
if results_local is not None:
|
| 846 |
print(f"\n{'='*60}")
|
|
@@ -855,7 +973,7 @@ if __name__ == "__main__":
|
|
| 855 |
print("\n" + "="*60)
|
| 856 |
print("🚀 Starting evaluation of Baseline Fashion CLIP on Local Validation")
|
| 857 |
print("="*60)
|
| 858 |
-
results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=
|
| 859 |
|
| 860 |
if results_baseline_local is not None:
|
| 861 |
print(f"\n{'='*60}")
|
|
@@ -867,4 +985,4 @@ if __name__ == "__main__":
|
|
| 867 |
print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}")
|
| 868 |
|
| 869 |
|
| 870 |
-
print(f"\n✅ Evaluation completed! Check '{directory}/' for visualization files.")
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Section 5.1 — Color Model Evaluation (Table 1)
|
| 3 |
+
===============================================
|
| 4 |
+
|
| 5 |
+
Evaluates the standalone 16D color model (ColorCLIP) on accuracy and
|
| 6 |
+
separation scores across:
|
| 7 |
+
- KAGL Marqo (external, 10k items, 46 colors)
|
| 8 |
+
- Local validation dataset (internal, 5k items, 11 colors)
|
| 9 |
+
|
| 10 |
+
Metrics reported match Table 1 in the paper:
|
| 11 |
+
- Text/image embedding NN accuracy
|
| 12 |
+
- Text/image embedding separation score (intra - inter class distance)
|
| 13 |
+
|
| 14 |
+
Compared against Fashion-CLIP baseline (patrickjohncyh/fashion-clip).
|
| 15 |
+
|
| 16 |
+
Run directly:
|
| 17 |
+
python sec51_color_model_eval.py
|
| 18 |
+
|
| 19 |
+
Paper reference: Section 5.1, Table 1.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
import os
|
| 23 |
import json
|
| 24 |
+
import hashlib
|
| 25 |
+
import requests
|
| 26 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 27 |
+
import sys
|
| 28 |
+
from pathlib import Path
|
| 29 |
|
| 30 |
import torch
|
| 31 |
import pandas as pd
|
|
|
|
| 44 |
import warnings
|
| 45 |
warnings.filterwarnings('ignore')
|
| 46 |
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 47 |
+
from huggingface_hub import hf_hub_download
|
| 48 |
+
|
| 49 |
+
# Ensure project root is importable when running this file directly.
|
| 50 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 51 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 52 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 53 |
|
| 54 |
from config import (
|
| 55 |
color_model_path,
|
|
|
|
| 57 |
local_dataset_path,
|
| 58 |
column_local_image_path,
|
| 59 |
tokeniser_path,
|
| 60 |
+
images_dir,
|
| 61 |
)
|
| 62 |
+
from training.color_model import ColorCLIP, Tokenizer
|
| 63 |
|
| 64 |
|
| 65 |
class KaggleDataset(Dataset):
|
|
|
|
| 177 |
|
| 178 |
def __getitem__(self, idx):
|
| 179 |
row = self.dataframe.iloc[idx]
|
| 180 |
+
|
|
|
|
|
|
|
| 181 |
try:
|
| 182 |
+
# Try local path first
|
| 183 |
+
image_path = row.get(column_local_image_path) if hasattr(row, 'get') else None
|
| 184 |
+
if isinstance(image_path, str) and image_path and os.path.exists(image_path):
|
| 185 |
+
image = Image.open(image_path).convert("RGB")
|
| 186 |
+
else:
|
| 187 |
+
# Fallback: download from image_url with caching
|
| 188 |
+
image_url = row.get('image_url') if hasattr(row, 'get') else None
|
| 189 |
+
if isinstance(image_url, str) and image_url:
|
| 190 |
+
cache_dir = Path(images_dir)
|
| 191 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 192 |
+
url_hash = hashlib.md5(image_url.encode("utf-8")).hexdigest()
|
| 193 |
+
cache_path = cache_dir / f"{url_hash}.jpg"
|
| 194 |
+
if cache_path.exists():
|
| 195 |
+
image = Image.open(cache_path).convert("RGB")
|
| 196 |
+
else:
|
| 197 |
+
resp = requests.get(image_url, timeout=10)
|
| 198 |
+
resp.raise_for_status()
|
| 199 |
+
image = Image.open(BytesIO(resp.content)).convert("RGB")
|
| 200 |
+
image.save(cache_path, "JPEG", quality=85, optimize=True)
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError("No valid image_path or image_url")
|
| 203 |
except Exception as e:
|
|
|
|
|
|
|
| 204 |
image = Image.new('RGB', (224, 224), color='gray')
|
| 205 |
+
|
| 206 |
+
# Apply transform
|
| 207 |
image = self.transform(image)
|
| 208 |
|
| 209 |
# Get text and labels
|
|
|
|
| 220 |
df = pd.read_csv(local_dataset_path)
|
| 221 |
print(f"✅ Dataset loaded: {len(df)} samples")
|
| 222 |
|
| 223 |
+
# Filter out rows with NaN values in image path (use whichever column exists)
|
| 224 |
+
img_col = column_local_image_path if column_local_image_path in df.columns else 'image_url'
|
| 225 |
+
df_clean = df.dropna(subset=[img_col])
|
| 226 |
+
print(f"📊 After filtering NaN image paths ({img_col}): {len(df_clean)} samples")
|
| 227 |
|
| 228 |
# Filter for colors that were used during training (11 colors)
|
| 229 |
valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
|
|
|
|
| 273 |
class ColorEvaluator:
|
| 274 |
"""Evaluate color 16 embeddings"""
|
| 275 |
|
| 276 |
+
def __init__(
|
| 277 |
+
self,
|
| 278 |
+
device='mps',
|
| 279 |
+
directory="figures/confusion_matrices/cm_color",
|
| 280 |
+
repo_id="Leacb4/gap-clip",
|
| 281 |
+
cache_dir="./models_cache",
|
| 282 |
+
):
|
| 283 |
self.device = torch.device(device)
|
| 284 |
self.directory = directory
|
| 285 |
self.color_emb_dim = color_emb_dim
|
| 286 |
+
self.repo_id = repo_id
|
| 287 |
+
self.cache_dir = cache_dir
|
| 288 |
os.makedirs(self.directory, exist_ok=True)
|
| 289 |
|
| 290 |
# Load baseline Fashion CLIP model
|
|
|
|
| 305 |
if self.color_model is not None and self.color_tokenizer is not None:
|
| 306 |
return
|
| 307 |
|
| 308 |
+
local_model_exists = os.path.exists(color_model_path)
|
| 309 |
+
local_tokenizer_exists = os.path.exists(tokeniser_path)
|
| 310 |
+
|
| 311 |
+
if local_model_exists and local_tokenizer_exists:
|
| 312 |
+
print("🎨 Loading specialized color model (16D) from local files...")
|
| 313 |
+
state_dict = torch.load(color_model_path, map_location=self.device)
|
| 314 |
+
with open(tokeniser_path, "r") as f:
|
| 315 |
+
vocab = json.load(f)
|
| 316 |
+
else:
|
| 317 |
+
print("🎨 Local color model/tokenizer not found. Loading from Hugging Face...")
|
| 318 |
+
print(f" Repo: {self.repo_id}")
|
| 319 |
+
hf_model_path = hf_hub_download(
|
| 320 |
+
repo_id=self.repo_id,
|
| 321 |
+
filename="color_model.pt",
|
| 322 |
+
cache_dir=self.cache_dir,
|
| 323 |
+
)
|
| 324 |
+
hf_vocab_path = hf_hub_download(
|
| 325 |
+
repo_id=self.repo_id,
|
| 326 |
+
filename="tokenizer_vocab.json",
|
| 327 |
+
cache_dir=self.cache_dir,
|
| 328 |
+
)
|
| 329 |
+
state_dict = torch.load(hf_model_path, map_location=self.device)
|
| 330 |
+
with open(hf_vocab_path, "r") as f:
|
| 331 |
+
vocab = json.load(f)
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
# Get vocab size from the embedding weight shape in checkpoint
|
| 334 |
vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
|
| 335 |
print(f" Detected vocab size from checkpoint: {vocab_size}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
self.color_tokenizer = Tokenizer()
|
| 338 |
self.color_tokenizer.load_vocab(vocab)
|
|
|
|
| 609 |
|
| 610 |
accuracy = accuracy_score(filtered_labels, filtered_predictions)
|
| 611 |
fig, acc, cm = self.create_confusion_matrix(
|
| 612 |
+
filtered_labels, filtered_predictions,
|
| 613 |
+
embedding_type,
|
| 614 |
label_type
|
| 615 |
)
|
| 616 |
unique_labels = sorted(list(set(filtered_labels)))
|
|
|
|
| 646 |
image_full_embeddings, image_colors_full = self.extract_color_embeddings(dataloader, embedding_type='image', max_samples=max_samples)
|
| 647 |
text_color_metrics = self.compute_similarity_metrics(text_full_embeddings, text_colors_full)
|
| 648 |
text_color_class = self.evaluate_classification_performance(
|
| 649 |
+
text_full_embeddings, text_colors_full,
|
| 650 |
+
"KAGL Marqo, text, color confusion matrix", "Color",
|
| 651 |
)
|
| 652 |
text_color_metrics.update(text_color_class)
|
| 653 |
results['text_color'] = text_color_metrics
|
| 654 |
image_color_metrics = self.compute_similarity_metrics(image_full_embeddings, image_colors_full)
|
| 655 |
image_color_class = self.evaluate_classification_performance(
|
| 656 |
image_full_embeddings, image_colors_full,
|
| 657 |
+
"KAGL Marqo, image, color confusion matrix", "Color",
|
| 658 |
)
|
| 659 |
image_color_metrics.update(image_color_class)
|
| 660 |
results['image_color'] = image_color_metrics
|
|
|
|
| 696 |
print(f" Text color embeddings shape: {text_color_embeddings.shape}")
|
| 697 |
text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors)
|
| 698 |
text_color_class = self.evaluate_classification_performance(
|
| 699 |
+
text_color_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color"
|
| 700 |
)
|
| 701 |
text_color_metrics.update(text_color_class)
|
| 702 |
results['text_color'] = text_color_metrics
|
|
|
|
| 710 |
print(f" Image color embeddings shape: {image_color_embeddings.shape}")
|
| 711 |
image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors)
|
| 712 |
image_color_class = self.evaluate_classification_performance(
|
| 713 |
+
image_color_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color"
|
| 714 |
)
|
| 715 |
image_color_metrics.update(image_color_class)
|
| 716 |
results['image_color'] = image_color_metrics
|
|
|
|
| 755 |
text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
|
| 756 |
|
| 757 |
text_color_classification = self.evaluate_classification_performance(
|
| 758 |
+
text_embeddings, text_colors, "KAGL Marqo, text, color confusion matrix", "Color"
|
| 759 |
)
|
| 760 |
text_color_metrics.update(text_color_classification)
|
| 761 |
results['text'] = {
|
|
|
|
| 773 |
image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
|
| 774 |
|
| 775 |
image_color_classification = self.evaluate_classification_performance(
|
| 776 |
+
image_embeddings, image_colors, "KAGL Marqo, image, color confusion matrix", "Color"
|
| 777 |
)
|
| 778 |
image_color_metrics.update(image_color_classification)
|
| 779 |
results['image'] = {
|
|
|
|
| 823 |
text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
|
| 824 |
|
| 825 |
text_color_classification = self.evaluate_classification_performance(
|
| 826 |
+
text_embeddings, text_colors, "Test Dataset, text, color confusion matrix", "Color"
|
| 827 |
)
|
| 828 |
text_color_metrics.update(text_color_classification)
|
| 829 |
results['text'] = {
|
|
|
|
| 841 |
image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
|
| 842 |
|
| 843 |
image_color_classification = self.evaluate_classification_performance(
|
| 844 |
+
image_embeddings, image_colors, "Test Dataset, image, color confusion matrix", "Color"
|
| 845 |
)
|
| 846 |
image_color_metrics.update(image_color_classification)
|
| 847 |
results['image'] = {
|
|
|
|
| 866 |
|
| 867 |
return results
|
| 868 |
|
| 869 |
+
def analyze_baseline_vs_trained_performance(self, results_trained, results_baseline, dataset_name):
|
| 870 |
+
"""
|
| 871 |
+
Analyse et explique pourquoi la baseline peut performer mieux que le modèle entraîné
|
| 872 |
+
|
| 873 |
+
Raisons possibles:
|
| 874 |
+
1. Capacité dimensionnelle: Baseline utilise toutes les dimensions (512), modèle entraîné utilise seulement des sous-espaces (17 ou 64 dims)
|
| 875 |
+
2. Distribution shift: Dataset de validation différent de celui d'entraînement
|
| 876 |
+
3. Overfitting: Modèle trop spécialisé sur le dataset d'entraînement
|
| 877 |
+
4. Généralisation: Baseline pré-entraînée sur un dataset plus large et diversifié
|
| 878 |
+
5. Perte d'information: Spécialisation excessive peut causer perte d'information générale
|
| 879 |
+
"""
|
| 880 |
+
print(f"\n{'='*60}")
|
| 881 |
+
print(f"📊 ANALYSE: Baseline vs Modèle Entraîné - {dataset_name}")
|
| 882 |
+
print(f"{'='*60}")
|
| 883 |
+
|
| 884 |
+
# Comparer les métriques pour chaque type d'embedding
|
| 885 |
+
comparisons = []
|
| 886 |
+
|
| 887 |
+
# Text Color
|
| 888 |
+
trained_color_text_acc = results_trained.get('text_color', {}).get('accuracy', 0)
|
| 889 |
+
baseline_color_text_acc = results_baseline.get('text', {}).get('color', {}).get('accuracy', 0)
|
| 890 |
+
if trained_color_text_acc > 0 and baseline_color_text_acc > 0:
|
| 891 |
+
diff = baseline_color_text_acc - trained_color_text_acc
|
| 892 |
+
comparisons.append({
|
| 893 |
+
'type': 'Text Color',
|
| 894 |
+
'trained': trained_color_text_acc,
|
| 895 |
+
'baseline': baseline_color_text_acc,
|
| 896 |
+
'diff': diff,
|
| 897 |
+
'trained_dims': '0-15 (16 dims)',
|
| 898 |
+
'baseline_dims': 'All dimensions (512 dims)'
|
| 899 |
+
})
|
| 900 |
+
|
| 901 |
+
# Image Color
|
| 902 |
+
trained_color_img_acc = results_trained.get('image_color', {}).get('accuracy', 0)
|
| 903 |
+
baseline_color_img_acc = results_baseline.get('image', {}).get('color', {}).get('accuracy', 0)
|
| 904 |
+
if trained_color_img_acc > 0 and baseline_color_img_acc > 0:
|
| 905 |
+
diff = baseline_color_img_acc - trained_color_img_acc
|
| 906 |
+
comparisons.append({
|
| 907 |
+
'type': 'Image Color',
|
| 908 |
+
'trained': trained_color_img_acc,
|
| 909 |
+
'baseline': baseline_color_img_acc,
|
| 910 |
+
'diff': diff,
|
| 911 |
+
'trained_dims': '0-15 (16 dims)',
|
| 912 |
+
'baseline_dims': 'All dimensions (512 dims)'
|
| 913 |
+
})
|
| 914 |
+
|
| 915 |
+
return comparisons
|
| 916 |
+
|
| 917 |
+
|
| 918 |
|
| 919 |
if __name__ == "__main__":
|
| 920 |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 921 |
print(f"Using device: {device}")
|
| 922 |
|
| 923 |
+
directory = 'figures/confusion_matrices/cm_color'
|
| 924 |
max_samples = 10000
|
| 925 |
+
local_max_samples = 1000
|
| 926 |
+
|
| 927 |
+
evaluator = ColorEvaluator(device=device, directory=directory, repo_id="Leacb4/gap-clip")
|
| 928 |
+
|
| 929 |
+
# # Evaluate KAGL Marqo (skipped — CMs already generated)
|
| 930 |
+
# print("\n" + "="*60)
|
| 931 |
+
# print("🚀 Starting evaluation of KAGL Marqo with Color embeddings")
|
| 932 |
+
# print("="*60)
|
| 933 |
+
# results_kaggle = evaluator.evaluate_kaggle_marqo(max_samples=max_samples)
|
| 934 |
+
#
|
| 935 |
+
# print(f"\n{'='*60}")
|
| 936 |
+
# print("KAGL MARQO EVALUATION SUMMARY")
|
| 937 |
+
# print(f"{'='*60}")
|
| 938 |
+
#
|
| 939 |
+
# print("\n🎨 COLOR CLASSIFICATION RESULTS:")
|
| 940 |
+
# print(f" Text - NN Acc: {results_kaggle['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_color']['separation_score']:.4f}")
|
| 941 |
+
# print(f" Image - NN Acc: {results_kaggle['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_color']['separation_score']:.4f}")
|
| 942 |
+
#
|
| 943 |
+
# # Evaluate Baseline Fashion CLIP on KAGL Marqo
|
| 944 |
+
# print("\n" + "="*60)
|
| 945 |
+
# print("🚀 Starting evaluation of Baseline Fashion CLIP on KAGL Marqo")
|
| 946 |
+
# print("="*60)
|
| 947 |
+
# results_baseline_kaggle = evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples)
|
| 948 |
+
#
|
| 949 |
+
# print(f"\n{'='*60}")
|
| 950 |
+
# print("BASELINE KAGL MARQO EVALUATION SUMMARY")
|
| 951 |
+
# print(f"{'='*60}")
|
| 952 |
+
#
|
| 953 |
+
# print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):")
|
| 954 |
+
# print(f" Text - NN Acc: {results_baseline_kaggle['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['color']['separation_score']:.4f}")
|
| 955 |
+
# print(f" Image - NN Acc: {results_baseline_kaggle['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['color']['separation_score']:.4f}")
|
| 956 |
|
| 957 |
# Evaluate Local Validation Dataset
|
| 958 |
print("\n" + "="*60)
|
| 959 |
print("🚀 Starting evaluation of Local Validation Dataset with Color embeddings")
|
| 960 |
print("="*60)
|
| 961 |
+
results_local = evaluator.evaluate_local_validation(max_samples=local_max_samples)
|
| 962 |
|
| 963 |
if results_local is not None:
|
| 964 |
print(f"\n{'='*60}")
|
|
|
|
| 973 |
print("\n" + "="*60)
|
| 974 |
print("🚀 Starting evaluation of Baseline Fashion CLIP on Local Validation")
|
| 975 |
print("="*60)
|
| 976 |
+
results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=local_max_samples)
|
| 977 |
|
| 978 |
if results_baseline_local is not None:
|
| 979 |
print(f"\n{'='*60}")
|
|
|
|
| 985 |
print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}")
|
| 986 |
|
| 987 |
|
| 988 |
+
print(f"\n✅ Evaluation completed! Check '{directory}/' for visualization files.")
|
evaluation/sec52_category_model_eval.py
ADDED
|
@@ -0,0 +1,1212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Section 5.2 — Category Model Evaluation (Table 2)
|
| 3 |
+
==================================================
|
| 4 |
+
|
| 5 |
+
Evaluates GAP-CLIP vs the Fashion-CLIP baseline on hierarchy (category)
|
| 6 |
+
classification using three datasets:
|
| 7 |
+
- Fashion-MNIST (10 categories)
|
| 8 |
+
- KAGL Marqo (external, real-world fashion e-commerce)
|
| 9 |
+
- Internal validation dataset
|
| 10 |
+
|
| 11 |
+
Produces hierarchy confusion matrices (text + image) for both models on each
|
| 12 |
+
dataset.
|
| 13 |
+
|
| 14 |
+
Metrics match Table 2 in the paper:
|
| 15 |
+
- Text/image embedding NN accuracy
|
| 16 |
+
- Text/image embedding separation score
|
| 17 |
+
|
| 18 |
+
Run directly:
|
| 19 |
+
python sec52_category_model_eval.py
|
| 20 |
+
|
| 21 |
+
Paper reference: Section 5.2, Table 2.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import pandas as pd
|
| 29 |
+
import numpy as np
|
| 30 |
+
import matplotlib.pyplot as plt
|
| 31 |
+
import seaborn as sns
|
| 32 |
+
import difflib
|
| 33 |
+
from collections import defaultdict
|
| 34 |
+
import hashlib
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
import requests
|
| 37 |
+
|
| 38 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 39 |
+
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
|
| 40 |
+
from sklearn.preprocessing import normalize
|
| 41 |
+
|
| 42 |
+
from tqdm import tqdm
|
| 43 |
+
from torch.utils.data import Dataset, DataLoader
|
| 44 |
+
from torchvision import transforms
|
| 45 |
+
from PIL import Image
|
| 46 |
+
from io import BytesIO
|
| 47 |
+
|
| 48 |
+
import warnings
|
| 49 |
+
warnings.filterwarnings('ignore')
|
| 50 |
+
|
| 51 |
+
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 52 |
+
|
| 53 |
+
from config import (
|
| 54 |
+
main_model_path,
|
| 55 |
+
hierarchy_model_path,
|
| 56 |
+
color_emb_dim,
|
| 57 |
+
hierarchy_emb_dim,
|
| 58 |
+
local_dataset_path,
|
| 59 |
+
column_local_image_path,
|
| 60 |
+
images_dir,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# ============================================================================
|
| 64 |
+
# 1. Fashion-MNIST utilities
|
| 65 |
+
# ============================================================================
|
| 66 |
+
|
| 67 |
+
def get_fashion_mnist_labels():
|
| 68 |
+
return {
|
| 69 |
+
0: "T-shirt/top",
|
| 70 |
+
1: "Trouser",
|
| 71 |
+
2: "Pullover",
|
| 72 |
+
3: "Dress",
|
| 73 |
+
4: "Coat",
|
| 74 |
+
5: "Sandal",
|
| 75 |
+
6: "Shirt",
|
| 76 |
+
7: "Sneaker",
|
| 77 |
+
8: "Bag",
|
| 78 |
+
9: "Ankle boot",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes):
|
| 83 |
+
fashion_mnist_labels = get_fashion_mnist_labels()
|
| 84 |
+
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
|
| 85 |
+
mapping = {}
|
| 86 |
+
|
| 87 |
+
for fm_label_id, fm_label in fashion_mnist_labels.items():
|
| 88 |
+
fm_label_lower = fm_label.lower()
|
| 89 |
+
matched_hierarchy = None
|
| 90 |
+
|
| 91 |
+
if fm_label_lower in hierarchy_classes_lower:
|
| 92 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)]
|
| 93 |
+
elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower):
|
| 94 |
+
for h_class in hierarchy_classes:
|
| 95 |
+
h_lower = h_class.lower()
|
| 96 |
+
if h_lower in fm_label_lower or fm_label_lower in h_lower:
|
| 97 |
+
matched_hierarchy = h_class
|
| 98 |
+
break
|
| 99 |
+
else:
|
| 100 |
+
if fm_label_lower in ['t-shirt/top', 'top']:
|
| 101 |
+
if 'top' in hierarchy_classes_lower:
|
| 102 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('top')]
|
| 103 |
+
|
| 104 |
+
elif 'trouser' in fm_label_lower:
|
| 105 |
+
for possible in ['bottom', 'pants', 'trousers', 'trouser', 'pant']:
|
| 106 |
+
if possible in hierarchy_classes_lower:
|
| 107 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 108 |
+
break
|
| 109 |
+
|
| 110 |
+
elif 'pullover' in fm_label_lower:
|
| 111 |
+
for possible in ['sweater', 'pullover']:
|
| 112 |
+
if possible in hierarchy_classes_lower:
|
| 113 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
elif 'dress' in fm_label_lower:
|
| 117 |
+
if 'dress' in hierarchy_classes_lower:
|
| 118 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('dress')]
|
| 119 |
+
|
| 120 |
+
elif 'coat' in fm_label_lower:
|
| 121 |
+
for possible in ['jacket', 'outerwear', 'coat']:
|
| 122 |
+
if possible in hierarchy_classes_lower:
|
| 123 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
elif fm_label_lower in ['sandal', 'sneaker', 'ankle boot']:
|
| 127 |
+
for possible in ['shoes', 'shoe', 'sandal', 'sneaker', 'boot']:
|
| 128 |
+
if possible in hierarchy_classes_lower:
|
| 129 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 130 |
+
break
|
| 131 |
+
|
| 132 |
+
elif 'bag' in fm_label_lower:
|
| 133 |
+
if 'bag' in hierarchy_classes_lower:
|
| 134 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('bag')]
|
| 135 |
+
|
| 136 |
+
if matched_hierarchy is None:
|
| 137 |
+
close_matches = difflib.get_close_matches(
|
| 138 |
+
fm_label_lower, hierarchy_classes_lower, n=1, cutoff=0.6
|
| 139 |
+
)
|
| 140 |
+
if close_matches:
|
| 141 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close_matches[0])]
|
| 142 |
+
|
| 143 |
+
mapping[fm_label_id] = matched_hierarchy
|
| 144 |
+
if matched_hierarchy:
|
| 145 |
+
print(f" {fm_label} ({fm_label_id}) -> {matched_hierarchy}")
|
| 146 |
+
else:
|
| 147 |
+
print(f" {fm_label} ({fm_label_id}) -> NO MATCH (will be filtered out)")
|
| 148 |
+
|
| 149 |
+
return mapping
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def convert_fashion_mnist_to_image(pixel_values):
|
| 153 |
+
image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
|
| 154 |
+
image_array = np.stack([image_array] * 3, axis=-1)
|
| 155 |
+
return Image.fromarray(image_array)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class FashionMNISTDataset(Dataset):
|
| 159 |
+
def __init__(self, dataframe, image_size=224, label_mapping=None):
|
| 160 |
+
self.dataframe = dataframe
|
| 161 |
+
self.image_size = image_size
|
| 162 |
+
self.labels_map = get_fashion_mnist_labels()
|
| 163 |
+
self.label_mapping = label_mapping
|
| 164 |
+
|
| 165 |
+
self.transform = transforms.Compose([
|
| 166 |
+
transforms.Resize((image_size, image_size)),
|
| 167 |
+
transforms.ToTensor(),
|
| 168 |
+
transforms.Normalize(
|
| 169 |
+
mean=[0.485, 0.456, 0.406],
|
| 170 |
+
std=[0.229, 0.224, 0.225],
|
| 171 |
+
),
|
| 172 |
+
])
|
| 173 |
+
|
| 174 |
+
def __len__(self):
|
| 175 |
+
return len(self.dataframe)
|
| 176 |
+
|
| 177 |
+
def __getitem__(self, idx):
|
| 178 |
+
row = self.dataframe.iloc[idx]
|
| 179 |
+
|
| 180 |
+
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
|
| 181 |
+
pixel_values = row[pixel_cols].values
|
| 182 |
+
|
| 183 |
+
image = convert_fashion_mnist_to_image(pixel_values)
|
| 184 |
+
image = self.transform(image)
|
| 185 |
+
|
| 186 |
+
label_id = int(row['label'])
|
| 187 |
+
description = self.labels_map[label_id]
|
| 188 |
+
color = "unknown"
|
| 189 |
+
|
| 190 |
+
if self.label_mapping and label_id in self.label_mapping:
|
| 191 |
+
hierarchy = self.label_mapping[label_id]
|
| 192 |
+
else:
|
| 193 |
+
hierarchy = self.labels_map[label_id]
|
| 194 |
+
|
| 195 |
+
return image, description, color, hierarchy
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def load_fashion_mnist_dataset(
|
| 199 |
+
max_samples=10000,
|
| 200 |
+
hierarchy_classes=None,
|
| 201 |
+
csv_path=None,
|
| 202 |
+
):
|
| 203 |
+
if csv_path is None:
|
| 204 |
+
csv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "fashion-mnist_test.csv")
|
| 205 |
+
print("Loading Fashion-MNIST test dataset...")
|
| 206 |
+
df = pd.read_csv(csv_path)
|
| 207 |
+
print(f"Fashion-MNIST dataset loaded: {len(df)} samples")
|
| 208 |
+
|
| 209 |
+
label_mapping = None
|
| 210 |
+
if hierarchy_classes is not None:
|
| 211 |
+
print("\nCreating mapping from Fashion-MNIST labels to hierarchy classes:")
|
| 212 |
+
label_mapping = create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes)
|
| 213 |
+
|
| 214 |
+
valid_label_ids = [lid for lid, h in label_mapping.items() if h is not None]
|
| 215 |
+
df_filtered = df[df['label'].isin(valid_label_ids)]
|
| 216 |
+
print(f"\nAfter filtering to mappable labels: {len(df_filtered)} samples (from {len(df)})")
|
| 217 |
+
df_sample = df_filtered.head(max_samples)
|
| 218 |
+
else:
|
| 219 |
+
df_sample = df.head(max_samples)
|
| 220 |
+
|
| 221 |
+
print(f"Using {len(df_sample)} samples for evaluation")
|
| 222 |
+
return FashionMNISTDataset(df_sample, label_mapping=label_mapping)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# ============================================================================
|
| 226 |
+
# 1b. KAGL Marqo utilities
|
| 227 |
+
# ============================================================================
|
| 228 |
+
|
| 229 |
+
class KaggleHierarchyDataset(Dataset):
|
| 230 |
+
"""KAGL Marqo dataset returning (image, description, color, hierarchy)."""
|
| 231 |
+
|
| 232 |
+
def __init__(self, dataframe, image_size=224):
|
| 233 |
+
self.dataframe = dataframe.reset_index(drop=True)
|
| 234 |
+
self.transform = transforms.Compose([
|
| 235 |
+
transforms.Resize((image_size, image_size)),
|
| 236 |
+
transforms.ToTensor(),
|
| 237 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 238 |
+
])
|
| 239 |
+
|
| 240 |
+
def __len__(self):
|
| 241 |
+
return len(self.dataframe)
|
| 242 |
+
|
| 243 |
+
def __getitem__(self, idx):
|
| 244 |
+
row = self.dataframe.iloc[idx]
|
| 245 |
+
image_data = row["image"]
|
| 246 |
+
if isinstance(image_data, dict) and "bytes" in image_data:
|
| 247 |
+
image = Image.open(BytesIO(image_data["bytes"])).convert("RGB")
|
| 248 |
+
elif hasattr(image_data, "convert"):
|
| 249 |
+
image = image_data.convert("RGB")
|
| 250 |
+
else:
|
| 251 |
+
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 252 |
+
image = self.transform(image)
|
| 253 |
+
description = str(row["text"])
|
| 254 |
+
color = str(row.get("baseColour", "unknown")).lower()
|
| 255 |
+
hierarchy = str(row["hierarchy"])
|
| 256 |
+
return image, description, color, hierarchy
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None):
|
| 260 |
+
"""Load KAGL Marqo dataset with hierarchy labels derived from articleType."""
|
| 261 |
+
from datasets import load_dataset
|
| 262 |
+
|
| 263 |
+
print("Loading KAGL Marqo dataset for hierarchy evaluation...")
|
| 264 |
+
dataset = load_dataset("Marqo/KAGL")
|
| 265 |
+
df = dataset["data"].to_pandas()
|
| 266 |
+
print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}")
|
| 267 |
+
|
| 268 |
+
# Use the most specific category column as hierarchy source
|
| 269 |
+
hierarchy_col = None
|
| 270 |
+
for col in ["articleType", "category3", "category2", "subCategory", "masterCategory", "category1"]:
|
| 271 |
+
if col in df.columns:
|
| 272 |
+
hierarchy_col = col
|
| 273 |
+
break
|
| 274 |
+
|
| 275 |
+
if hierarchy_col is None:
|
| 276 |
+
print("WARNING: No hierarchy column found in KAGL dataset")
|
| 277 |
+
return None
|
| 278 |
+
|
| 279 |
+
print(f"Using '{hierarchy_col}' as hierarchy source")
|
| 280 |
+
df = df.dropna(subset=["text", "image", hierarchy_col])
|
| 281 |
+
df["hierarchy"] = df[hierarchy_col].astype(str).str.strip()
|
| 282 |
+
|
| 283 |
+
# If hierarchy_classes provided, map KAGL types to model hierarchy classes
|
| 284 |
+
if hierarchy_classes:
|
| 285 |
+
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
|
| 286 |
+
mapped = []
|
| 287 |
+
for _, row in df.iterrows():
|
| 288 |
+
kagl_type = row["hierarchy"].lower()
|
| 289 |
+
matched = None
|
| 290 |
+
# Exact match
|
| 291 |
+
if kagl_type in hierarchy_classes_lower:
|
| 292 |
+
matched = hierarchy_classes[hierarchy_classes_lower.index(kagl_type)]
|
| 293 |
+
else:
|
| 294 |
+
# Substring match
|
| 295 |
+
for h_class in hierarchy_classes:
|
| 296 |
+
h_lower = h_class.lower()
|
| 297 |
+
if h_lower in kagl_type or kagl_type in h_lower:
|
| 298 |
+
matched = h_class
|
| 299 |
+
break
|
| 300 |
+
if matched is None:
|
| 301 |
+
close = difflib.get_close_matches(kagl_type, hierarchy_classes_lower, n=1, cutoff=0.6)
|
| 302 |
+
if close:
|
| 303 |
+
matched = hierarchy_classes[hierarchy_classes_lower.index(close[0])]
|
| 304 |
+
mapped.append(matched)
|
| 305 |
+
df["hierarchy"] = mapped
|
| 306 |
+
df = df.dropna(subset=["hierarchy"])
|
| 307 |
+
print(f"After hierarchy mapping: {len(df)} samples")
|
| 308 |
+
|
| 309 |
+
if len(df) > max_samples:
|
| 310 |
+
df = df.sample(n=max_samples, random_state=42)
|
| 311 |
+
|
| 312 |
+
print(f"Using {len(df)} samples, {df['hierarchy'].nunique()} hierarchy classes: "
|
| 313 |
+
f"{sorted(df['hierarchy'].unique())}")
|
| 314 |
+
return KaggleHierarchyDataset(df)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# ============================================================================
|
| 318 |
+
# 1c. Local validation dataset utilities
|
| 319 |
+
# ============================================================================
|
| 320 |
+
|
| 321 |
+
class LocalHierarchyDataset(Dataset):
|
| 322 |
+
"""Local validation dataset returning (image, description, color, hierarchy)."""
|
| 323 |
+
|
| 324 |
+
def __init__(self, dataframe, image_size=224):
|
| 325 |
+
self.dataframe = dataframe.reset_index(drop=True)
|
| 326 |
+
self.transform = transforms.Compose([
|
| 327 |
+
transforms.Resize((image_size, image_size)),
|
| 328 |
+
transforms.ToTensor(),
|
| 329 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 330 |
+
])
|
| 331 |
+
|
| 332 |
+
def __len__(self):
|
| 333 |
+
return len(self.dataframe)
|
| 334 |
+
|
| 335 |
+
def __getitem__(self, idx):
|
| 336 |
+
row = self.dataframe.iloc[idx]
|
| 337 |
+
try:
|
| 338 |
+
image_path = row.get(column_local_image_path) if hasattr(row, "get") else None
|
| 339 |
+
if isinstance(image_path, str) and image_path and os.path.exists(image_path):
|
| 340 |
+
image = Image.open(image_path).convert("RGB")
|
| 341 |
+
else:
|
| 342 |
+
# Fallback: download image from URL (and cache).
|
| 343 |
+
image_url = row.get("image_url") if hasattr(row, "get") else None
|
| 344 |
+
if isinstance(image_url, dict) and "bytes" in image_url:
|
| 345 |
+
image = Image.open(BytesIO(image_url["bytes"])).convert("RGB")
|
| 346 |
+
elif isinstance(image_url, str) and image_url:
|
| 347 |
+
cache_dir = Path(images_dir)
|
| 348 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 349 |
+
url_hash = hashlib.md5(image_url.encode("utf-8")).hexdigest()
|
| 350 |
+
cache_path = cache_dir / f"{url_hash}.jpg"
|
| 351 |
+
if cache_path.exists():
|
| 352 |
+
image = Image.open(cache_path).convert("RGB")
|
| 353 |
+
else:
|
| 354 |
+
resp = requests.get(image_url, timeout=10)
|
| 355 |
+
resp.raise_for_status()
|
| 356 |
+
image = Image.open(BytesIO(resp.content)).convert("RGB")
|
| 357 |
+
# Cache so repeated runs are faster.
|
| 358 |
+
image.save(cache_path, "JPEG", quality=85, optimize=True)
|
| 359 |
+
else:
|
| 360 |
+
raise ValueError("Missing image_path and image_url")
|
| 361 |
+
except Exception:
|
| 362 |
+
image = Image.new("RGB", (224, 224), color="gray")
|
| 363 |
+
image = self.transform(image)
|
| 364 |
+
description = str(row["text"])
|
| 365 |
+
color = str(row.get("color", "unknown"))
|
| 366 |
+
hierarchy = str(row["hierarchy"])
|
| 367 |
+
return image, description, color, hierarchy
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None):
|
| 371 |
+
"""Load internal validation dataset with hierarchy labels."""
|
| 372 |
+
print("Loading local validation dataset for hierarchy evaluation...")
|
| 373 |
+
df = pd.read_csv(local_dataset_path)
|
| 374 |
+
print(f"Dataset loaded: {len(df)} samples")
|
| 375 |
+
|
| 376 |
+
# Some internal CSVs only contain `image_url` (no `local_image_path`).
|
| 377 |
+
# If so, we fall back to downloading images on-demand.
|
| 378 |
+
if column_local_image_path in df.columns:
|
| 379 |
+
df = df.dropna(subset=[column_local_image_path, "hierarchy"])
|
| 380 |
+
else:
|
| 381 |
+
df = df.dropna(subset=["hierarchy"])
|
| 382 |
+
df["hierarchy"] = df["hierarchy"].astype(str).str.strip()
|
| 383 |
+
df = df[df["hierarchy"].str.len() > 0]
|
| 384 |
+
|
| 385 |
+
if hierarchy_classes:
|
| 386 |
+
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
|
| 387 |
+
df["hierarchy_lower"] = df["hierarchy"].str.lower()
|
| 388 |
+
df = df[df["hierarchy_lower"].isin(hierarchy_classes_lower)]
|
| 389 |
+
# Restore proper casing from hierarchy_classes
|
| 390 |
+
case_map = {h.lower(): h for h in hierarchy_classes}
|
| 391 |
+
df["hierarchy"] = df["hierarchy_lower"].map(case_map)
|
| 392 |
+
df = df.drop(columns=["hierarchy_lower"])
|
| 393 |
+
|
| 394 |
+
print(f"After filtering: {len(df)} samples, {df['hierarchy'].nunique()} classes")
|
| 395 |
+
|
| 396 |
+
if len(df) > max_samples:
|
| 397 |
+
df = df.sample(n=max_samples, random_state=42)
|
| 398 |
+
|
| 399 |
+
print(f"Using {len(df)} samples, classes: {sorted(df['hierarchy'].unique())}")
|
| 400 |
+
return LocalHierarchyDataset(df)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# ============================================================================
|
| 404 |
+
# 2. Evaluator
|
| 405 |
+
# ============================================================================
|
| 406 |
+
|
| 407 |
+
class CategoryModelEvaluator:
|
| 408 |
+
"""
|
| 409 |
+
Produces hierarchy confusion matrices for GAP-CLIP and the
|
| 410 |
+
baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and internal datasets.
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
def __init__(self, device='mps', directory='figures/confusion_matrices/cm_hierarchy'):
|
| 414 |
+
self.device = torch.device(device)
|
| 415 |
+
self.directory = directory
|
| 416 |
+
self.color_emb_dim = color_emb_dim
|
| 417 |
+
self.hierarchy_emb_dim = hierarchy_emb_dim
|
| 418 |
+
os.makedirs(self.directory, exist_ok=True)
|
| 419 |
+
|
| 420 |
+
# --- load GAP-CLIP ---
|
| 421 |
+
print(f"Loading GAP-CLIP model from {main_model_path}")
|
| 422 |
+
if not os.path.exists(main_model_path):
|
| 423 |
+
raise FileNotFoundError(f"GAP-CLIP model file {main_model_path} not found")
|
| 424 |
+
|
| 425 |
+
print("Loading hierarchy classes from hierarchy model...")
|
| 426 |
+
if not os.path.exists(hierarchy_model_path):
|
| 427 |
+
raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found")
|
| 428 |
+
|
| 429 |
+
hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device)
|
| 430 |
+
self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
|
| 431 |
+
print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}")
|
| 432 |
+
|
| 433 |
+
self.validation_hierarchy_classes = self._load_validation_hierarchy_classes()
|
| 434 |
+
if self.validation_hierarchy_classes:
|
| 435 |
+
print(f"Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): "
|
| 436 |
+
f"{sorted(self.validation_hierarchy_classes)}")
|
| 437 |
+
else:
|
| 438 |
+
print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.")
|
| 439 |
+
self.validation_hierarchy_classes = self.hierarchy_classes
|
| 440 |
+
|
| 441 |
+
checkpoint = torch.load(main_model_path, map_location=self.device)
|
| 442 |
+
self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 443 |
+
self.model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 444 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 445 |
+
self.model.to(self.device)
|
| 446 |
+
self.model.eval()
|
| 447 |
+
print("GAP-CLIP model loaded successfully")
|
| 448 |
+
|
| 449 |
+
# --- baseline Fashion-CLIP ---
|
| 450 |
+
print("Loading baseline Fashion-CLIP model...")
|
| 451 |
+
patrick_model_name = "patrickjohncyh/fashion-clip"
|
| 452 |
+
self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name)
|
| 453 |
+
self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device)
|
| 454 |
+
self.baseline_model.eval()
|
| 455 |
+
print("Baseline Fashion-CLIP model loaded successfully")
|
| 456 |
+
|
| 457 |
+
# ------------------------------------------------------------------
|
| 458 |
+
# helpers
|
| 459 |
+
# ------------------------------------------------------------------
|
| 460 |
+
def _load_validation_hierarchy_classes(self):
|
| 461 |
+
if not os.path.exists(local_dataset_path):
|
| 462 |
+
print(f"Validation dataset not found at {local_dataset_path}")
|
| 463 |
+
return []
|
| 464 |
+
try:
|
| 465 |
+
df = pd.read_csv(local_dataset_path)
|
| 466 |
+
except Exception as exc:
|
| 467 |
+
print(f"Failed to read validation dataset: {exc}")
|
| 468 |
+
return []
|
| 469 |
+
if 'hierarchy' not in df.columns:
|
| 470 |
+
print("Validation dataset does not contain 'hierarchy' column.")
|
| 471 |
+
return []
|
| 472 |
+
hierarchies = df['hierarchy'].dropna().astype(str).str.strip()
|
| 473 |
+
hierarchies = [h for h in hierarchies if h]
|
| 474 |
+
return sorted(set(hierarchies))
|
| 475 |
+
|
| 476 |
+
def prepare_shared_fashion_mnist(self, max_samples=10000, batch_size=8):
|
| 477 |
+
"""
|
| 478 |
+
Build one shared Fashion-MNIST dataset/dataloader to ensure every model
|
| 479 |
+
is evaluated on the exact same items.
|
| 480 |
+
"""
|
| 481 |
+
target_classes = self.validation_hierarchy_classes or self.hierarchy_classes
|
| 482 |
+
fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_classes)
|
| 483 |
+
dataloader = DataLoader(fashion_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
|
| 484 |
+
|
| 485 |
+
hierarchy_counts = defaultdict(int)
|
| 486 |
+
if len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping:
|
| 487 |
+
for _, row in fashion_dataset.dataframe.iterrows():
|
| 488 |
+
lid = int(row['label'])
|
| 489 |
+
hierarchy_counts[fashion_dataset.label_mapping.get(lid, 'unknown')] += 1
|
| 490 |
+
|
| 491 |
+
return fashion_dataset, dataloader, dict(hierarchy_counts)
|
| 492 |
+
|
| 493 |
+
@staticmethod
|
| 494 |
+
def _count_labels(labels):
|
| 495 |
+
counts = defaultdict(int)
|
| 496 |
+
for label in labels:
|
| 497 |
+
counts[label] += 1
|
| 498 |
+
return dict(counts)
|
| 499 |
+
|
| 500 |
+
def _validate_label_distribution(self, labels, expected_counts, context):
|
| 501 |
+
observed = self._count_labels(labels)
|
| 502 |
+
if observed != expected_counts:
|
| 503 |
+
raise ValueError(
|
| 504 |
+
f"Label distribution mismatch in {context}. "
|
| 505 |
+
f"Expected {expected_counts}, observed {observed}"
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# ------------------------------------------------------------------
|
| 509 |
+
# embedding extraction — GAP-CLIP
|
| 510 |
+
# ------------------------------------------------------------------
|
| 511 |
+
def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
|
| 512 |
+
"""Full 512D embeddings from GAP-CLIP (text or image)."""
|
| 513 |
+
all_embeddings, all_colors, all_hierarchies = [], [], []
|
| 514 |
+
sample_count = 0
|
| 515 |
+
|
| 516 |
+
with torch.no_grad():
|
| 517 |
+
for batch in tqdm(dataloader, desc=f"GAP-CLIP {embedding_type} embeddings"):
|
| 518 |
+
if sample_count >= max_samples:
|
| 519 |
+
break
|
| 520 |
+
images, texts, colors, hierarchies = batch
|
| 521 |
+
images = images.to(self.device).expand(-1, 3, -1, -1)
|
| 522 |
+
|
| 523 |
+
text_inputs = self.processor(text=list(texts), padding=True, return_tensors="pt")
|
| 524 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 525 |
+
outputs = self.model(**text_inputs, pixel_values=images)
|
| 526 |
+
|
| 527 |
+
if embedding_type == 'image':
|
| 528 |
+
emb = outputs.image_embeds
|
| 529 |
+
else:
|
| 530 |
+
emb = outputs.text_embeds
|
| 531 |
+
|
| 532 |
+
all_embeddings.append(emb.cpu().numpy())
|
| 533 |
+
all_colors.extend(colors)
|
| 534 |
+
all_hierarchies.extend(hierarchies)
|
| 535 |
+
sample_count += len(images)
|
| 536 |
+
|
| 537 |
+
del images, text_inputs, outputs, emb
|
| 538 |
+
if torch.cuda.is_available():
|
| 539 |
+
torch.cuda.empty_cache()
|
| 540 |
+
|
| 541 |
+
return np.vstack(all_embeddings), all_colors, all_hierarchies
|
| 542 |
+
|
| 543 |
+
# ------------------------------------------------------------------
|
| 544 |
+
# embedding extraction — baseline Fashion-CLIP
|
| 545 |
+
# ------------------------------------------------------------------
|
| 546 |
+
def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
|
| 547 |
+
"""L2-normalised embeddings from baseline Fashion-CLIP."""
|
| 548 |
+
all_embeddings, all_colors, all_hierarchies = [], [], []
|
| 549 |
+
sample_count = 0
|
| 550 |
+
|
| 551 |
+
with torch.no_grad():
|
| 552 |
+
for batch in tqdm(dataloader, desc=f"Baseline {embedding_type} embeddings"):
|
| 553 |
+
if sample_count >= max_samples:
|
| 554 |
+
break
|
| 555 |
+
images, texts, colors, hierarchies = batch
|
| 556 |
+
|
| 557 |
+
if embedding_type == 'text':
|
| 558 |
+
inp = self.baseline_processor(
|
| 559 |
+
text=list(texts), return_tensors="pt",
|
| 560 |
+
padding=True, truncation=True, max_length=77,
|
| 561 |
+
)
|
| 562 |
+
inp = {k: v.to(self.device) for k, v in inp.items()}
|
| 563 |
+
feats = self.baseline_model.get_text_features(**inp)
|
| 564 |
+
feats = feats / feats.norm(dim=-1, keepdim=True)
|
| 565 |
+
emb = feats
|
| 566 |
+
|
| 567 |
+
elif embedding_type == 'image':
|
| 568 |
+
pil_images = []
|
| 569 |
+
for i in range(images.shape[0]):
|
| 570 |
+
t = images[i]
|
| 571 |
+
if t.min() < 0 or t.max() > 1:
|
| 572 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 573 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 574 |
+
t = torch.clamp(t * std + mean, 0, 1)
|
| 575 |
+
pil_images.append(transforms.ToPILImage()(t))
|
| 576 |
+
|
| 577 |
+
inp = self.baseline_processor(images=pil_images, return_tensors="pt")
|
| 578 |
+
inp = {k: v.to(self.device) for k, v in inp.items()}
|
| 579 |
+
feats = self.baseline_model.get_image_features(**inp)
|
| 580 |
+
feats = feats / feats.norm(dim=-1, keepdim=True)
|
| 581 |
+
emb = feats
|
| 582 |
+
else:
|
| 583 |
+
inp = self.baseline_processor(
|
| 584 |
+
text=list(texts), return_tensors="pt",
|
| 585 |
+
padding=True, truncation=True, max_length=77,
|
| 586 |
+
)
|
| 587 |
+
inp = {k: v.to(self.device) for k, v in inp.items()}
|
| 588 |
+
feats = self.baseline_model.get_text_features(**inp)
|
| 589 |
+
feats = feats / feats.norm(dim=-1, keepdim=True)
|
| 590 |
+
emb = feats
|
| 591 |
+
|
| 592 |
+
all_embeddings.append(emb.cpu().numpy())
|
| 593 |
+
all_colors.extend(colors)
|
| 594 |
+
all_hierarchies.extend(hierarchies)
|
| 595 |
+
sample_count += len(images)
|
| 596 |
+
|
| 597 |
+
del emb
|
| 598 |
+
if torch.cuda.is_available():
|
| 599 |
+
torch.cuda.empty_cache()
|
| 600 |
+
|
| 601 |
+
return np.vstack(all_embeddings), all_colors, all_hierarchies
|
| 602 |
+
|
| 603 |
+
# ------------------------------------------------------------------
|
| 604 |
+
# metrics
|
| 605 |
+
# ------------------------------------------------------------------
|
| 606 |
+
def compute_embedding_accuracy(self, embeddings, labels, similarities=None):
|
| 607 |
+
n = len(embeddings)
|
| 608 |
+
if n == 0:
|
| 609 |
+
return 0.0
|
| 610 |
+
if similarities is None:
|
| 611 |
+
similarities = cosine_similarity(embeddings)
|
| 612 |
+
|
| 613 |
+
correct = 0
|
| 614 |
+
for i in range(n):
|
| 615 |
+
sims = similarities[i].copy()
|
| 616 |
+
sims[i] = -1.0
|
| 617 |
+
nearest_neighbor_idx = int(np.argmax(sims))
|
| 618 |
+
predicted = labels[nearest_neighbor_idx]
|
| 619 |
+
if predicted == labels[i]:
|
| 620 |
+
correct += 1
|
| 621 |
+
return correct / n
|
| 622 |
+
|
| 623 |
+
def compute_similarity_metrics(self, embeddings, labels):
|
| 624 |
+
max_samples = min(5000, len(embeddings))
|
| 625 |
+
if len(embeddings) > max_samples:
|
| 626 |
+
indices = np.random.choice(len(embeddings), max_samples, replace=False)
|
| 627 |
+
embeddings = embeddings[indices]
|
| 628 |
+
labels = [labels[i] for i in indices]
|
| 629 |
+
|
| 630 |
+
similarities = cosine_similarity(embeddings)
|
| 631 |
+
|
| 632 |
+
label_groups = defaultdict(list)
|
| 633 |
+
for i, label in enumerate(labels):
|
| 634 |
+
label_groups[label].append(i)
|
| 635 |
+
|
| 636 |
+
intra = []
|
| 637 |
+
for _, idxs in label_groups.items():
|
| 638 |
+
if len(idxs) > 1:
|
| 639 |
+
for i in range(len(idxs)):
|
| 640 |
+
for j in range(i + 1, len(idxs)):
|
| 641 |
+
intra.append(similarities[idxs[i], idxs[j]])
|
| 642 |
+
|
| 643 |
+
inter = []
|
| 644 |
+
keys = list(label_groups.keys())
|
| 645 |
+
for i in range(len(keys)):
|
| 646 |
+
for j in range(i + 1, len(keys)):
|
| 647 |
+
for idx1 in label_groups[keys[i]]:
|
| 648 |
+
for idx2 in label_groups[keys[j]]:
|
| 649 |
+
inter.append(similarities[idx1, idx2])
|
| 650 |
+
|
| 651 |
+
nn_acc = self.compute_embedding_accuracy(embeddings, labels, similarities)
|
| 652 |
+
|
| 653 |
+
return {
|
| 654 |
+
'intra_class_mean': float(np.mean(intra)) if intra else 0.0,
|
| 655 |
+
'inter_class_mean': float(np.mean(inter)) if inter else 0.0,
|
| 656 |
+
'separation_score': (float(np.mean(intra) - np.mean(inter))
|
| 657 |
+
if intra and inter else 0.0),
|
| 658 |
+
'nn_accuracy': nn_acc,
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
def compute_centroid_accuracy(self, embeddings, labels):
|
| 662 |
+
if len(embeddings) == 0:
|
| 663 |
+
return 0.0
|
| 664 |
+
emb_norm = normalize(embeddings, norm='l2')
|
| 665 |
+
unique_labels = sorted(set(labels))
|
| 666 |
+
centroids = {}
|
| 667 |
+
for label in unique_labels:
|
| 668 |
+
idx = [i for i, l in enumerate(labels) if l == label]
|
| 669 |
+
centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0]
|
| 670 |
+
|
| 671 |
+
correct = 0
|
| 672 |
+
for i, emb in enumerate(emb_norm):
|
| 673 |
+
best_sim, pred = -1, None
|
| 674 |
+
for label, c in centroids.items():
|
| 675 |
+
sim = cosine_similarity([emb], [c])[0][0]
|
| 676 |
+
if sim > best_sim:
|
| 677 |
+
best_sim, pred = sim, label
|
| 678 |
+
if pred == labels[i]:
|
| 679 |
+
correct += 1
|
| 680 |
+
return correct / len(labels)
|
| 681 |
+
|
| 682 |
+
def predict_labels_from_embeddings(self, embeddings, labels):
|
| 683 |
+
emb_norm = normalize(embeddings, norm='l2')
|
| 684 |
+
unique_labels = sorted(set(labels))
|
| 685 |
+
centroids = {}
|
| 686 |
+
for label in unique_labels:
|
| 687 |
+
idx = [i for i, l in enumerate(labels) if l == label]
|
| 688 |
+
centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0]
|
| 689 |
+
|
| 690 |
+
preds = []
|
| 691 |
+
for emb in emb_norm:
|
| 692 |
+
best_sim, pred = -1, None
|
| 693 |
+
for label, c in centroids.items():
|
| 694 |
+
sim = cosine_similarity([emb], [c])[0][0]
|
| 695 |
+
if sim > best_sim:
|
| 696 |
+
best_sim, pred = sim, label
|
| 697 |
+
preds.append(pred)
|
| 698 |
+
return preds
|
| 699 |
+
|
| 700 |
+
def predict_labels_nearest_neighbor(self, embeddings, labels):
|
| 701 |
+
"""
|
| 702 |
+
Predict labels using 1-NN on the same embedding set.
|
| 703 |
+
This matches the accuracy logic used in the evaluation pipeline.
|
| 704 |
+
"""
|
| 705 |
+
similarities = cosine_similarity(embeddings)
|
| 706 |
+
preds = []
|
| 707 |
+
for i in range(len(embeddings)):
|
| 708 |
+
sims = similarities[i].copy()
|
| 709 |
+
sims[i] = -1.0
|
| 710 |
+
nearest_neighbor_idx = int(np.argmax(sims))
|
| 711 |
+
preds.append(labels[nearest_neighbor_idx])
|
| 712 |
+
return preds
|
| 713 |
+
|
| 714 |
+
# ------------------------------------------------------------------
|
| 715 |
+
# image + text ensemble
|
| 716 |
+
# ------------------------------------------------------------------
|
| 717 |
+
def _compute_img_centroids(self, embeddings, labels):
|
| 718 |
+
emb_norm = normalize(embeddings, norm='l2')
|
| 719 |
+
centroids = {}
|
| 720 |
+
for label in sorted(set(labels)):
|
| 721 |
+
idx = [i for i, l in enumerate(labels) if l == label]
|
| 722 |
+
centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0]
|
| 723 |
+
return centroids
|
| 724 |
+
|
| 725 |
+
def predict_labels_image_ensemble(self, img_embeddings, labels,
|
| 726 |
+
text_protos, cls_names, alpha=0.5):
|
| 727 |
+
"""Combine image centroids (512D) with text prototypes (512D)."""
|
| 728 |
+
img_norm = normalize(img_embeddings, norm='l2')
|
| 729 |
+
img_centroids = self._compute_img_centroids(img_norm, labels)
|
| 730 |
+
centroid_mat = np.stack([img_centroids[c] for c in cls_names], axis=0)
|
| 731 |
+
|
| 732 |
+
preds = []
|
| 733 |
+
for i in range(len(img_norm)):
|
| 734 |
+
v = img_norm[i:i + 1]
|
| 735 |
+
sim_img = cosine_similarity(v, centroid_mat)[0]
|
| 736 |
+
sim_txt = cosine_similarity(v, text_protos)[0]
|
| 737 |
+
scores = alpha * sim_img + (1 - alpha) * sim_txt
|
| 738 |
+
preds.append(cls_names[int(np.argmax(scores))])
|
| 739 |
+
return preds
|
| 740 |
+
|
| 741 |
+
# ------------------------------------------------------------------
|
| 742 |
+
# confusion matrix & classification report
|
| 743 |
+
# ------------------------------------------------------------------
|
| 744 |
+
def create_confusion_matrix(self, true_labels, predicted_labels,
|
| 745 |
+
title="Confusion Matrix", label_type="Label"):
|
| 746 |
+
unique_labels = sorted(set(true_labels + predicted_labels))
|
| 747 |
+
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
|
| 748 |
+
acc = accuracy_score(true_labels, predicted_labels)
|
| 749 |
+
|
| 750 |
+
plt.figure(figsize=(10, 8))
|
| 751 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 752 |
+
xticklabels=unique_labels, yticklabels=unique_labels)
|
| 753 |
+
plt.title(f'{title}\nAccuracy: {acc:.3f} ({acc * 100:.1f}%)')
|
| 754 |
+
plt.ylabel(f'True {label_type}')
|
| 755 |
+
plt.xlabel(f'Predicted {label_type}')
|
| 756 |
+
plt.xticks(rotation=45)
|
| 757 |
+
plt.yticks(rotation=0)
|
| 758 |
+
plt.tight_layout()
|
| 759 |
+
return plt.gcf(), acc, cm
|
| 760 |
+
|
| 761 |
+
def evaluate_classification_performance(self, embeddings, labels,
|
| 762 |
+
embedding_type="Embeddings",
|
| 763 |
+
label_type="Label",
|
| 764 |
+
method="nn"):
|
| 765 |
+
if method == "nn":
|
| 766 |
+
preds = self.predict_labels_nearest_neighbor(embeddings, labels)
|
| 767 |
+
elif method == "centroid":
|
| 768 |
+
preds = self.predict_labels_from_embeddings(embeddings, labels)
|
| 769 |
+
else:
|
| 770 |
+
raise ValueError(f"Unknown classification method: {method}")
|
| 771 |
+
acc = accuracy_score(labels, preds)
|
| 772 |
+
unique_labels = sorted(set(labels))
|
| 773 |
+
fig, _, cm = self.create_confusion_matrix(
|
| 774 |
+
labels, preds,
|
| 775 |
+
embedding_type,
|
| 776 |
+
label_type,
|
| 777 |
+
)
|
| 778 |
+
report = classification_report(labels, preds, labels=unique_labels,
|
| 779 |
+
target_names=unique_labels, output_dict=True)
|
| 780 |
+
return {
|
| 781 |
+
'accuracy': acc,
|
| 782 |
+
'predictions': preds,
|
| 783 |
+
'confusion_matrix': cm,
|
| 784 |
+
'labels': unique_labels,
|
| 785 |
+
'classification_report': report,
|
| 786 |
+
'figure': fig,
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
# ==================================================================
|
| 790 |
+
# 3. GAP-CLIP evaluation on Fashion-MNIST
|
| 791 |
+
# ==================================================================
|
| 792 |
+
def evaluate_gap_clip_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None):
|
| 793 |
+
print(f"\n{'=' * 60}")
|
| 794 |
+
print("Evaluating GAP-CLIP on Fashion-MNIST")
|
| 795 |
+
print(" Hierarchy embeddings (dims 16-79)")
|
| 796 |
+
print(f" Max samples: {max_samples}")
|
| 797 |
+
print(f"{'=' * 60}")
|
| 798 |
+
|
| 799 |
+
if dataloader is None:
|
| 800 |
+
fashion_dataset, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples)
|
| 801 |
+
expected_counts = expected_counts or dataset_counts
|
| 802 |
+
else:
|
| 803 |
+
fashion_dataset = getattr(dataloader, "dataset", None)
|
| 804 |
+
if expected_counts is None:
|
| 805 |
+
raise ValueError("expected_counts must be provided when using a custom dataloader.")
|
| 806 |
+
|
| 807 |
+
if fashion_dataset is not None and len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping:
|
| 808 |
+
print(f"\nHierarchy distribution in dataset:")
|
| 809 |
+
for h in sorted(expected_counts):
|
| 810 |
+
print(f" {h}: {expected_counts[h]} samples")
|
| 811 |
+
|
| 812 |
+
results = {}
|
| 813 |
+
|
| 814 |
+
# --- full 512D embeddings (text & image) ---
|
| 815 |
+
print("\nExtracting full 512-dimensional GAP-CLIP embeddings...")
|
| 816 |
+
text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples)
|
| 817 |
+
img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples)
|
| 818 |
+
self._validate_label_distribution(text_hier, expected_counts, "GAP-CLIP text")
|
| 819 |
+
self._validate_label_distribution(img_hier, expected_counts, "GAP-CLIP image")
|
| 820 |
+
print(f" Text shape: {text_full.shape} | Image shape: {img_full.shape}")
|
| 821 |
+
|
| 822 |
+
# --- TEXT: hierarchy on specialized 64D (dims 16-79) ---
|
| 823 |
+
print("\n--- GAP-CLIP TEXT HIERARCHY (dims 16-79) ---")
|
| 824 |
+
text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
|
| 825 |
+
print(f" Specialized text hierarchy shape: {text_hier_spec.shape}")
|
| 826 |
+
|
| 827 |
+
text_metrics = self.compute_similarity_metrics(text_hier_spec, text_hier)
|
| 828 |
+
text_class = self.evaluate_classification_performance(
|
| 829 |
+
text_hier_spec, text_hier,
|
| 830 |
+
"Fashion-MNIST, text, hierarchy confusion matrix", "Hierarchy",
|
| 831 |
+
method="nn",
|
| 832 |
+
)
|
| 833 |
+
text_metrics.update(text_class)
|
| 834 |
+
results['text_hierarchy'] = text_metrics
|
| 835 |
+
|
| 836 |
+
# --- IMAGE: 64D vs 512D + ensemble ---
|
| 837 |
+
print("\n--- GAP-CLIP IMAGE HIERARCHY (64D vs 512D) ---")
|
| 838 |
+
img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
|
| 839 |
+
print(f" Specialized image hierarchy shape: {img_hier_spec.shape}")
|
| 840 |
+
|
| 841 |
+
print(" Testing specialized 64D...")
|
| 842 |
+
spec_metrics = self.compute_similarity_metrics(img_hier_spec, img_hier)
|
| 843 |
+
spec_class = self.evaluate_classification_performance(
|
| 844 |
+
img_hier_spec, img_hier,
|
| 845 |
+
"Fashion-MNIST, image, hierarchy confusion matrix", "Hierarchy",
|
| 846 |
+
method="nn",
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
print(" Testing full 512D...")
|
| 850 |
+
full_metrics = self.compute_similarity_metrics(img_full, img_hier)
|
| 851 |
+
full_class = self.evaluate_classification_performance(
|
| 852 |
+
img_full, img_hier,
|
| 853 |
+
"Fashion-MNIST, image, hierarchy confusion matrix", "Hierarchy",
|
| 854 |
+
method="nn",
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
if full_class['accuracy'] >= spec_class['accuracy']:
|
| 858 |
+
print(f" 512D wins: {full_class['accuracy'] * 100:.1f}% vs {spec_class['accuracy'] * 100:.1f}%")
|
| 859 |
+
img_metrics, img_class = full_metrics, full_class
|
| 860 |
+
else:
|
| 861 |
+
print(f" 64D wins: {spec_class['accuracy'] * 100:.1f}% vs {full_class['accuracy'] * 100:.1f}%")
|
| 862 |
+
img_metrics, img_class = spec_metrics, spec_class
|
| 863 |
+
|
| 864 |
+
# --- ensemble image + text prototypes ---
|
| 865 |
+
print("\n Testing GAP-CLIP image + text ensemble (prototypes per class)...")
|
| 866 |
+
cls_names = sorted(set(img_hier))
|
| 867 |
+
prompts = [f"a photo of a {c}" for c in cls_names]
|
| 868 |
+
text_inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True)
|
| 869 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 870 |
+
with torch.no_grad():
|
| 871 |
+
txt_feats = self.model.get_text_features(**text_inputs)
|
| 872 |
+
txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True)
|
| 873 |
+
text_protos = txt_feats.cpu().numpy()
|
| 874 |
+
|
| 875 |
+
ensemble_preds = self.predict_labels_image_ensemble(
|
| 876 |
+
img_full, img_hier, text_protos, cls_names, alpha=0.7,
|
| 877 |
+
)
|
| 878 |
+
ensemble_acc = accuracy_score(img_hier, ensemble_preds)
|
| 879 |
+
print(f" Ensemble accuracy (alpha=0.7): {ensemble_acc * 100:.2f}%")
|
| 880 |
+
|
| 881 |
+
img_metrics.update(img_class)
|
| 882 |
+
img_metrics['ensemble_accuracy'] = ensemble_acc
|
| 883 |
+
results['image_hierarchy'] = img_metrics
|
| 884 |
+
|
| 885 |
+
# --- save confusion matrix figures ---
|
| 886 |
+
for key in ['text_hierarchy', 'image_hierarchy']:
|
| 887 |
+
fig = results[key]['figure']
|
| 888 |
+
fig.savefig(
|
| 889 |
+
os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"),
|
| 890 |
+
dpi=300, bbox_inches='tight',
|
| 891 |
+
)
|
| 892 |
+
plt.close(fig)
|
| 893 |
+
|
| 894 |
+
del text_full, img_full, text_hier_spec, img_hier_spec
|
| 895 |
+
if torch.cuda.is_available():
|
| 896 |
+
torch.cuda.empty_cache()
|
| 897 |
+
|
| 898 |
+
return results
|
| 899 |
+
|
| 900 |
+
# ==================================================================
|
| 901 |
+
# 4. Baseline Fashion-CLIP evaluation on Fashion-MNIST
|
| 902 |
+
# ==================================================================
|
| 903 |
+
def evaluate_baseline_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None):
|
| 904 |
+
print(f"\n{'=' * 60}")
|
| 905 |
+
print("Evaluating Baseline Fashion-CLIP on Fashion-MNIST")
|
| 906 |
+
print(f" Max samples: {max_samples}")
|
| 907 |
+
print(f"{'=' * 60}")
|
| 908 |
+
|
| 909 |
+
if dataloader is None:
|
| 910 |
+
_, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples)
|
| 911 |
+
expected_counts = expected_counts or dataset_counts
|
| 912 |
+
elif expected_counts is None:
|
| 913 |
+
raise ValueError("expected_counts must be provided when using a custom dataloader.")
|
| 914 |
+
|
| 915 |
+
results = {}
|
| 916 |
+
|
| 917 |
+
# --- text ---
|
| 918 |
+
print("\nExtracting baseline text embeddings...")
|
| 919 |
+
text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
|
| 920 |
+
self._validate_label_distribution(text_hier, expected_counts, "baseline text")
|
| 921 |
+
print(f" Baseline text shape: {text_emb.shape}")
|
| 922 |
+
|
| 923 |
+
text_metrics = self.compute_similarity_metrics(text_emb, text_hier)
|
| 924 |
+
text_class = self.evaluate_classification_performance(
|
| 925 |
+
text_emb, text_hier,
|
| 926 |
+
"Fashion-MNIST, text, hierarchy confusion matrix", "Hierarchy",
|
| 927 |
+
method="nn",
|
| 928 |
+
)
|
| 929 |
+
text_metrics.update(text_class)
|
| 930 |
+
results['text'] = {'hierarchy': text_metrics}
|
| 931 |
+
|
| 932 |
+
del text_emb
|
| 933 |
+
if torch.cuda.is_available():
|
| 934 |
+
torch.cuda.empty_cache()
|
| 935 |
+
|
| 936 |
+
# --- image ---
|
| 937 |
+
print("\nExtracting baseline image embeddings...")
|
| 938 |
+
img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
|
| 939 |
+
self._validate_label_distribution(img_hier, expected_counts, "baseline image")
|
| 940 |
+
print(f" Baseline image shape: {img_emb.shape}")
|
| 941 |
+
|
| 942 |
+
img_metrics = self.compute_similarity_metrics(img_emb, img_hier)
|
| 943 |
+
img_class = self.evaluate_classification_performance(
|
| 944 |
+
img_emb, img_hier,
|
| 945 |
+
"Fashion-MNIST, image, hierarchy confusion matrix", "Hierarchy",
|
| 946 |
+
method="nn",
|
| 947 |
+
)
|
| 948 |
+
img_metrics.update(img_class)
|
| 949 |
+
results['image'] = {'hierarchy': img_metrics}
|
| 950 |
+
|
| 951 |
+
del img_emb
|
| 952 |
+
if torch.cuda.is_available():
|
| 953 |
+
torch.cuda.empty_cache()
|
| 954 |
+
|
| 955 |
+
for key in ['text', 'image']:
|
| 956 |
+
fig = results[key]['hierarchy']['figure']
|
| 957 |
+
fig.savefig(
|
| 958 |
+
os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"),
|
| 959 |
+
dpi=300, bbox_inches='tight',
|
| 960 |
+
)
|
| 961 |
+
plt.close(fig)
|
| 962 |
+
|
| 963 |
+
return results
|
| 964 |
+
|
| 965 |
+
# ==================================================================
|
| 966 |
+
# 5. Generic dataset evaluation (KAGL Marqo / Internal)
|
| 967 |
+
# ==================================================================
|
| 968 |
+
def evaluate_gap_clip_generic(self, dataloader, dataset_name, max_samples=10000):
|
| 969 |
+
"""Evaluate GAP-CLIP hierarchy performance on any dataset."""
|
| 970 |
+
print(f"\n{'=' * 60}")
|
| 971 |
+
print(f"Evaluating GAP-CLIP on {dataset_name}")
|
| 972 |
+
print(f" Hierarchy embeddings (dims 16-79)")
|
| 973 |
+
print(f"{'=' * 60}")
|
| 974 |
+
|
| 975 |
+
results = {}
|
| 976 |
+
|
| 977 |
+
# --- text hierarchy (64D specialized) ---
|
| 978 |
+
print("\nExtracting GAP-CLIP text embeddings...")
|
| 979 |
+
text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples)
|
| 980 |
+
text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
|
| 981 |
+
print(f" Text shape: {text_full.shape}, hierarchy subspace: {text_hier_spec.shape}")
|
| 982 |
+
|
| 983 |
+
text_metrics = self.compute_similarity_metrics(text_hier_spec, text_hier)
|
| 984 |
+
text_class = self.evaluate_classification_performance(
|
| 985 |
+
text_hier_spec, text_hier,
|
| 986 |
+
f"{dataset_name}, text, hierarchy confusion matrix", "Hierarchy", method="nn",
|
| 987 |
+
)
|
| 988 |
+
text_metrics.update(text_class)
|
| 989 |
+
results['text_hierarchy'] = text_metrics
|
| 990 |
+
|
| 991 |
+
# --- image hierarchy (best of 64D vs 512D) ---
|
| 992 |
+
print("\nExtracting GAP-CLIP image embeddings...")
|
| 993 |
+
img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples)
|
| 994 |
+
img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim]
|
| 995 |
+
|
| 996 |
+
spec_metrics = self.compute_similarity_metrics(img_hier_spec, img_hier)
|
| 997 |
+
spec_class = self.evaluate_classification_performance(
|
| 998 |
+
img_hier_spec, img_hier,
|
| 999 |
+
f"{dataset_name}, image, hierarchy confusion matrix", "Hierarchy", method="nn",
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
full_metrics = self.compute_similarity_metrics(img_full, img_hier)
|
| 1003 |
+
full_class = self.evaluate_classification_performance(
|
| 1004 |
+
img_full, img_hier,
|
| 1005 |
+
f"{dataset_name}, image, hierarchy confusion matrix", "Hierarchy", method="nn",
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
if full_class['accuracy'] >= spec_class['accuracy']:
|
| 1009 |
+
print(f" 512D wins: {full_class['accuracy']*100:.1f}% vs {spec_class['accuracy']*100:.1f}%")
|
| 1010 |
+
img_metrics, img_class = full_metrics, full_class
|
| 1011 |
+
else:
|
| 1012 |
+
print(f" 64D wins: {spec_class['accuracy']*100:.1f}% vs {full_class['accuracy']*100:.1f}%")
|
| 1013 |
+
img_metrics, img_class = spec_metrics, spec_class
|
| 1014 |
+
|
| 1015 |
+
img_metrics.update(img_class)
|
| 1016 |
+
results['image_hierarchy'] = img_metrics
|
| 1017 |
+
|
| 1018 |
+
# --- save confusion matrices ---
|
| 1019 |
+
prefix = dataset_name.lower().replace(" ", "_")
|
| 1020 |
+
for key in ['text_hierarchy', 'image_hierarchy']:
|
| 1021 |
+
fig = results[key]['figure']
|
| 1022 |
+
fig.savefig(
|
| 1023 |
+
os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"),
|
| 1024 |
+
dpi=300, bbox_inches='tight',
|
| 1025 |
+
)
|
| 1026 |
+
plt.close(fig)
|
| 1027 |
+
|
| 1028 |
+
del text_full, img_full, text_hier_spec, img_hier_spec
|
| 1029 |
+
if torch.cuda.is_available():
|
| 1030 |
+
torch.cuda.empty_cache()
|
| 1031 |
+
|
| 1032 |
+
return results
|
| 1033 |
+
|
| 1034 |
+
def evaluate_baseline_generic(self, dataloader, dataset_name, max_samples=10000):
|
| 1035 |
+
"""Evaluate baseline Fashion-CLIP hierarchy performance on any dataset."""
|
| 1036 |
+
print(f"\n{'=' * 60}")
|
| 1037 |
+
print(f"Evaluating Baseline Fashion-CLIP on {dataset_name}")
|
| 1038 |
+
print(f"{'=' * 60}")
|
| 1039 |
+
|
| 1040 |
+
results = {}
|
| 1041 |
+
|
| 1042 |
+
# --- text ---
|
| 1043 |
+
print("\nExtracting baseline text embeddings...")
|
| 1044 |
+
text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
|
| 1045 |
+
print(f" Baseline text shape: {text_emb.shape}")
|
| 1046 |
+
|
| 1047 |
+
text_metrics = self.compute_similarity_metrics(text_emb, text_hier)
|
| 1048 |
+
text_class = self.evaluate_classification_performance(
|
| 1049 |
+
text_emb, text_hier,
|
| 1050 |
+
f"{dataset_name}, text, hierarchy confusion matrix", "Hierarchy", method="nn",
|
| 1051 |
+
)
|
| 1052 |
+
text_metrics.update(text_class)
|
| 1053 |
+
results['text'] = {'hierarchy': text_metrics}
|
| 1054 |
+
|
| 1055 |
+
del text_emb
|
| 1056 |
+
if torch.cuda.is_available():
|
| 1057 |
+
torch.cuda.empty_cache()
|
| 1058 |
+
|
| 1059 |
+
# --- image ---
|
| 1060 |
+
print("\nExtracting baseline image embeddings...")
|
| 1061 |
+
img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
|
| 1062 |
+
print(f" Baseline image shape: {img_emb.shape}")
|
| 1063 |
+
|
| 1064 |
+
img_metrics = self.compute_similarity_metrics(img_emb, img_hier)
|
| 1065 |
+
img_class = self.evaluate_classification_performance(
|
| 1066 |
+
img_emb, img_hier,
|
| 1067 |
+
f"{dataset_name}, image, hierarchy confusion matrix", "Hierarchy", method="nn",
|
| 1068 |
+
)
|
| 1069 |
+
img_metrics.update(img_class)
|
| 1070 |
+
results['image'] = {'hierarchy': img_metrics}
|
| 1071 |
+
|
| 1072 |
+
del img_emb
|
| 1073 |
+
if torch.cuda.is_available():
|
| 1074 |
+
torch.cuda.empty_cache()
|
| 1075 |
+
|
| 1076 |
+
prefix = dataset_name.lower().replace(" ", "_")
|
| 1077 |
+
for key in ['text', 'image']:
|
| 1078 |
+
fig = results[key]['hierarchy']['figure']
|
| 1079 |
+
fig.savefig(
|
| 1080 |
+
os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.png"),
|
| 1081 |
+
dpi=300, bbox_inches='tight',
|
| 1082 |
+
)
|
| 1083 |
+
plt.close(fig)
|
| 1084 |
+
|
| 1085 |
+
return results
|
| 1086 |
+
|
| 1087 |
+
# ==================================================================
|
| 1088 |
+
# 6. Full evaluation across all datasets
|
| 1089 |
+
# ==================================================================
|
| 1090 |
+
def run_full_evaluation(self, max_samples=10000, local_max_samples=None, batch_size=8):
|
| 1091 |
+
"""Run hierarchy evaluation on all 3 datasets for both models."""
|
| 1092 |
+
if local_max_samples is None:
|
| 1093 |
+
local_max_samples = max_samples
|
| 1094 |
+
all_results = {}
|
| 1095 |
+
|
| 1096 |
+
# --- Fashion-MNIST ---
|
| 1097 |
+
shared_dataset, shared_dataloader, shared_counts = self.prepare_shared_fashion_mnist(
|
| 1098 |
+
max_samples=max_samples, batch_size=batch_size,
|
| 1099 |
+
)
|
| 1100 |
+
all_results['fashion_mnist_gap'] = self.evaluate_gap_clip_fashion_mnist(
|
| 1101 |
+
max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts,
|
| 1102 |
+
)
|
| 1103 |
+
all_results['fashion_mnist_baseline'] = self.evaluate_baseline_fashion_mnist(
|
| 1104 |
+
max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts,
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
# --- KAGL Marqo ---
|
| 1108 |
+
try:
|
| 1109 |
+
kaggle_dataset = load_kaggle_marqo_with_hierarchy(
|
| 1110 |
+
max_samples=max_samples,
|
| 1111 |
+
hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
|
| 1112 |
+
)
|
| 1113 |
+
if kaggle_dataset is not None and len(kaggle_dataset) > 0:
|
| 1114 |
+
kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
|
| 1115 |
+
all_results['kaggle_gap'] = self.evaluate_gap_clip_generic(
|
| 1116 |
+
kaggle_dataloader, "KAGL Marqo", max_samples,
|
| 1117 |
+
)
|
| 1118 |
+
all_results['kaggle_baseline'] = self.evaluate_baseline_generic(
|
| 1119 |
+
kaggle_dataloader, "KAGL Marqo", max_samples,
|
| 1120 |
+
)
|
| 1121 |
+
else:
|
| 1122 |
+
print("WARNING: KAGL Marqo dataset empty after hierarchy mapping, skipping.")
|
| 1123 |
+
except Exception as e:
|
| 1124 |
+
print(f"WARNING: Could not evaluate on KAGL Marqo: {e}")
|
| 1125 |
+
|
| 1126 |
+
# --- Internal (local validation) ---
|
| 1127 |
+
try:
|
| 1128 |
+
local_dataset = load_local_validation_with_hierarchy(
|
| 1129 |
+
max_samples=local_max_samples,
|
| 1130 |
+
hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes,
|
| 1131 |
+
)
|
| 1132 |
+
if local_dataset is not None and len(local_dataset) > 0:
|
| 1133 |
+
local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
|
| 1134 |
+
all_results['local_gap'] = self.evaluate_gap_clip_generic(
|
| 1135 |
+
local_dataloader, "Internal", local_max_samples,
|
| 1136 |
+
)
|
| 1137 |
+
all_results['local_baseline'] = self.evaluate_baseline_generic(
|
| 1138 |
+
local_dataloader, "Internal", local_max_samples,
|
| 1139 |
+
)
|
| 1140 |
+
else:
|
| 1141 |
+
print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.")
|
| 1142 |
+
except Exception as e:
|
| 1143 |
+
print(f"WARNING: Could not evaluate on internal dataset: {e}")
|
| 1144 |
+
|
| 1145 |
+
# --- Print summary ---
|
| 1146 |
+
print(f"\n{'=' * 70}")
|
| 1147 |
+
print("CATEGORY MODEL EVALUATION SUMMARY")
|
| 1148 |
+
print(f"{'=' * 70}")
|
| 1149 |
+
for dataset_key, label in [
|
| 1150 |
+
('fashion_mnist_gap', 'Fashion-MNIST (GAP-CLIP)'),
|
| 1151 |
+
('fashion_mnist_baseline', 'Fashion-MNIST (Baseline)'),
|
| 1152 |
+
('kaggle_gap', 'KAGL Marqo (GAP-CLIP)'),
|
| 1153 |
+
('kaggle_baseline', 'KAGL Marqo (Baseline)'),
|
| 1154 |
+
('local_gap', 'Internal (GAP-CLIP)'),
|
| 1155 |
+
('local_baseline', 'Internal (Baseline)'),
|
| 1156 |
+
]:
|
| 1157 |
+
if dataset_key not in all_results:
|
| 1158 |
+
continue
|
| 1159 |
+
res = all_results[dataset_key]
|
| 1160 |
+
print(f"\n{label}:")
|
| 1161 |
+
if 'text_hierarchy' in res:
|
| 1162 |
+
t = res['text_hierarchy']
|
| 1163 |
+
i = res['image_hierarchy']
|
| 1164 |
+
print(f" Text NN Acc: {t['nn_accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}")
|
| 1165 |
+
print(f" Image NN Acc: {i['nn_accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}")
|
| 1166 |
+
elif 'text' in res:
|
| 1167 |
+
t = res['text']['hierarchy']
|
| 1168 |
+
i = res['image']['hierarchy']
|
| 1169 |
+
print(f" Text NN Acc: {t['nn_accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}")
|
| 1170 |
+
print(f" Image NN Acc: {i['nn_accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}")
|
| 1171 |
+
|
| 1172 |
+
return all_results
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
# ============================================================================
|
| 1176 |
+
# 7. Main
|
| 1177 |
+
# ============================================================================
|
| 1178 |
+
|
| 1179 |
+
if __name__ == "__main__":
|
| 1180 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 1181 |
+
print(f"Using device: {device}")
|
| 1182 |
+
|
| 1183 |
+
directory = 'figures/confusion_matrices/cm_hierarchy'
|
| 1184 |
+
max_samples = 10000
|
| 1185 |
+
local_max_samples = 1000
|
| 1186 |
+
|
| 1187 |
+
evaluator = CategoryModelEvaluator(device=device, directory=directory)
|
| 1188 |
+
|
| 1189 |
+
# # Full evaluation including Fashion-MNIST and KAGL Marqo (skipped — CMs already generated)
|
| 1190 |
+
# evaluator.run_full_evaluation(max_samples=max_samples, local_max_samples=local_max_samples, batch_size=8)
|
| 1191 |
+
|
| 1192 |
+
# Evaluate only the local/internal dataset
|
| 1193 |
+
local_dataset = load_local_validation_with_hierarchy(
|
| 1194 |
+
max_samples=local_max_samples,
|
| 1195 |
+
hierarchy_classes=evaluator.validation_hierarchy_classes or evaluator.hierarchy_classes,
|
| 1196 |
+
)
|
| 1197 |
+
if local_dataset is not None and len(local_dataset) > 0:
|
| 1198 |
+
local_dl = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
|
| 1199 |
+
results_gap = evaluator.evaluate_gap_clip_generic(local_dl, "Internal", local_max_samples)
|
| 1200 |
+
results_base = evaluator.evaluate_baseline_generic(local_dl, "Internal", local_max_samples)
|
| 1201 |
+
|
| 1202 |
+
print(f"\n{'=' * 60}")
|
| 1203 |
+
print("INTERNAL DATASET — HIERARCHY EVALUATION SUMMARY")
|
| 1204 |
+
print(f"{'=' * 60}")
|
| 1205 |
+
print(f"\nGAP-CLIP:")
|
| 1206 |
+
print(f" Text NN Acc: {results_gap['text_hierarchy']['nn_accuracy']*100:.1f}% | Separation: {results_gap['text_hierarchy']['separation_score']:.4f}")
|
| 1207 |
+
print(f" Image NN Acc: {results_gap['image_hierarchy']['nn_accuracy']*100:.1f}% | Separation: {results_gap['image_hierarchy']['separation_score']:.4f}")
|
| 1208 |
+
print(f"\nBaseline:")
|
| 1209 |
+
print(f" Text NN Acc: {results_base['text']['hierarchy']['nn_accuracy']*100:.1f}% | Separation: {results_base['text']['hierarchy']['separation_score']:.4f}")
|
| 1210 |
+
print(f" Image NN Acc: {results_base['image']['hierarchy']['nn_accuracy']*100:.1f}% | Separation: {results_base['image']['hierarchy']['separation_score']:.4f}")
|
| 1211 |
+
else:
|
| 1212 |
+
print("WARNING: Local validation dataset empty after hierarchy filtering.")
|
evaluation/{main_model_evaluation.py → sec533_clip_nn_accuracy.py}
RENAMED
|
@@ -1,202 +1,67 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import pandas as pd
|
| 6 |
-
import numpy as np
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
import seaborn as sns
|
| 9 |
-
import difflib
|
| 10 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 11 |
-
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
|
| 12 |
-
from collections import defaultdict
|
| 13 |
-
from tqdm import tqdm
|
| 14 |
-
from torch.utils.data import Dataset, DataLoader
|
| 15 |
-
from torchvision import transforms
|
| 16 |
-
from PIL import Image
|
| 17 |
-
from io import BytesIO
|
| 18 |
-
import warnings
|
| 19 |
-
warnings.filterwarnings('ignore')
|
| 20 |
-
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 21 |
-
|
| 22 |
-
from config import main_model_path, hierarchy_model_path, color_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes):
|
| 26 |
-
"""Create mapping from Fashion-MNIST labels to hierarchy classes"""
|
| 27 |
-
# Fashion-MNIST labels
|
| 28 |
-
fashion_mnist_labels = {
|
| 29 |
-
0: "T-shirt/top",
|
| 30 |
-
1: "Trouser",
|
| 31 |
-
2: "Pullover",
|
| 32 |
-
3: "Dress",
|
| 33 |
-
4: "Coat",
|
| 34 |
-
5: "Sandal",
|
| 35 |
-
6: "Shirt",
|
| 36 |
-
7: "Sneaker",
|
| 37 |
-
8: "Bag",
|
| 38 |
-
9: "Ankle boot",
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
# Normalize hierarchy classes to lowercase for matching
|
| 42 |
-
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
|
| 43 |
-
|
| 44 |
-
# Create mapping dictionary
|
| 45 |
-
mapping = {}
|
| 46 |
-
|
| 47 |
-
for fm_label_id, fm_label in fashion_mnist_labels.items():
|
| 48 |
-
fm_label_lower = fm_label.lower()
|
| 49 |
-
matched_hierarchy = None
|
| 50 |
-
|
| 51 |
-
# Try exact match first
|
| 52 |
-
if fm_label_lower in hierarchy_classes_lower:
|
| 53 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)]
|
| 54 |
-
# Try partial matches
|
| 55 |
-
elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower):
|
| 56 |
-
for h_class in hierarchy_classes:
|
| 57 |
-
h_lower = h_class.lower()
|
| 58 |
-
if h_lower in fm_label_lower or fm_label_lower in h_lower:
|
| 59 |
-
matched_hierarchy = h_class
|
| 60 |
-
break
|
| 61 |
-
# Try semantic matching
|
| 62 |
-
else:
|
| 63 |
-
# T-shirt/top -> shirt or top
|
| 64 |
-
if fm_label_lower in ['t-shirt/top', 'top']:
|
| 65 |
-
if 'top' in hierarchy_classes_lower:
|
| 66 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('top')]
|
| 67 |
-
|
| 68 |
-
# Trouser -> bottom, pants, trousers
|
| 69 |
-
elif 'trouser' in fm_label_lower:
|
| 70 |
-
for possible in ['bottom', 'pants', 'trousers', 'trouser', 'pant']:
|
| 71 |
-
if possible in hierarchy_classes_lower:
|
| 72 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 73 |
-
break
|
| 74 |
-
|
| 75 |
-
# Pullover -> sweater
|
| 76 |
-
elif 'pullover' in fm_label_lower:
|
| 77 |
-
for possible in ['sweater', 'pullover']:
|
| 78 |
-
if possible in hierarchy_classes_lower:
|
| 79 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 80 |
-
break
|
| 81 |
-
|
| 82 |
-
# Dress -> dress
|
| 83 |
-
elif 'dress' in fm_label_lower:
|
| 84 |
-
if 'dress' in hierarchy_classes_lower:
|
| 85 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('dress')]
|
| 86 |
-
# Coat -> jacket, outerwear, coat
|
| 87 |
-
elif 'coat' in fm_label_lower:
|
| 88 |
-
for possible in ['jacket', 'outerwear', 'coat']:
|
| 89 |
-
if possible in hierarchy_classes_lower:
|
| 90 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 91 |
-
break
|
| 92 |
-
# Sandal, Sneaker, Ankle boot -> shoes, shoe
|
| 93 |
-
elif fm_label_lower in ['sandal', 'sneaker', 'ankle boot']:
|
| 94 |
-
for possible in ['shoes', 'shoe', 'sandal', 'sneaker', 'boot']:
|
| 95 |
-
if possible in hierarchy_classes_lower:
|
| 96 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
|
| 97 |
-
break
|
| 98 |
-
# Bag -> bag
|
| 99 |
-
elif 'bag' in fm_label_lower:
|
| 100 |
-
if 'bag' in hierarchy_classes_lower:
|
| 101 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('bag')]
|
| 102 |
-
|
| 103 |
-
if matched_hierarchy is None:
|
| 104 |
-
close_matches = difflib.get_close_matches(fm_label_lower, hierarchy_classes_lower, n=1, cutoff=0.6)
|
| 105 |
-
if close_matches:
|
| 106 |
-
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close_matches[0])]
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
print(f" {fm_label} ({fm_label_id}) -> {matched_hierarchy}")
|
| 111 |
-
else:
|
| 112 |
-
print(f" ⚠️ {fm_label} ({fm_label_id}) -> NO MATCH (will be filtered out)")
|
| 113 |
-
|
| 114 |
-
return mapping
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def convert_fashion_mnist_to_image(pixel_values):
|
| 118 |
-
image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
|
| 119 |
-
image_array = np.stack([image_array] * 3, axis=-1)
|
| 120 |
-
image = Image.fromarray(image_array)
|
| 121 |
-
return image
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def get_fashion_mnist_labels():
|
| 125 |
-
return {
|
| 126 |
-
0: "T-shirt/top",
|
| 127 |
-
1: "Trouser",
|
| 128 |
-
2: "Pullover",
|
| 129 |
-
3: "Dress",
|
| 130 |
-
4: "Coat",
|
| 131 |
-
5: "Sandal",
|
| 132 |
-
6: "Shirt",
|
| 133 |
-
7: "Sneaker",
|
| 134 |
-
8: "Bag",
|
| 135 |
-
9: "Ankle boot",
|
| 136 |
-
}
|
| 137 |
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
-
class
|
| 140 |
-
def __init__(self, dataframe, image_size=224, label_mapping=None):
|
| 141 |
-
self.dataframe = dataframe
|
| 142 |
-
self.image_size = image_size
|
| 143 |
-
self.labels_map = get_fashion_mnist_labels()
|
| 144 |
-
self.label_mapping = label_mapping # Mapping from Fashion-MNIST label ID to hierarchy class
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 150 |
-
])
|
| 151 |
|
| 152 |
-
|
| 153 |
-
return len(self.dataframe)
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
-
|
| 168 |
-
# Use mapped hierarchy if available, otherwise use original label
|
| 169 |
-
if self.label_mapping and label_id in self.label_mapping:
|
| 170 |
-
hierarchy = self.label_mapping[label_id]
|
| 171 |
-
else:
|
| 172 |
-
hierarchy = self.labels_map[label_id]
|
| 173 |
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
-
def load_fashion_mnist_dataset(max_samples=1000, hierarchy_classes=None):
|
| 178 |
-
print("📊 Loading Fashion-MNIST test dataset...")
|
| 179 |
-
df = pd.read_csv("/Users/leaattiasarfati/Desktop/docs/search/old/MainModel/data/fashion-mnist_test.csv")
|
| 180 |
-
print(f"✅ Fashion-MNIST dataset loaded: {len(df)} samples")
|
| 181 |
-
|
| 182 |
-
# Create mapping if hierarchy classes are provided
|
| 183 |
-
label_mapping = None
|
| 184 |
-
if hierarchy_classes is not None:
|
| 185 |
-
print("\n🔗 Creating mapping from Fashion-MNIST labels to hierarchy classes:")
|
| 186 |
-
label_mapping = create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes)
|
| 187 |
-
|
| 188 |
-
# Filter dataset to only include samples that can be mapped to hierarchy classes
|
| 189 |
-
valid_label_ids = [label_id for label_id, hierarchy in label_mapping.items() if hierarchy is not None]
|
| 190 |
-
df_filtered = df[df['label'].isin(valid_label_ids)]
|
| 191 |
-
print(f"\n📊 After filtering to mappable labels: {len(df_filtered)} samples (from {len(df)})")
|
| 192 |
-
|
| 193 |
-
# Apply max_samples limit after filtering
|
| 194 |
-
df_sample = df_filtered.head(max_samples)
|
| 195 |
-
else:
|
| 196 |
-
df_sample = df.head(max_samples)
|
| 197 |
-
|
| 198 |
-
print(f"📊 Using {len(df_sample)} samples for evaluation")
|
| 199 |
-
return FashionMNISTDataset(df_sample, label_mapping=label_mapping)
|
| 200 |
|
| 201 |
|
| 202 |
def create_kaggle_marqo_to_hierarchy_mapping(kaggle_labels, hierarchy_classes):
|
|
@@ -378,7 +243,7 @@ class KaggleDataset(Dataset):
|
|
| 378 |
return image, description, color, hierarchy
|
| 379 |
|
| 380 |
|
| 381 |
-
def load_kaggle_marqo_dataset(evaluator, max_samples=
|
| 382 |
"""Load and prepare Kaggle KAGL dataset with memory optimization"""
|
| 383 |
from datasets import load_dataset
|
| 384 |
print("📊 Loading Kaggle KAGL dataset...")
|
|
@@ -450,100 +315,6 @@ def load_kaggle_marqo_dataset(evaluator, max_samples=5000):
|
|
| 450 |
return KaggleDataset(kaggle_formatted)
|
| 451 |
|
| 452 |
|
| 453 |
-
class LocalDataset(Dataset):
|
| 454 |
-
"""Dataset class for local validation dataset"""
|
| 455 |
-
def __init__(self, dataframe, image_size=224):
|
| 456 |
-
self.dataframe = dataframe
|
| 457 |
-
self.image_size = image_size
|
| 458 |
-
|
| 459 |
-
# Transforms for validation (no augmentation)
|
| 460 |
-
self.val_transform = transforms.Compose([
|
| 461 |
-
transforms.Resize((image_size, image_size)),
|
| 462 |
-
transforms.ToTensor(),
|
| 463 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 464 |
-
])
|
| 465 |
-
|
| 466 |
-
def __len__(self):
|
| 467 |
-
return len(self.dataframe)
|
| 468 |
-
|
| 469 |
-
def __getitem__(self, idx):
|
| 470 |
-
row = self.dataframe.iloc[idx]
|
| 471 |
-
|
| 472 |
-
# Load image from local path
|
| 473 |
-
image_path = row[column_local_image_path]
|
| 474 |
-
try:
|
| 475 |
-
image = Image.open(image_path).convert("RGB")
|
| 476 |
-
except Exception as e:
|
| 477 |
-
print(f"Error loading image at index {idx} from {image_path}: {e}")
|
| 478 |
-
# Create a dummy image if loading fails
|
| 479 |
-
image = Image.new('RGB', (224, 224), color='gray')
|
| 480 |
-
|
| 481 |
-
# Apply validation transform
|
| 482 |
-
image = self.val_transform(image)
|
| 483 |
-
|
| 484 |
-
# Get text and labels
|
| 485 |
-
description = row['text']
|
| 486 |
-
color = row.get('color', 'unknown')
|
| 487 |
-
hierarchy = row['hierarchy']
|
| 488 |
-
|
| 489 |
-
return image, description, color, hierarchy
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
def load_local_validation_dataset(max_samples=5000):
|
| 493 |
-
"""Load and prepare local validation dataset"""
|
| 494 |
-
print("📊 Loading local validation dataset...")
|
| 495 |
-
|
| 496 |
-
if not os.path.exists(local_dataset_path):
|
| 497 |
-
print(f"❌ Local dataset file not found: {local_dataset_path}")
|
| 498 |
-
return None
|
| 499 |
-
|
| 500 |
-
df = pd.read_csv(local_dataset_path)
|
| 501 |
-
print(f"✅ Dataset loaded: {len(df)} samples")
|
| 502 |
-
|
| 503 |
-
# Filter out rows with NaN values in image path
|
| 504 |
-
df_clean = df.dropna(subset=[column_local_image_path])
|
| 505 |
-
print(f"📊 After filtering NaN image paths: {len(df_clean)} samples")
|
| 506 |
-
|
| 507 |
-
if len(df_clean) == 0:
|
| 508 |
-
print("❌ No valid samples after filtering.")
|
| 509 |
-
return None
|
| 510 |
-
|
| 511 |
-
# NO COLOR FILTERING for local dataset - keep all colors for comprehensive evaluation
|
| 512 |
-
if 'color' in df_clean.columns:
|
| 513 |
-
print(f"🎨 Total unique colors in dataset: {len(df_clean['color'].unique())}")
|
| 514 |
-
print(f"🎨 Colors found: {sorted(df_clean['color'].unique())}")
|
| 515 |
-
print(f"🎨 Color distribution (top 15):")
|
| 516 |
-
color_counts = df_clean['color'].value_counts()
|
| 517 |
-
for color in color_counts.index[:15]: # Show top 15 colors
|
| 518 |
-
print(f" {color}: {color_counts[color]} samples")
|
| 519 |
-
|
| 520 |
-
# Ensure we have required columns
|
| 521 |
-
required_cols = ['text', 'hierarchy']
|
| 522 |
-
missing_cols = [col for col in required_cols if col not in df_clean.columns]
|
| 523 |
-
if missing_cols:
|
| 524 |
-
print(f"❌ Missing required columns: {missing_cols}")
|
| 525 |
-
return None
|
| 526 |
-
|
| 527 |
-
# Limit to max_samples with RANDOM SAMPLING to get diverse colors
|
| 528 |
-
if len(df_clean) > max_samples:
|
| 529 |
-
df_clean = df_clean.sample(n=max_samples, random_state=42)
|
| 530 |
-
print(f"📊 Randomly sampled {max_samples} samples")
|
| 531 |
-
|
| 532 |
-
print(f"📊 Using {len(df_clean)} samples for evaluation")
|
| 533 |
-
print(f" Samples per hierarchy:")
|
| 534 |
-
for hierarchy in sorted(df_clean['hierarchy'].unique()):
|
| 535 |
-
count = len(df_clean[df_clean['hierarchy'] == hierarchy])
|
| 536 |
-
print(f" {hierarchy}: {count} samples")
|
| 537 |
-
|
| 538 |
-
# Show color distribution after sampling
|
| 539 |
-
if 'color' in df_clean.columns:
|
| 540 |
-
print(f"\n🎨 Color distribution in sampled data:")
|
| 541 |
-
color_counts = df_clean['color'].value_counts()
|
| 542 |
-
print(f" Total unique colors: {len(color_counts)}")
|
| 543 |
-
for color in color_counts.index[:15]: # Show top 15
|
| 544 |
-
print(f" {color}: {color_counts[color]} samples")
|
| 545 |
-
|
| 546 |
-
return LocalDataset(df_clean)
|
| 547 |
|
| 548 |
|
| 549 |
class ColorHierarchyEvaluator:
|
|
@@ -994,6 +765,7 @@ class ColorHierarchyEvaluator:
|
|
| 994 |
plt.tight_layout()
|
| 995 |
return plt.gcf(), accuracy, cm
|
| 996 |
|
|
|
|
| 997 |
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label",
|
| 998 |
full_embeddings=None, ensemble_weight=0.5):
|
| 999 |
"""
|
|
@@ -1010,16 +782,14 @@ class ColorHierarchyEvaluator:
|
|
| 1010 |
if full_embeddings is not None:
|
| 1011 |
# Use ensemble prediction
|
| 1012 |
predictions = self.predict_labels_ensemble(embeddings, full_embeddings, labels, ensemble_weight)
|
| 1013 |
-
title_suffix = f" (Ensemble: {ensemble_weight:.1f} specialized + {1-ensemble_weight:.1f} full)"
|
| 1014 |
else:
|
| 1015 |
# Use only specialized embeddings
|
| 1016 |
predictions = self.predict_labels_from_embeddings(embeddings, labels)
|
| 1017 |
-
title_suffix = ""
|
| 1018 |
|
| 1019 |
accuracy = accuracy_score(labels, predictions)
|
| 1020 |
fig, acc, cm = self.create_confusion_matrix(
|
| 1021 |
labels, predictions,
|
| 1022 |
-
f"{
|
| 1023 |
label_type
|
| 1024 |
)
|
| 1025 |
unique_labels = sorted(list(set(labels)))
|
|
@@ -1346,7 +1116,7 @@ class ColorHierarchyEvaluator:
|
|
| 1346 |
|
| 1347 |
return results
|
| 1348 |
|
| 1349 |
-
def evaluate_baseline_fashion_mnist(self, max_samples=
|
| 1350 |
"""Evaluate baseline Fashion CLIP model on Fashion-MNIST"""
|
| 1351 |
print(f"\n{'='*60}")
|
| 1352 |
print("Evaluating Baseline Fashion CLIP on Fashion-MNIST")
|
|
@@ -1418,7 +1188,7 @@ class ColorHierarchyEvaluator:
|
|
| 1418 |
|
| 1419 |
return results
|
| 1420 |
|
| 1421 |
-
def evaluate_baseline_kaggle_marqo(self, max_samples=
|
| 1422 |
"""Evaluate baseline Fashion CLIP model on KAGL Marqo dataset"""
|
| 1423 |
print(f"\n{'='*60}")
|
| 1424 |
print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset")
|
|
@@ -1500,7 +1270,7 @@ class ColorHierarchyEvaluator:
|
|
| 1500 |
|
| 1501 |
return results
|
| 1502 |
|
| 1503 |
-
def evaluate_baseline_local_validation(self, max_samples=
|
| 1504 |
"""Evaluate baseline Fashion CLIP model on local validation dataset"""
|
| 1505 |
print(f"\n{'='*60}")
|
| 1506 |
print("Evaluating Baseline Fashion CLIP on Local Validation Dataset")
|
|
@@ -1598,7 +1368,7 @@ if __name__ == "__main__":
|
|
| 1598 |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 1599 |
print(f"Using device: {device}")
|
| 1600 |
|
| 1601 |
-
directory = '
|
| 1602 |
max_samples = 10000
|
| 1603 |
|
| 1604 |
evaluator = ColorHierarchyEvaluator(device=device, directory=directory)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
§5.3.3 Nearest-Neighbour Classification Accuracy (Table 3)
|
| 3 |
+
============================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
Evaluates the full GAP-CLIP embedding on three datasets and compares with the
|
| 6 |
+
patrickjohncyh/fashion-clip baseline:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
- Fashion-MNIST (public benchmark, 10 clothing categories)
|
| 9 |
+
- KAGL Marqo HuggingFace dataset (diverse fashion, colour + category labels)
|
| 10 |
+
- Internal local validation set (50 k images)
|
| 11 |
|
| 12 |
+
For each dataset the ``ColorHierarchyEvaluator`` class extracts:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
* **Color slice** (dims 0–15): nearest-neighbour and centroid accuracy per colour class.
|
| 15 |
+
* **Hierarchy slice** (dims 16–79): nearest-neighbour and centroid accuracy per category.
|
| 16 |
+
* **Ensemble mode** (Kaggle/MNIST): sliced dims combined with full 512-D embedding.
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
Results feed directly into **Table 3** of the paper.
|
|
|
|
| 19 |
|
| 20 |
+
See also:
|
| 21 |
+
- §5.1 (``sec51_color_model_eval.py``) – standalone colour model
|
| 22 |
+
- §5.2 (``sec52_category_model_eval.py``) – confusion-matrix analysis
|
| 23 |
+
- §5.3.4–5 (``sec5354_separation_semantic.py``) – separation scores
|
| 24 |
+
"""
|
| 25 |
+
import os
|
| 26 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 27 |
|
| 28 |
+
import difflib
|
| 29 |
+
import warnings
|
| 30 |
|
| 31 |
+
import matplotlib.pyplot as plt
|
| 32 |
+
import numpy as np
|
| 33 |
+
import pandas as pd
|
| 34 |
+
import seaborn as sns
|
| 35 |
+
import torch
|
| 36 |
+
from collections import defaultdict
|
| 37 |
+
from io import BytesIO
|
| 38 |
|
| 39 |
+
from PIL import Image
|
| 40 |
+
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
| 41 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 42 |
+
from torch.utils.data import DataLoader, Dataset
|
| 43 |
+
from torchvision import transforms
|
| 44 |
+
from tqdm import tqdm
|
| 45 |
+
from transformers import CLIPModel as CLIPModel_transformers, CLIPProcessor
|
| 46 |
|
| 47 |
+
warnings.filterwarnings('ignore')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
from config import (
|
| 50 |
+
color_emb_dim,
|
| 51 |
+
column_local_image_path,
|
| 52 |
+
hierarchy_emb_dim,
|
| 53 |
+
hierarchy_model_path,
|
| 54 |
+
local_dataset_path,
|
| 55 |
+
main_model_path,
|
| 56 |
+
)
|
| 57 |
+
from utils.datasets import (
|
| 58 |
+
FashionMNISTDataset,
|
| 59 |
+
LocalDataset,
|
| 60 |
+
load_fashion_mnist_dataset,
|
| 61 |
+
load_local_validation_dataset,
|
| 62 |
+
)
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def create_kaggle_marqo_to_hierarchy_mapping(kaggle_labels, hierarchy_classes):
|
|
|
|
| 243 |
return image, description, color, hierarchy
|
| 244 |
|
| 245 |
|
| 246 |
+
def load_kaggle_marqo_dataset(evaluator, max_samples=10000):
|
| 247 |
"""Load and prepare Kaggle KAGL dataset with memory optimization"""
|
| 248 |
from datasets import load_dataset
|
| 249 |
print("📊 Loading Kaggle KAGL dataset...")
|
|
|
|
| 315 |
return KaggleDataset(kaggle_formatted)
|
| 316 |
|
| 317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
|
| 320 |
class ColorHierarchyEvaluator:
|
|
|
|
| 765 |
plt.tight_layout()
|
| 766 |
return plt.gcf(), accuracy, cm
|
| 767 |
|
| 768 |
+
|
| 769 |
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label",
|
| 770 |
full_embeddings=None, ensemble_weight=0.5):
|
| 771 |
"""
|
|
|
|
| 782 |
if full_embeddings is not None:
|
| 783 |
# Use ensemble prediction
|
| 784 |
predictions = self.predict_labels_ensemble(embeddings, full_embeddings, labels, ensemble_weight)
|
|
|
|
| 785 |
else:
|
| 786 |
# Use only specialized embeddings
|
| 787 |
predictions = self.predict_labels_from_embeddings(embeddings, labels)
|
|
|
|
| 788 |
|
| 789 |
accuracy = accuracy_score(labels, predictions)
|
| 790 |
fig, acc, cm = self.create_confusion_matrix(
|
| 791 |
labels, predictions,
|
| 792 |
+
f"{label_type} Classification",
|
| 793 |
label_type
|
| 794 |
)
|
| 795 |
unique_labels = sorted(list(set(labels)))
|
|
|
|
| 1116 |
|
| 1117 |
return results
|
| 1118 |
|
| 1119 |
+
def evaluate_baseline_fashion_mnist(self, max_samples=10000):
|
| 1120 |
"""Evaluate baseline Fashion CLIP model on Fashion-MNIST"""
|
| 1121 |
print(f"\n{'='*60}")
|
| 1122 |
print("Evaluating Baseline Fashion CLIP on Fashion-MNIST")
|
|
|
|
| 1188 |
|
| 1189 |
return results
|
| 1190 |
|
| 1191 |
+
def evaluate_baseline_kaggle_marqo(self, max_samples=10000):
|
| 1192 |
"""Evaluate baseline Fashion CLIP model on KAGL Marqo dataset"""
|
| 1193 |
print(f"\n{'='*60}")
|
| 1194 |
print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset")
|
|
|
|
| 1270 |
|
| 1271 |
return results
|
| 1272 |
|
| 1273 |
+
def evaluate_baseline_local_validation(self, max_samples=10000):
|
| 1274 |
"""Evaluate baseline Fashion CLIP model on local validation dataset"""
|
| 1275 |
print(f"\n{'='*60}")
|
| 1276 |
print("Evaluating Baseline Fashion CLIP on Local Validation Dataset")
|
|
|
|
| 1368 |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 1369 |
print(f"Using device: {device}")
|
| 1370 |
|
| 1371 |
+
directory = 'figures/confusion_matrices'
|
| 1372 |
max_samples = 10000
|
| 1373 |
|
| 1374 |
evaluator = ColorHierarchyEvaluator(device=device, directory=directory)
|
evaluation/sec5354_separation_semantic.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Sections 5.3.4 + 5.3.5 — Separation Score Analysis and Semantic Evaluation
|
| 4 |
+
===========================================================================
|
| 5 |
+
|
| 6 |
+
Section 5.3.4: Separation score analysis on GAP-CLIP full embeddings vs baseline
|
| 7 |
+
across three datasets (reported in paper body; detailed scores in main evaluation).
|
| 8 |
+
|
| 9 |
+
Section 5.3.5: Zero-shot semantic evaluation comparing simple vs. extended text
|
| 10 |
+
descriptions. Three evaluation modes on the internal dataset:
|
| 11 |
+
|
| 12 |
+
(a) Color-only encoding (control): encodes only the color name — tests whether
|
| 13 |
+
the embedding space is consistent for colors.
|
| 14 |
+
(b) Text-to-text classification: encodes the full item description and finds
|
| 15 |
+
the nearest color label in embedding space.
|
| 16 |
+
(c) Image-to-text classification: encodes the item image and finds the nearest
|
| 17 |
+
color label in embedding space.
|
| 18 |
+
|
| 19 |
+
The 40%+ performance gap between GAP-CLIP and baseline on extended descriptions
|
| 20 |
+
(Annex 9.7) demonstrates that the dedicated color/hierarchy subspaces act as
|
| 21 |
+
semantic anchors under verbose, multi-attribute text inputs.
|
| 22 |
+
|
| 23 |
+
Run directly:
|
| 24 |
+
python sec5354_separation_semantic.py
|
| 25 |
+
|
| 26 |
+
Paper reference: Sections 5.3.4 and 5.3.5.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import os
|
| 32 |
+
import sys
|
| 33 |
+
import warnings
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
import numpy as np
|
| 38 |
+
import pandas as pd
|
| 39 |
+
import seaborn as sns
|
| 40 |
+
import torch
|
| 41 |
+
import torch.nn.functional as F
|
| 42 |
+
from PIL import Image
|
| 43 |
+
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
| 44 |
+
from torch.utils.data import Dataset
|
| 45 |
+
from torchvision import transforms
|
| 46 |
+
from tqdm import tqdm
|
| 47 |
+
|
| 48 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 49 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 50 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 51 |
+
|
| 52 |
+
# Ensure project root is importable when running this file directly.
|
| 53 |
+
_PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 54 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 55 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 56 |
+
|
| 57 |
+
import config
|
| 58 |
+
from evaluation.utils.model_loader import load_gap_clip, get_text_embedding, get_image_embedding
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# Dataset
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
class CustomCSVDataset(Dataset):
|
| 66 |
+
"""Dataset backed by a local CSV; optionally loads images from disk.
|
| 67 |
+
|
| 68 |
+
Each item returns (image_tensor, text, color).
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, dataframe: pd.DataFrame, image_size: int = 224, load_images: bool = True):
|
| 72 |
+
self.dataframe = dataframe
|
| 73 |
+
self.image_size = image_size
|
| 74 |
+
self.load_images = load_images
|
| 75 |
+
|
| 76 |
+
self.transform = transforms.Compose([
|
| 77 |
+
transforms.Resize((image_size, image_size)),
|
| 78 |
+
transforms.ToTensor(),
|
| 79 |
+
transforms.Normalize(
|
| 80 |
+
mean=[0.48145466, 0.4578275, 0.40821073],
|
| 81 |
+
std=[0.26862954, 0.26130258, 0.27577711],
|
| 82 |
+
),
|
| 83 |
+
])
|
| 84 |
+
|
| 85 |
+
def __len__(self) -> int:
|
| 86 |
+
return len(self.dataframe)
|
| 87 |
+
|
| 88 |
+
def __getitem__(self, idx):
|
| 89 |
+
row = self.dataframe.iloc[idx]
|
| 90 |
+
text = row[config.text_column]
|
| 91 |
+
color = row[config.color_column]
|
| 92 |
+
|
| 93 |
+
if self.load_images and config.column_local_image_path in row:
|
| 94 |
+
try:
|
| 95 |
+
image = Image.open(row[config.column_local_image_path]).convert("RGB")
|
| 96 |
+
image = self.transform(image)
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"Warning: could not load image {row.get(config.column_local_image_path, 'unknown')}: {e}")
|
| 99 |
+
image = torch.zeros(3, self.image_size, self.image_size)
|
| 100 |
+
else:
|
| 101 |
+
image = torch.zeros(3, self.image_size, self.image_size)
|
| 102 |
+
|
| 103 |
+
return image, text, color
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# Evaluation functions (Section 5.3.5)
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
def evaluate_color_only_zero_shot(model, dataset, processor):
|
| 111 |
+
"""Control test: encode ONLY the color name (not the full text description).
|
| 112 |
+
|
| 113 |
+
Tests whether the embedding space is consistent for color tokens regardless
|
| 114 |
+
of surrounding context.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
(true_labels, predicted_labels, accuracy)
|
| 118 |
+
"""
|
| 119 |
+
print("\n=== Section 5.3.5 (a): Color-Only Encoding — Control Test ===")
|
| 120 |
+
print("Encodes ONLY the color name, not the full product description.")
|
| 121 |
+
|
| 122 |
+
model.eval()
|
| 123 |
+
|
| 124 |
+
all_colors = sorted({dataset[i][2] for i in range(len(dataset))})
|
| 125 |
+
print(f"Colors found: {all_colors}")
|
| 126 |
+
|
| 127 |
+
color_embeddings = {
|
| 128 |
+
c: get_text_embedding(model, processor, config.device, c)
|
| 129 |
+
for c in all_colors
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
true_labels, predicted_labels = [], []
|
| 133 |
+
correct = 0
|
| 134 |
+
|
| 135 |
+
for idx in tqdm(range(len(dataset)), desc="Evaluating (color-only)"):
|
| 136 |
+
_, _, true_color = dataset[idx]
|
| 137 |
+
true_color_emb = get_text_embedding(model, processor, config.device, true_color)
|
| 138 |
+
|
| 139 |
+
best_sim = -1.0
|
| 140 |
+
predicted_color = all_colors[0]
|
| 141 |
+
for color, emb in color_embeddings.items():
|
| 142 |
+
sim = F.cosine_similarity(true_color_emb.unsqueeze(0), emb.unsqueeze(0), dim=1).item()
|
| 143 |
+
if sim > best_sim:
|
| 144 |
+
best_sim, predicted_color = sim, color
|
| 145 |
+
|
| 146 |
+
true_labels.append(true_color)
|
| 147 |
+
predicted_labels.append(predicted_color)
|
| 148 |
+
if true_color == predicted_color:
|
| 149 |
+
correct += 1
|
| 150 |
+
|
| 151 |
+
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 152 |
+
print(f"Color-only accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
|
| 153 |
+
print(f"Correct: {correct}/{len(true_labels)}")
|
| 154 |
+
return true_labels, predicted_labels, accuracy
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def evaluate_text_to_text_zero_shot(model, dataset, processor):
|
| 158 |
+
"""Text-to-text classification: compare full product description against color labels.
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
(true_labels, predicted_labels, accuracy)
|
| 162 |
+
"""
|
| 163 |
+
print("\n=== Section 5.3.5 (b): Text-to-Text Classification ===")
|
| 164 |
+
|
| 165 |
+
model.eval()
|
| 166 |
+
|
| 167 |
+
all_colors = sorted({dataset[i][2] for i in range(len(dataset))})
|
| 168 |
+
print(f"Colors found: {all_colors}")
|
| 169 |
+
|
| 170 |
+
color_embeddings = {
|
| 171 |
+
c: get_text_embedding(model, processor, config.device, c)
|
| 172 |
+
for c in all_colors
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
true_labels, predicted_labels = [], []
|
| 176 |
+
correct = 0
|
| 177 |
+
|
| 178 |
+
for idx in tqdm(range(len(dataset)), desc="Evaluating (text-to-text)"):
|
| 179 |
+
_, text, true_color = dataset[idx]
|
| 180 |
+
text_emb = get_text_embedding(model, processor, config.device, text)
|
| 181 |
+
|
| 182 |
+
best_sim = -1.0
|
| 183 |
+
predicted_color = all_colors[0]
|
| 184 |
+
for color, emb in color_embeddings.items():
|
| 185 |
+
sim = F.cosine_similarity(text_emb.unsqueeze(0), emb.unsqueeze(0), dim=1).item()
|
| 186 |
+
if sim > best_sim:
|
| 187 |
+
best_sim, predicted_color = sim, color
|
| 188 |
+
|
| 189 |
+
true_labels.append(true_color)
|
| 190 |
+
predicted_labels.append(predicted_color)
|
| 191 |
+
if true_color == predicted_color:
|
| 192 |
+
correct += 1
|
| 193 |
+
|
| 194 |
+
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 195 |
+
print(f"Text-to-text accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
|
| 196 |
+
print(f"Correct: {correct}/{len(true_labels)}")
|
| 197 |
+
return true_labels, predicted_labels, accuracy
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def evaluate_image_to_text_zero_shot(model, dataset, processor):
|
| 201 |
+
"""Image-to-text classification: compare image embedding against color labels.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
(true_labels, predicted_labels, accuracy)
|
| 205 |
+
"""
|
| 206 |
+
print("\n=== Section 5.3.5 (c): Image-to-Text Classification ===")
|
| 207 |
+
|
| 208 |
+
model.eval()
|
| 209 |
+
|
| 210 |
+
all_colors = sorted({dataset[i][2] for i in range(len(dataset))})
|
| 211 |
+
print(f"Colors found: {all_colors}")
|
| 212 |
+
|
| 213 |
+
color_embeddings = {
|
| 214 |
+
c: get_text_embedding(model, processor, config.device, c)
|
| 215 |
+
for c in all_colors
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
true_labels, predicted_labels = [], []
|
| 219 |
+
correct = 0
|
| 220 |
+
|
| 221 |
+
for idx in tqdm(range(len(dataset)), desc="Evaluating (image-to-text)"):
|
| 222 |
+
image, _, true_color = dataset[idx]
|
| 223 |
+
image_emb = get_image_embedding(model, image, config.device)
|
| 224 |
+
|
| 225 |
+
best_sim = -1.0
|
| 226 |
+
predicted_color = all_colors[0]
|
| 227 |
+
for color, emb in color_embeddings.items():
|
| 228 |
+
sim = F.cosine_similarity(image_emb, emb.unsqueeze(0), dim=1).item()
|
| 229 |
+
if sim > best_sim:
|
| 230 |
+
best_sim, predicted_color = sim, color
|
| 231 |
+
|
| 232 |
+
true_labels.append(true_color)
|
| 233 |
+
predicted_labels.append(predicted_color)
|
| 234 |
+
if true_color == predicted_color:
|
| 235 |
+
correct += 1
|
| 236 |
+
|
| 237 |
+
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 238 |
+
print(f"Image-to-text accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
|
| 239 |
+
print(f"Correct: {correct}/{len(true_labels)}")
|
| 240 |
+
return true_labels, predicted_labels, accuracy
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# ---------------------------------------------------------------------------
|
| 244 |
+
# Plotting
|
| 245 |
+
# ---------------------------------------------------------------------------
|
| 246 |
+
|
| 247 |
+
def plot_confusion_matrix(
|
| 248 |
+
true_labels,
|
| 249 |
+
predicted_labels,
|
| 250 |
+
save_path=None,
|
| 251 |
+
title_suffix: str = "text",
|
| 252 |
+
):
|
| 253 |
+
"""Generate and optionally save a percentage-based confusion matrix."""
|
| 254 |
+
print("\n=== Generating Confusion Matrix ===")
|
| 255 |
+
|
| 256 |
+
cm = confusion_matrix(true_labels, predicted_labels)
|
| 257 |
+
unique_labels = sorted(set(true_labels + predicted_labels))
|
| 258 |
+
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 259 |
+
|
| 260 |
+
cm_percent = np.round(cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] * 100).astype(int)
|
| 261 |
+
|
| 262 |
+
plt.figure(figsize=(12, 10))
|
| 263 |
+
sns.heatmap(
|
| 264 |
+
cm_percent,
|
| 265 |
+
annot=True,
|
| 266 |
+
fmt="d",
|
| 267 |
+
cmap="Blues",
|
| 268 |
+
cbar_kws={"label": "Percentage (%)"},
|
| 269 |
+
xticklabels=unique_labels,
|
| 270 |
+
yticklabels=unique_labels,
|
| 271 |
+
)
|
| 272 |
+
plt.title(
|
| 273 |
+
f"Confusion Matrix — {title_suffix} | accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)",
|
| 274 |
+
fontsize=16,
|
| 275 |
+
)
|
| 276 |
+
plt.xlabel("Predictions", fontsize=12)
|
| 277 |
+
plt.ylabel("True colors", fontsize=12)
|
| 278 |
+
plt.xticks(rotation=45, ha="right")
|
| 279 |
+
plt.yticks(rotation=0)
|
| 280 |
+
plt.tight_layout()
|
| 281 |
+
|
| 282 |
+
if save_path:
|
| 283 |
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 284 |
+
print(f"Saved: {save_path}")
|
| 285 |
+
|
| 286 |
+
plt.show()
|
| 287 |
+
return cm
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# ---------------------------------------------------------------------------
|
| 291 |
+
# Entry point
|
| 292 |
+
# ---------------------------------------------------------------------------
|
| 293 |
+
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
print("=== GAP-CLIP: Sections 5.3.4 + 5.3.5 — Semantic Evaluation ===")
|
| 296 |
+
|
| 297 |
+
model, processor = load_gap_clip(config.main_model_path, config.device)
|
| 298 |
+
|
| 299 |
+
df = pd.read_csv(config.local_dataset_path)
|
| 300 |
+
|
| 301 |
+
print("\n" + "=" * 80)
|
| 302 |
+
print("(a) COLOR-TO-COLOR CLASSIFICATION — Control Test")
|
| 303 |
+
print("=" * 80)
|
| 304 |
+
dataset_color = CustomCSVDataset(df, load_images=False)
|
| 305 |
+
true_c, pred_c, acc_c = evaluate_color_only_zero_shot(model, dataset_color, processor)
|
| 306 |
+
plot_confusion_matrix(true_c, pred_c, save_path="confusion_matrix_color_only.png", title_suffix="color-only")
|
| 307 |
+
|
| 308 |
+
print("\n" + "=" * 80)
|
| 309 |
+
print("(b) TEXT-TO-TEXT CLASSIFICATION")
|
| 310 |
+
print("=" * 80)
|
| 311 |
+
dataset_text = CustomCSVDataset(df, load_images=False)
|
| 312 |
+
true_t, pred_t, acc_t = evaluate_text_to_text_zero_shot(model, dataset_text, processor)
|
| 313 |
+
plot_confusion_matrix(true_t, pred_t, save_path="confusion_matrix_text.png", title_suffix="text")
|
| 314 |
+
|
| 315 |
+
print("\n" + "=" * 80)
|
| 316 |
+
print("(c) IMAGE-TO-TEXT CLASSIFICATION")
|
| 317 |
+
print("=" * 80)
|
| 318 |
+
dataset_image = CustomCSVDataset(df, load_images=True)
|
| 319 |
+
true_i, pred_i, acc_i = evaluate_image_to_text_zero_shot(model, dataset_image, processor)
|
| 320 |
+
plot_confusion_matrix(true_i, pred_i, save_path="confusion_matrix_image.png", title_suffix="image")
|
| 321 |
+
|
| 322 |
+
print("\n" + "=" * 80)
|
| 323 |
+
print("SUMMARY — Section 5.3.5")
|
| 324 |
+
print("=" * 80)
|
| 325 |
+
print(f"(a) Color-only (control): {acc_c:.4f} ({acc_c * 100:.2f}%)")
|
| 326 |
+
print(f"(b) Text-to-text: {acc_t:.4f} ({acc_t * 100:.2f}%)")
|
| 327 |
+
print(f"(c) Image-to-text: {acc_i:.4f} ({acc_i * 100:.2f}%)")
|
| 328 |
+
print(f"\nLoss from color-only vs text: {abs(acc_c - acc_t):.4f}")
|
| 329 |
+
print(f"Difference text vs image: {abs(acc_t - acc_i):.4f}")
|
evaluation/sec536_embedding_structure.py
ADDED
|
@@ -0,0 +1,1460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Section 5.3.6 — Embedding Structure Evaluation
|
| 4 |
+
===============================================
|
| 5 |
+
|
| 6 |
+
Verifies that the GAP-CLIP embedding subspaces encode the attributes they are
|
| 7 |
+
designed for, and tests zero-shot vision-language alignment.
|
| 8 |
+
|
| 9 |
+
Test A — Different colors, same hierarchy:
|
| 10 |
+
The 64D hierarchy subspace should be MORE similar between two items that
|
| 11 |
+
share a category but differ in color, compared to the 16D color subspace.
|
| 12 |
+
Expected result: 1000/1000 pass.
|
| 13 |
+
"correlation between the color slice are low and the correlation between the category part are high"
|
| 14 |
+
|
| 15 |
+
Test B — Same color, different hierarchies:
|
| 16 |
+
The 16D color subspace should be MORE similar than the full 512D embedding
|
| 17 |
+
for items sharing a color but differing in category.
|
| 18 |
+
Expected result: 1000/1000 pass.
|
| 19 |
+
|
| 20 |
+
Test C1 — Zero-shot image-to-text classification:
|
| 21 |
+
Each image is used as a query; the highest-scoring text label (cosine in
|
| 22 |
+
shared latent space) is the predicted class. Accuracy is computed across
|
| 23 |
+
three datasets (Fashion-MNIST, KAGL Marqo, Internal).
|
| 24 |
+
|
| 25 |
+
Test C2 — Zero-shot text-to-image retrieval:
|
| 26 |
+
Each text label queries all image embeddings; retrieval is correct when the
|
| 27 |
+
top-1 returned image belongs to the queried label.
|
| 28 |
+
|
| 29 |
+
Paper reference: Section 5.3.6 and Table 4.
|
| 30 |
+
|
| 31 |
+
Run directly:
|
| 32 |
+
python sec536_embedding_structure.py --tests AB # only tests A+B
|
| 33 |
+
python sec536_embedding_structure.py --tests ABC # all tests
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import os
|
| 40 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 41 |
+
|
| 42 |
+
from dataclasses import dataclass
|
| 43 |
+
from pathlib import Path
|
| 44 |
+
import random
|
| 45 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 46 |
+
|
| 47 |
+
import numpy as np
|
| 48 |
+
import pandas as pd
|
| 49 |
+
import requests
|
| 50 |
+
import torch
|
| 51 |
+
import torch.nn.functional as F
|
| 52 |
+
from io import BytesIO
|
| 53 |
+
from PIL import Image
|
| 54 |
+
from torchvision import transforms
|
| 55 |
+
from transformers import CLIPModel as CLIPModelTransformers
|
| 56 |
+
from transformers import CLIPProcessor
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class RuntimeConfig:
|
| 61 |
+
color_emb_dim: int = 16
|
| 62 |
+
hierarchy_emb_dim: int = 64
|
| 63 |
+
main_model_path: str = "models/gap_clip.pth"
|
| 64 |
+
device: torch.device = torch.device("cpu")
|
| 65 |
+
|
| 66 |
+
DEFAULT_NUM_EXAMPLES = 1000
|
| 67 |
+
DEFAULT_NUM_PRINTED = 3
|
| 68 |
+
|
| 69 |
+
COLORS = [
|
| 70 |
+
"yellow", "blue", "red", "green", "black", "white", "pink", "purple", "brown", "orange",
|
| 71 |
+
]
|
| 72 |
+
HIERARCHIES = [
|
| 73 |
+
"dress", "shirt", "pants", "skirt", "jacket", "coat", "jeans", "sweater", "shorts", "top",
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
LONG_TEXT_TEMPLATES = [
|
| 78 |
+
"{color} {hierarchy}",
|
| 79 |
+
"{color} {hierarchy} with buttons",
|
| 80 |
+
"{color} {hierarchy} in cotton",
|
| 81 |
+
"casual {color} {hierarchy} for women",
|
| 82 |
+
"elegant {color} {hierarchy} with pockets",
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def build_text_query(color: str, hierarchy: str) -> str:
|
| 87 |
+
template = random.choice(LONG_TEXT_TEMPLATES)
|
| 88 |
+
return template.format(color=color, hierarchy=hierarchy)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def resolve_runtime_config() -> RuntimeConfig:
|
| 92 |
+
"""Resolve config from local config.py if available, else use defaults."""
|
| 93 |
+
cfg = RuntimeConfig()
|
| 94 |
+
try:
|
| 95 |
+
import config # type: ignore
|
| 96 |
+
|
| 97 |
+
cfg.color_emb_dim = getattr(config, "color_emb_dim", cfg.color_emb_dim)
|
| 98 |
+
cfg.hierarchy_emb_dim = getattr(config, "hierarchy_emb_dim", cfg.hierarchy_emb_dim)
|
| 99 |
+
cfg.main_model_path = getattr(config, "main_model_path", cfg.main_model_path)
|
| 100 |
+
cfg.device = getattr(config, "device", cfg.device)
|
| 101 |
+
except Exception:
|
| 102 |
+
if torch.cuda.is_available():
|
| 103 |
+
cfg.device = torch.device("cuda")
|
| 104 |
+
elif torch.backends.mps.is_available():
|
| 105 |
+
cfg.device = torch.device("mps")
|
| 106 |
+
else:
|
| 107 |
+
cfg.device = torch.device("cpu")
|
| 108 |
+
|
| 109 |
+
return cfg
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_main_model(device: torch.device, main_model_path: str) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
|
| 113 |
+
"""Load GAP-CLIP (LAION CLIP + finetuned checkpoint) and processor.
|
| 114 |
+
|
| 115 |
+
Delegates to utils.model_loader.load_gap_clip for consistent loading.
|
| 116 |
+
"""
|
| 117 |
+
from evaluation.utils.model_loader import load_gap_clip # type: ignore
|
| 118 |
+
return load_gap_clip(main_model_path, device)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_text_embedding(
|
| 122 |
+
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, text: str
|
| 123 |
+
) -> torch.Tensor:
|
| 124 |
+
"""Extract normalized text embedding for a single query."""
|
| 125 |
+
text_inputs = processor(text=[text], padding=True, return_tensors="pt")
|
| 126 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 127 |
+
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
text_outputs = model.text_model(**text_inputs)
|
| 130 |
+
text_features = model.text_projection(text_outputs.pooler_output)
|
| 131 |
+
text_features = F.normalize(text_features, dim=-1)
|
| 132 |
+
|
| 133 |
+
return text_features.squeeze(0)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def cosine(a: torch.Tensor, b: torch.Tensor) -> float:
|
| 137 |
+
return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0), dim=1).item()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def delta_percent(reference: float, value: float) -> float:
|
| 141 |
+
"""Relative delta in percent: (value-reference)/|reference|*100."""
|
| 142 |
+
denom = max(abs(reference), 1e-8)
|
| 143 |
+
return ((value - reference) / denom) * 100.0
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def format_bool(ok: bool) -> str:
|
| 147 |
+
return "PASS" if ok else "FAIL"
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def print_table(title: str, headers: List[str], rows: List[List[str]]) -> None:
|
| 151 |
+
print("\n" + "=" * 120)
|
| 152 |
+
print(title)
|
| 153 |
+
print("=" * 120)
|
| 154 |
+
all_rows = [headers] + rows
|
| 155 |
+
col_widths = [max(len(str(r[i])) for r in all_rows) for i in range(len(headers))]
|
| 156 |
+
|
| 157 |
+
def fmt(row: List[str]) -> str:
|
| 158 |
+
return " | ".join(str(v).ljust(col_widths[i]) for i, v in enumerate(row))
|
| 159 |
+
|
| 160 |
+
print(fmt(headers))
|
| 161 |
+
print("-" * (sum(col_widths) + 3 * (len(headers) - 1)))
|
| 162 |
+
for row in rows:
|
| 163 |
+
print(fmt(row))
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def run_test_a(
|
| 167 |
+
model: CLIPModelTransformers,
|
| 168 |
+
processor: CLIPProcessor,
|
| 169 |
+
cfg: RuntimeConfig,
|
| 170 |
+
num_examples: int,
|
| 171 |
+
num_printed: int) -> Dict[str, bool]:
|
| 172 |
+
"""
|
| 173 |
+
A: different colors + same hierarchy.
|
| 174 |
+
Expect hierarchy subspace to be more similar than color subspace.
|
| 175 |
+
"""
|
| 176 |
+
positive_pairs: List[Tuple[str, str]] = []
|
| 177 |
+
negative_pairs: List[Tuple[str, str]] = []
|
| 178 |
+
for _ in range(num_examples):
|
| 179 |
+
hierarchy = random.choice(HIERARCHIES)
|
| 180 |
+
c1, c2 = random.sample(COLORS, 2)
|
| 181 |
+
negative_hierarchy = random.choice([h for h in HIERARCHIES if h != hierarchy])
|
| 182 |
+
positive_pairs.append((build_text_query(c1, hierarchy), build_text_query(c2, hierarchy)))
|
| 183 |
+
negative_pairs.append((build_text_query(c1, hierarchy), build_text_query(c2, negative_hierarchy)))
|
| 184 |
+
|
| 185 |
+
rows: List[List[str]] = []
|
| 186 |
+
pair_outcomes: List[bool] = []
|
| 187 |
+
full512_outcomes: List[bool] = []
|
| 188 |
+
hier_gt_full_outcomes: List[bool] = []
|
| 189 |
+
hier_gt_color_outcomes: List[bool] = []
|
| 190 |
+
delta_color_vs_full_values: List[float] = []
|
| 191 |
+
delta_hier_vs_full_values: List[float] = []
|
| 192 |
+
|
| 193 |
+
for (left, right), (_, negative_right) in zip(positive_pairs, negative_pairs):
|
| 194 |
+
emb_left = get_text_embedding(model, processor, cfg.device, left)
|
| 195 |
+
emb_right = get_text_embedding(model, processor, cfg.device, right)
|
| 196 |
+
emb_negative_right = get_text_embedding(model, processor, cfg.device, negative_right)
|
| 197 |
+
|
| 198 |
+
left_color = emb_left[: cfg.color_emb_dim]
|
| 199 |
+
right_color = emb_right[: cfg.color_emb_dim]
|
| 200 |
+
left_hier = emb_left[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim]
|
| 201 |
+
right_hier = emb_right[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim]
|
| 202 |
+
|
| 203 |
+
sim_color = cosine(left_color, right_color)
|
| 204 |
+
sim_hier = cosine(left_hier, right_hier)
|
| 205 |
+
sim_full512 = cosine(emb_left, emb_right)
|
| 206 |
+
sim_full512_negative = cosine(emb_left, emb_negative_right)
|
| 207 |
+
delta_color_vs_full_pct = delta_percent(sim_full512, sim_color)
|
| 208 |
+
delta_hier_vs_full_pct = delta_percent(sim_full512, sim_hier)
|
| 209 |
+
delta_color_vs_full_values.append(delta_color_vs_full_pct)
|
| 210 |
+
delta_hier_vs_full_values.append(delta_hier_vs_full_pct)
|
| 211 |
+
|
| 212 |
+
hierarchy_higher_than_full = sim_hier > sim_full512
|
| 213 |
+
hierarchy_higher_than_color = sim_hier > sim_color
|
| 214 |
+
pair_ok = hierarchy_higher_than_full and hierarchy_higher_than_color
|
| 215 |
+
pair_outcomes.append(pair_ok)
|
| 216 |
+
hier_gt_full_outcomes.append(hierarchy_higher_than_full)
|
| 217 |
+
hier_gt_color_outcomes.append(hierarchy_higher_than_color)
|
| 218 |
+
full512_outcomes.append(sim_full512 > sim_full512_negative)
|
| 219 |
+
|
| 220 |
+
rows.append(
|
| 221 |
+
[
|
| 222 |
+
f"{left} vs {right}",
|
| 223 |
+
f"{sim_color:.4f}",
|
| 224 |
+
f"{sim_hier:.4f}",
|
| 225 |
+
f"{sim_full512:.4f}",
|
| 226 |
+
f"{delta_color_vs_full_pct:+.2f}%",
|
| 227 |
+
f"{delta_hier_vs_full_pct:+.2f}%",
|
| 228 |
+
format_bool(pair_ok),
|
| 229 |
+
]
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
print_table(
|
| 233 |
+
f"Test A: Different colors, same hierarchy (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
|
| 234 |
+
[
|
| 235 |
+
"Pair",
|
| 236 |
+
"CosSim first16(color)",
|
| 237 |
+
"CosSim hier64",
|
| 238 |
+
"CosSim full512",
|
| 239 |
+
"Delta first16 vs full512 (%)",
|
| 240 |
+
"Delta hier64 vs full512 (%)",
|
| 241 |
+
"Result",
|
| 242 |
+
],
|
| 243 |
+
rows[:num_printed],
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
overall = all(pair_outcomes)
|
| 247 |
+
pass_rate = sum(pair_outcomes) / len(pair_outcomes)
|
| 248 |
+
full512_accuracy = sum(full512_outcomes) / len(full512_outcomes)
|
| 249 |
+
hier_gt_full_rate = sum(hier_gt_full_outcomes) / len(hier_gt_full_outcomes)
|
| 250 |
+
hier_gt_color_rate = sum(hier_gt_color_outcomes) / len(hier_gt_color_outcomes)
|
| 251 |
+
avg_delta_color_vs_full = sum(delta_color_vs_full_values) / len(delta_color_vs_full_values)
|
| 252 |
+
avg_delta_hier_vs_full = sum(delta_hier_vs_full_values) / len(delta_hier_vs_full_values)
|
| 253 |
+
print(f"Test A aggregate: {sum(pair_outcomes)}/{len(pair_outcomes)} passed ({pass_rate:.2%})")
|
| 254 |
+
print(f" sub-condition hier > full512: {sum(hier_gt_full_outcomes)}/{len(hier_gt_full_outcomes)} ({hier_gt_full_rate:.2%})")
|
| 255 |
+
print(f" sub-condition hier > color: {sum(hier_gt_color_outcomes)}/{len(hier_gt_color_outcomes)} ({hier_gt_color_rate:.2%})")
|
| 256 |
+
print(
|
| 257 |
+
"Test A full512 pair-discrimination accuracy "
|
| 258 |
+
f"(same-hierarchy > different-hierarchy): {sum(full512_outcomes)}/{len(full512_outcomes)} "
|
| 259 |
+
f"({full512_accuracy:.2%})"
|
| 260 |
+
)
|
| 261 |
+
print(
|
| 262 |
+
"Test A avg deltas: "
|
| 263 |
+
f"first16 vs full512 = {avg_delta_color_vs_full:+.2f}%, "
|
| 264 |
+
f"hier64 vs full512 = {avg_delta_hier_vs_full:+.2f}%"
|
| 265 |
+
)
|
| 266 |
+
return {
|
| 267 |
+
"overall": overall,
|
| 268 |
+
"accuracy_full512": full512_accuracy,
|
| 269 |
+
"pass_rate": pass_rate,
|
| 270 |
+
"hier_gt_full_rate": hier_gt_full_rate,
|
| 271 |
+
"hier_gt_color_rate": hier_gt_color_rate,
|
| 272 |
+
"avg_delta_color_vs_full": avg_delta_color_vs_full,
|
| 273 |
+
"avg_delta_hier_vs_full": avg_delta_hier_vs_full,
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def run_test_b(
|
| 278 |
+
model: CLIPModelTransformers,
|
| 279 |
+
processor: CLIPProcessor,
|
| 280 |
+
cfg: RuntimeConfig,
|
| 281 |
+
num_examples: int,
|
| 282 |
+
num_printed: int) -> Dict[str, bool]:
|
| 283 |
+
"""
|
| 284 |
+
B: same color + different hierarchies.
|
| 285 |
+
Expect similarity in first16 (color) to be higher than full512.
|
| 286 |
+
"""
|
| 287 |
+
positive_pairs: List[Tuple[str, str]] = []
|
| 288 |
+
negative_pairs: List[Tuple[str, str]] = []
|
| 289 |
+
for _ in range(num_examples):
|
| 290 |
+
color = random.choice(COLORS)
|
| 291 |
+
h1, h2 = random.sample(HIERARCHIES, 2)
|
| 292 |
+
negative_color = random.choice([c for c in COLORS if c != color])
|
| 293 |
+
positive_pairs.append((build_text_query(color, h1), build_text_query(color, h2)))
|
| 294 |
+
negative_pairs.append((build_text_query(color, h1), build_text_query(negative_color, h2)))
|
| 295 |
+
|
| 296 |
+
rows: List[List[str]] = []
|
| 297 |
+
pair_outcomes: List[bool] = []
|
| 298 |
+
full512_outcomes: List[bool] = []
|
| 299 |
+
color_gt_full_outcomes: List[bool] = []
|
| 300 |
+
color_gt_hier_outcomes: List[bool] = []
|
| 301 |
+
delta_color_vs_full_values: List[float] = []
|
| 302 |
+
delta_hier_vs_full_values: List[float] = []
|
| 303 |
+
|
| 304 |
+
for (left, right), (_, negative_right) in zip(positive_pairs, negative_pairs):
|
| 305 |
+
emb_left = get_text_embedding(model, processor, cfg.device, left)
|
| 306 |
+
emb_right = get_text_embedding(model, processor, cfg.device, right)
|
| 307 |
+
emb_negative_right = get_text_embedding(model, processor, cfg.device, negative_right)
|
| 308 |
+
|
| 309 |
+
sim_512 = cosine(emb_left, emb_right)
|
| 310 |
+
sim_16 = cosine(emb_left[: cfg.color_emb_dim], emb_right[: cfg.color_emb_dim])
|
| 311 |
+
sim_hier = cosine(
|
| 312 |
+
emb_left[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim],
|
| 313 |
+
emb_right[cfg.color_emb_dim : cfg.color_emb_dim + cfg.hierarchy_emb_dim],
|
| 314 |
+
)
|
| 315 |
+
sim_512_negative = cosine(emb_left, emb_negative_right)
|
| 316 |
+
delta_color_vs_full_pct = delta_percent(sim_512, sim_16)
|
| 317 |
+
delta_hier_vs_full_pct = delta_percent(sim_512, sim_hier)
|
| 318 |
+
delta_color_vs_full_values.append(delta_color_vs_full_pct)
|
| 319 |
+
delta_hier_vs_full_values.append(delta_hier_vs_full_pct)
|
| 320 |
+
|
| 321 |
+
first16_higher_than_full = sim_16 > sim_512
|
| 322 |
+
color_higher_than_hier = sim_16 > sim_hier
|
| 323 |
+
pair_ok = first16_higher_than_full and color_higher_than_hier
|
| 324 |
+
pair_outcomes.append(pair_ok)
|
| 325 |
+
color_gt_full_outcomes.append(first16_higher_than_full)
|
| 326 |
+
color_gt_hier_outcomes.append(color_higher_than_hier)
|
| 327 |
+
full512_outcomes.append(sim_512 > sim_512_negative)
|
| 328 |
+
|
| 329 |
+
rows.append(
|
| 330 |
+
[
|
| 331 |
+
f"{left} vs {right}",
|
| 332 |
+
f"{sim_16:.4f}",
|
| 333 |
+
f"{sim_hier:.4f}",
|
| 334 |
+
f"{sim_512:.4f}",
|
| 335 |
+
f"{delta_color_vs_full_pct:+.2f}%",
|
| 336 |
+
f"{delta_hier_vs_full_pct:+.2f}%",
|
| 337 |
+
format_bool(pair_ok),
|
| 338 |
+
]
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
print_table(
|
| 342 |
+
f"Test B: Same color, different hierarchies (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
|
| 343 |
+
[
|
| 344 |
+
"Pair",
|
| 345 |
+
"CosSim first16(color)",
|
| 346 |
+
"CosSim hier64",
|
| 347 |
+
"CosSim full512",
|
| 348 |
+
"Delta first16 vs full512 (%)",
|
| 349 |
+
"Delta hier64 vs full512 (%)",
|
| 350 |
+
"Result",
|
| 351 |
+
],
|
| 352 |
+
rows[:num_printed],
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
overall = all(pair_outcomes)
|
| 356 |
+
pass_rate = sum(pair_outcomes) / len(pair_outcomes)
|
| 357 |
+
full512_accuracy = sum(full512_outcomes) / len(full512_outcomes)
|
| 358 |
+
color_gt_full_rate = sum(color_gt_full_outcomes) / len(color_gt_full_outcomes)
|
| 359 |
+
color_gt_hier_rate = sum(color_gt_hier_outcomes) / len(color_gt_hier_outcomes)
|
| 360 |
+
avg_delta_color_vs_full = sum(delta_color_vs_full_values) / len(delta_color_vs_full_values)
|
| 361 |
+
avg_delta_hier_vs_full = sum(delta_hier_vs_full_values) / len(delta_hier_vs_full_values)
|
| 362 |
+
print(f"Test B aggregate: {sum(pair_outcomes)}/{len(pair_outcomes)} passed ({pass_rate:.2%})")
|
| 363 |
+
print(f" sub-condition color > full512: {sum(color_gt_full_outcomes)}/{len(color_gt_full_outcomes)} ({color_gt_full_rate:.2%})")
|
| 364 |
+
print(f" sub-condition color > hier: {sum(color_gt_hier_outcomes)}/{len(color_gt_hier_outcomes)} ({color_gt_hier_rate:.2%})")
|
| 365 |
+
print(
|
| 366 |
+
"Test B full512 pair-discrimination accuracy "
|
| 367 |
+
f"(same-color > different-color): {sum(full512_outcomes)}/{len(full512_outcomes)} "
|
| 368 |
+
f"({full512_accuracy:.2%})"
|
| 369 |
+
)
|
| 370 |
+
print(
|
| 371 |
+
"Test B avg deltas: "
|
| 372 |
+
f"first16 vs full512 = {avg_delta_color_vs_full:+.2f}%, "
|
| 373 |
+
f"hier64 vs full512 = {avg_delta_hier_vs_full:+.2f}%"
|
| 374 |
+
)
|
| 375 |
+
return {
|
| 376 |
+
"overall": overall,
|
| 377 |
+
"accuracy_full512": full512_accuracy,
|
| 378 |
+
"pass_rate": pass_rate,
|
| 379 |
+
"color_gt_full_rate": color_gt_full_rate,
|
| 380 |
+
"color_gt_hier_rate": color_gt_hier_rate,
|
| 381 |
+
"avg_delta_color_vs_full": avg_delta_color_vs_full,
|
| 382 |
+
"avg_delta_hier_vs_full": avg_delta_hier_vs_full,
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
FASHION_MNIST_LABELS = {
|
| 388 |
+
0: "top",
|
| 389 |
+
1: "pant",
|
| 390 |
+
2: "sweater",
|
| 391 |
+
3: "dress",
|
| 392 |
+
4: "coat",
|
| 393 |
+
5: "shoes",
|
| 394 |
+
6: "shirt",
|
| 395 |
+
7: "shoes",
|
| 396 |
+
8: "accessories",
|
| 397 |
+
9: "shoes",
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
FASHION_MNIST_CSV = "data/fashion-mnist_test.csv"
|
| 401 |
+
INTERNAL_DATASET_CSV = "data/data.csv"
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def fashion_mnist_pixels_to_tensor(pixel_values: np.ndarray, image_size: int = 224) -> torch.Tensor:
|
| 405 |
+
img_array = pixel_values.reshape(28, 28).astype(np.uint8)
|
| 406 |
+
img_array = np.stack([img_array] * 3, axis=-1)
|
| 407 |
+
image = Image.fromarray(img_array)
|
| 408 |
+
transform = transforms.Compose([
|
| 409 |
+
transforms.Resize((image_size, image_size)),
|
| 410 |
+
transforms.ToTensor(),
|
| 411 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 412 |
+
])
|
| 413 |
+
return transform(image)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def get_image_embedding(
|
| 417 |
+
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image_tensor: torch.Tensor
|
| 418 |
+
) -> torch.Tensor:
|
| 419 |
+
image_tensor = image_tensor.unsqueeze(0).to(device)
|
| 420 |
+
with torch.no_grad():
|
| 421 |
+
vision_outputs = model.vision_model(pixel_values=image_tensor)
|
| 422 |
+
image_features = model.visual_projection(vision_outputs.pooler_output)
|
| 423 |
+
image_features = F.normalize(image_features, dim=-1)
|
| 424 |
+
return image_features.squeeze(0)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def get_image_embedding_from_pil(
|
| 428 |
+
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, image: Image.Image
|
| 429 |
+
) -> torch.Tensor:
|
| 430 |
+
image_inputs = processor(images=[image], return_tensors="pt")
|
| 431 |
+
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 432 |
+
with torch.no_grad():
|
| 433 |
+
vision_outputs = model.vision_model(**image_inputs)
|
| 434 |
+
image_features = model.visual_projection(vision_outputs.pooler_output)
|
| 435 |
+
image_features = F.normalize(image_features, dim=-1)
|
| 436 |
+
return image_features.squeeze(0)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def get_text_embeddings_batch(
|
| 440 |
+
model: CLIPModelTransformers, processor: CLIPProcessor, device: torch.device, texts: List[str]
|
| 441 |
+
) -> torch.Tensor:
|
| 442 |
+
text_inputs = processor(text=texts, padding=True, return_tensors="pt")
|
| 443 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 444 |
+
with torch.no_grad():
|
| 445 |
+
text_outputs = model.text_model(**text_inputs)
|
| 446 |
+
text_features = model.text_projection(text_outputs.pooler_output)
|
| 447 |
+
text_features = F.normalize(text_features, dim=-1)
|
| 448 |
+
return text_features
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def get_prompt_ensembled_text_embeddings(
|
| 452 |
+
model: CLIPModelTransformers,
|
| 453 |
+
processor: CLIPProcessor,
|
| 454 |
+
device: torch.device,
|
| 455 |
+
labels: List[str],
|
| 456 |
+
templates: List[str],
|
| 457 |
+
) -> torch.Tensor:
|
| 458 |
+
"""Encode labels with multiple prompt templates and average embeddings."""
|
| 459 |
+
all_prompt_embs: List[torch.Tensor] = []
|
| 460 |
+
for template in templates:
|
| 461 |
+
prompts = [template.format(label=label) for label in labels]
|
| 462 |
+
all_prompt_embs.append(get_text_embeddings_batch(model, processor, device, prompts))
|
| 463 |
+
stacked = torch.stack(all_prompt_embs, dim=0)
|
| 464 |
+
ensembled = stacked.mean(dim=0)
|
| 465 |
+
ensembled = F.normalize(ensembled, dim=-1)
|
| 466 |
+
return ensembled
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def get_internal_label_prior(labels: List[str]) -> torch.Tensor:
|
| 470 |
+
"""
|
| 471 |
+
Compute label prior from internal dataset hierarchy frequency.
|
| 472 |
+
Falls back to uniform when internal CSV is unavailable.
|
| 473 |
+
"""
|
| 474 |
+
csv_file = Path(INTERNAL_DATASET_CSV)
|
| 475 |
+
if not csv_file.exists():
|
| 476 |
+
return torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1)
|
| 477 |
+
try:
|
| 478 |
+
df = pd.read_csv(INTERNAL_DATASET_CSV, usecols=["hierarchy"]).dropna()
|
| 479 |
+
except Exception:
|
| 480 |
+
return torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1)
|
| 481 |
+
if len(df) == 0:
|
| 482 |
+
return torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1)
|
| 483 |
+
|
| 484 |
+
norm_labels = [normalize_hierarchy_label(v) for v in df["hierarchy"].astype(str)]
|
| 485 |
+
counts = pd.Series(norm_labels).value_counts().to_dict()
|
| 486 |
+
smooth = 1e-3
|
| 487 |
+
probs = torch.tensor([float(counts.get(label, 0.0)) + smooth for label in labels], dtype=torch.float32)
|
| 488 |
+
probs = probs / probs.sum()
|
| 489 |
+
return probs
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def get_adaptive_label_prior(labels: List[str]) -> Tuple[torch.Tensor, float]:
|
| 493 |
+
"""
|
| 494 |
+
Compute label prior with adaptive strength based on overlap between
|
| 495 |
+
candidate labels and the training distribution. When most candidate
|
| 496 |
+
labels are out-of-domain, the recommended weight drops toward zero so
|
| 497 |
+
the prior does not penalise novel categories.
|
| 498 |
+
"""
|
| 499 |
+
csv_file = Path(INTERNAL_DATASET_CSV)
|
| 500 |
+
uniform = torch.ones(len(labels), dtype=torch.float32) / max(len(labels), 1)
|
| 501 |
+
if not csv_file.exists():
|
| 502 |
+
return uniform, 0.0
|
| 503 |
+
try:
|
| 504 |
+
df = pd.read_csv(INTERNAL_DATASET_CSV, usecols=["hierarchy"]).dropna()
|
| 505 |
+
except Exception:
|
| 506 |
+
return uniform, 0.0
|
| 507 |
+
if len(df) == 0:
|
| 508 |
+
return uniform, 0.0
|
| 509 |
+
|
| 510 |
+
norm_labels = [normalize_hierarchy_label(v) for v in df["hierarchy"].astype(str)]
|
| 511 |
+
counts = pd.Series(norm_labels).value_counts().to_dict()
|
| 512 |
+
known_labels = set(counts.keys())
|
| 513 |
+
overlap = sum(1 for l in labels if l in known_labels) / max(len(labels), 1)
|
| 514 |
+
total_count = sum(counts.values())
|
| 515 |
+
default_prob = 1.0 / max(len(labels), 1)
|
| 516 |
+
|
| 517 |
+
probs = torch.tensor(
|
| 518 |
+
[
|
| 519 |
+
counts.get(label, 0.0) / total_count if label in known_labels else default_prob
|
| 520 |
+
for label in labels
|
| 521 |
+
],
|
| 522 |
+
dtype=torch.float32,
|
| 523 |
+
)
|
| 524 |
+
probs = probs / probs.sum()
|
| 525 |
+
recommended_weight = 0.15 * (overlap ** 2)
|
| 526 |
+
return probs, recommended_weight
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def run_test_c(
|
| 530 |
+
model: CLIPModelTransformers,
|
| 531 |
+
processor: CLIPProcessor,
|
| 532 |
+
cfg: RuntimeConfig,
|
| 533 |
+
num_examples: int,
|
| 534 |
+
num_printed: int,
|
| 535 |
+
csv_path: str = FASHION_MNIST_CSV,
|
| 536 |
+
) -> Dict[str, object]:
|
| 537 |
+
"""
|
| 538 |
+
C: Zero-shot image classification.
|
| 539 |
+
For each image, compute cosine similarity against all candidate text labels
|
| 540 |
+
and check whether the highest-scoring text matches the ground truth.
|
| 541 |
+
"""
|
| 542 |
+
csv_file = Path(csv_path)
|
| 543 |
+
if not csv_file.exists():
|
| 544 |
+
print(f" Skipping Test C: {csv_path} not found")
|
| 545 |
+
return {"overall": True, "accuracy": None}
|
| 546 |
+
|
| 547 |
+
df = pd.read_csv(csv_path)
|
| 548 |
+
df = df.sample(n=min(num_examples, len(df)), random_state=42).reset_index(drop=True)
|
| 549 |
+
|
| 550 |
+
candidate_labels = sorted(set(FASHION_MNIST_LABELS.values()))
|
| 551 |
+
candidate_texts = [f"a photo of {label}" for label in candidate_labels]
|
| 552 |
+
text_embs = get_text_embeddings_batch(model, processor, cfg.device, candidate_texts)
|
| 553 |
+
|
| 554 |
+
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
|
| 555 |
+
rows: List[List[str]] = []
|
| 556 |
+
failed_rows: List[List[str]] = []
|
| 557 |
+
correct = 0
|
| 558 |
+
|
| 559 |
+
for idx in range(len(df)):
|
| 560 |
+
row = df.iloc[idx]
|
| 561 |
+
label_id = int(row["label"])
|
| 562 |
+
ground_truth = FASHION_MNIST_LABELS.get(label_id, "unknown")
|
| 563 |
+
|
| 564 |
+
pixels = row[pixel_cols].values.astype(float)
|
| 565 |
+
img_tensor = fashion_mnist_pixels_to_tensor(pixels)
|
| 566 |
+
img_emb = get_image_embedding(model, processor, cfg.device, img_tensor)
|
| 567 |
+
|
| 568 |
+
sims = F.cosine_similarity(img_emb.unsqueeze(0), text_embs, dim=1)
|
| 569 |
+
best_idx = sims.argmax().item()
|
| 570 |
+
predicted = candidate_labels[best_idx]
|
| 571 |
+
best_sim = sims[best_idx].item()
|
| 572 |
+
|
| 573 |
+
ok = predicted == ground_truth
|
| 574 |
+
if ok:
|
| 575 |
+
correct += 1
|
| 576 |
+
|
| 577 |
+
rows.append([
|
| 578 |
+
str(idx),
|
| 579 |
+
ground_truth,
|
| 580 |
+
predicted,
|
| 581 |
+
f"{best_sim:.4f}",
|
| 582 |
+
format_bool(ok),
|
| 583 |
+
])
|
| 584 |
+
if not ok:
|
| 585 |
+
failed_rows.append([
|
| 586 |
+
str(idx),
|
| 587 |
+
ground_truth,
|
| 588 |
+
predicted,
|
| 589 |
+
f"{best_sim:.4f}",
|
| 590 |
+
])
|
| 591 |
+
|
| 592 |
+
accuracy = correct / len(df)
|
| 593 |
+
|
| 594 |
+
print_table(
|
| 595 |
+
f"Test C: Zero-shot image classification (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
|
| 596 |
+
["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
|
| 597 |
+
rows[:num_printed],
|
| 598 |
+
)
|
| 599 |
+
print(f"Test C aggregate: {correct}/{len(df)} correct ({accuracy:.2%})")
|
| 600 |
+
|
| 601 |
+
return {"overall": True, "accuracy": accuracy}
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def normalize_hierarchy_label(raw_label: str) -> str:
|
| 605 |
+
"""Map dataset category strings to internal hierarchy labels."""
|
| 606 |
+
label = str(raw_label).strip().lower()
|
| 607 |
+
synonyms = {
|
| 608 |
+
"t-shirt/top": "top",
|
| 609 |
+
"top": "top",
|
| 610 |
+
"tee": "top",
|
| 611 |
+
"t-shirt": "top",
|
| 612 |
+
"shirt": "shirt",
|
| 613 |
+
"shirts": "shirt",
|
| 614 |
+
"pullover": "sweater",
|
| 615 |
+
"sweater": "sweater",
|
| 616 |
+
"coat": "coat",
|
| 617 |
+
"jacket": "jacket",
|
| 618 |
+
"outerwear": "coat",
|
| 619 |
+
"trouser": "pant",
|
| 620 |
+
"trousers": "pant",
|
| 621 |
+
"pants": "pant",
|
| 622 |
+
"pant": "pant",
|
| 623 |
+
"jeans": "pant",
|
| 624 |
+
"dress": "dress",
|
| 625 |
+
"skirt": "skirt",
|
| 626 |
+
"shorts": "short",
|
| 627 |
+
"short": "short",
|
| 628 |
+
"sandal": "shoes",
|
| 629 |
+
"sneaker": "shoes",
|
| 630 |
+
"ankle boot": "shoes",
|
| 631 |
+
"shoe": "shoes",
|
| 632 |
+
"shoes": "shoes",
|
| 633 |
+
"flip flops": "shoes",
|
| 634 |
+
"footwear": "shoes",
|
| 635 |
+
"shoe accessories": "shoes",
|
| 636 |
+
"bag": "accessories",
|
| 637 |
+
"bags": "accessories",
|
| 638 |
+
"accessory": "accessories",
|
| 639 |
+
"accessories": "accessories",
|
| 640 |
+
"belts": "accessories",
|
| 641 |
+
"eyewear": "accessories",
|
| 642 |
+
"jewellery": "accessories",
|
| 643 |
+
"jewelry": "accessories",
|
| 644 |
+
"headwear": "accessories",
|
| 645 |
+
"wallets": "accessories",
|
| 646 |
+
"watches": "accessories",
|
| 647 |
+
"mufflers": "accessories",
|
| 648 |
+
"scarves": "accessories",
|
| 649 |
+
"stoles": "accessories",
|
| 650 |
+
"ties": "accessories",
|
| 651 |
+
"topwear": "top",
|
| 652 |
+
"bottomwear": "pant",
|
| 653 |
+
"innerwear": "underwear",
|
| 654 |
+
"loungewear and nightwear": "underwear",
|
| 655 |
+
"saree": "dress",
|
| 656 |
+
}
|
| 657 |
+
return synonyms.get(label, label)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def get_candidate_labels_from_internal_csv() -> List[str]:
|
| 661 |
+
csv_file = Path(INTERNAL_DATASET_CSV)
|
| 662 |
+
if csv_file.exists():
|
| 663 |
+
df = pd.read_csv(INTERNAL_DATASET_CSV, usecols=["hierarchy"]).dropna()
|
| 664 |
+
labels = sorted(set(normalize_hierarchy_label(v) for v in df["hierarchy"].astype(str)))
|
| 665 |
+
if labels:
|
| 666 |
+
return labels
|
| 667 |
+
return sorted(set(FASHION_MNIST_LABELS.values()))
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def load_hierarchy_model_for_eval(device: torch.device):
|
| 671 |
+
"""Load the trained hierarchy model for evaluation strategies. Returns None on failure."""
|
| 672 |
+
try:
|
| 673 |
+
from training.hierarchy_model import Model as _HierarchyModel, HierarchyExtractor as _HierarchyExtractor
|
| 674 |
+
import config as _cfg
|
| 675 |
+
except ImportError:
|
| 676 |
+
return None
|
| 677 |
+
model_path = Path(getattr(_cfg, "hierarchy_model_path", "models/hierarchy_model.pth"))
|
| 678 |
+
if not model_path.exists():
|
| 679 |
+
return None
|
| 680 |
+
try:
|
| 681 |
+
checkpoint = torch.load(str(model_path), map_location=device)
|
| 682 |
+
hierarchy_classes = checkpoint.get("hierarchy_classes", [])
|
| 683 |
+
if not hierarchy_classes:
|
| 684 |
+
return None
|
| 685 |
+
_model = _HierarchyModel(
|
| 686 |
+
num_hierarchy_classes=len(hierarchy_classes),
|
| 687 |
+
embed_dim=getattr(_cfg, "hierarchy_emb_dim", 64),
|
| 688 |
+
).to(device)
|
| 689 |
+
_model.load_state_dict(checkpoint["model_state"])
|
| 690 |
+
_model.set_hierarchy_extractor(_HierarchyExtractor(hierarchy_classes, verbose=False))
|
| 691 |
+
_model.eval()
|
| 692 |
+
return _model
|
| 693 |
+
except Exception:
|
| 694 |
+
return None
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def evaluate_zero_shot_gap(
|
| 698 |
+
model: CLIPModelTransformers,
|
| 699 |
+
processor: CLIPProcessor,
|
| 700 |
+
device: torch.device,
|
| 701 |
+
samples: List[Tuple[Image.Image, str]],
|
| 702 |
+
candidate_labels: List[str],
|
| 703 |
+
title_prefix: str,
|
| 704 |
+
num_printed: int,
|
| 705 |
+
color_emb_dim: int = 16,
|
| 706 |
+
hierarchy_emb_dim: int = 64,
|
| 707 |
+
hierarchy_model=None,
|
| 708 |
+
) -> Dict[str, Optional[float]]:
|
| 709 |
+
if len(samples) == 0:
|
| 710 |
+
print(f" Skipping {title_prefix}: no valid samples")
|
| 711 |
+
return {"accuracy_c1": None, "accuracy_c2": None, "strategy": None}
|
| 712 |
+
|
| 713 |
+
# Strategy 1 (baseline prompt) and prompt-ensemble embeddings.
|
| 714 |
+
base_templates = ["a photo of {label}"]
|
| 715 |
+
ensemble_templates = [
|
| 716 |
+
"a photo of {label}",
|
| 717 |
+
"a product photo of {label}",
|
| 718 |
+
"a studio photo of {label}",
|
| 719 |
+
"a fashion item: {label}",
|
| 720 |
+
"an image of {label}",
|
| 721 |
+
]
|
| 722 |
+
text_embs_single = get_prompt_ensembled_text_embeddings(
|
| 723 |
+
model=model,
|
| 724 |
+
processor=processor,
|
| 725 |
+
device=device,
|
| 726 |
+
labels=candidate_labels,
|
| 727 |
+
templates=base_templates,
|
| 728 |
+
)
|
| 729 |
+
text_embs_ensemble = get_prompt_ensembled_text_embeddings(
|
| 730 |
+
model=model,
|
| 731 |
+
processor=processor,
|
| 732 |
+
device=device,
|
| 733 |
+
labels=candidate_labels,
|
| 734 |
+
templates=ensemble_templates,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
# Precompute image embeddings once for both C1 and C2.
|
| 738 |
+
image_embs: List[torch.Tensor] = []
|
| 739 |
+
for image, _ in samples:
|
| 740 |
+
image_embs.append(get_image_embedding_from_pil(model, processor, device, image))
|
| 741 |
+
image_embs_tensor = torch.stack(image_embs, dim=0)
|
| 742 |
+
|
| 743 |
+
# Similarity matrices (N images x C labels)
|
| 744 |
+
sims_single = image_embs_tensor @ text_embs_single.T
|
| 745 |
+
sims_ensemble = image_embs_tensor @ text_embs_ensemble.T
|
| 746 |
+
|
| 747 |
+
# Calibration and prior terms.
|
| 748 |
+
class_bias = sims_ensemble.mean(dim=0, keepdim=True)
|
| 749 |
+
class_prior = get_internal_label_prior(candidate_labels).to(device)
|
| 750 |
+
log_prior = torch.log(class_prior + 1e-8).unsqueeze(0)
|
| 751 |
+
|
| 752 |
+
# Baseline inference-time strategies (full 512-d embedding).
|
| 753 |
+
strategy_scores: Dict[str, torch.Tensor] = {
|
| 754 |
+
"single_prompt": sims_single,
|
| 755 |
+
"prompt_ensemble": sims_ensemble,
|
| 756 |
+
"ensemble_plus_calibration": sims_ensemble - 0.2 * class_bias,
|
| 757 |
+
"ensemble_plus_prior": sims_ensemble + 0.15 * log_prior,
|
| 758 |
+
"ensemble_calibration_plus_prior": sims_ensemble - 0.2 * class_bias + 0.15 * log_prior,
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
# Extended prompt ensemble for broader category coverage.
|
| 762 |
+
extended_templates = [
|
| 763 |
+
"a photo of {label}",
|
| 764 |
+
"a product photo of {label}",
|
| 765 |
+
"a studio photo of {label}",
|
| 766 |
+
"a fashion item: {label}",
|
| 767 |
+
"an image of {label}",
|
| 768 |
+
"{label}",
|
| 769 |
+
"a picture of {label}",
|
| 770 |
+
"this is a {label}",
|
| 771 |
+
"a fashion product: {label}",
|
| 772 |
+
"a {label} clothing item",
|
| 773 |
+
]
|
| 774 |
+
text_embs_extended = get_prompt_ensembled_text_embeddings(
|
| 775 |
+
model=model, processor=processor, device=device,
|
| 776 |
+
labels=candidate_labels, templates=extended_templates,
|
| 777 |
+
)
|
| 778 |
+
sims_extended = image_embs_tensor @ text_embs_extended.T
|
| 779 |
+
|
| 780 |
+
# Subspace: exclude color dimensions (keep hierarchy + residual).
|
| 781 |
+
hier_end = color_emb_dim + hierarchy_emb_dim
|
| 782 |
+
img_no_color = F.normalize(image_embs_tensor[:, color_emb_dim:], dim=-1)
|
| 783 |
+
text_ext_no_color = F.normalize(text_embs_extended[:, color_emb_dim:], dim=-1)
|
| 784 |
+
text_ens_no_color = F.normalize(text_embs_ensemble[:, color_emb_dim:], dim=-1)
|
| 785 |
+
sims_no_color = img_no_color @ text_ens_no_color.T
|
| 786 |
+
sims_no_color_ext = img_no_color @ text_ext_no_color.T
|
| 787 |
+
|
| 788 |
+
# Subspace: hierarchy-only dimensions.
|
| 789 |
+
img_hier = F.normalize(image_embs_tensor[:, color_emb_dim:hier_end], dim=-1)
|
| 790 |
+
text_ens_hier = F.normalize(text_embs_ensemble[:, color_emb_dim:hier_end], dim=-1)
|
| 791 |
+
text_ext_hier = F.normalize(text_embs_extended[:, color_emb_dim:hier_end], dim=-1)
|
| 792 |
+
sims_hier_ens = img_hier @ text_ens_hier.T
|
| 793 |
+
sims_hier_ext = img_hier @ text_ext_hier.T
|
| 794 |
+
|
| 795 |
+
# Adaptive prior (reduces influence for out-of-domain label sets).
|
| 796 |
+
adaptive_prior, adaptive_weight = get_adaptive_label_prior(candidate_labels)
|
| 797 |
+
adaptive_prior = adaptive_prior.to(device)
|
| 798 |
+
log_adaptive_prior = torch.log(adaptive_prior + 1e-8).unsqueeze(0)
|
| 799 |
+
|
| 800 |
+
class_bias_no_color = sims_no_color.mean(dim=0, keepdim=True)
|
| 801 |
+
|
| 802 |
+
strategy_scores.update({
|
| 803 |
+
"extended_ensemble": sims_extended,
|
| 804 |
+
"no_color_ensemble": sims_no_color,
|
| 805 |
+
"no_color_extended": sims_no_color_ext,
|
| 806 |
+
"hierarchy_only_ensemble": sims_hier_ens,
|
| 807 |
+
"hierarchy_only_extended": sims_hier_ext,
|
| 808 |
+
"no_color_calibrated": sims_no_color - 0.2 * class_bias_no_color,
|
| 809 |
+
"no_color_adaptive_prior": sims_no_color + adaptive_weight * log_adaptive_prior,
|
| 810 |
+
"no_color_ext_adaptive_prior": sims_no_color_ext + adaptive_weight * log_adaptive_prior,
|
| 811 |
+
"extended_adaptive_prior": sims_extended + adaptive_weight * log_adaptive_prior,
|
| 812 |
+
})
|
| 813 |
+
|
| 814 |
+
# Weighted embeddings: amplify hierarchy dims relative to residual.
|
| 815 |
+
for amp_factor in (2.0, 4.0):
|
| 816 |
+
weights = torch.ones(image_embs_tensor.shape[1], device=device)
|
| 817 |
+
weights[:color_emb_dim] = 0.0
|
| 818 |
+
weights[color_emb_dim:hier_end] = amp_factor
|
| 819 |
+
weighted_img = F.normalize(image_embs_tensor * weights.unsqueeze(0), dim=-1)
|
| 820 |
+
weighted_text = F.normalize(text_embs_extended * weights.unsqueeze(0), dim=-1)
|
| 821 |
+
tag = f"weighted_hier_{amp_factor:.0f}x"
|
| 822 |
+
strategy_scores[tag] = weighted_img @ weighted_text.T
|
| 823 |
+
|
| 824 |
+
# Hierarchy model direct strategy (uses dedicated hierarchy encoder).
|
| 825 |
+
if hierarchy_model is not None:
|
| 826 |
+
hier_text_embs: List[torch.Tensor] = []
|
| 827 |
+
known_label_mask: List[bool] = []
|
| 828 |
+
for label in candidate_labels:
|
| 829 |
+
try:
|
| 830 |
+
emb = hierarchy_model.get_text_embeddings(label).squeeze(0)
|
| 831 |
+
hier_text_embs.append(emb)
|
| 832 |
+
known_label_mask.append(True)
|
| 833 |
+
except (ValueError, Exception):
|
| 834 |
+
hier_text_embs.append(text_ext_hier[candidate_labels.index(label)])
|
| 835 |
+
known_label_mask.append(False)
|
| 836 |
+
hier_text_matrix = F.normalize(torch.stack(hier_text_embs).to(device), dim=-1)
|
| 837 |
+
sims_hier_model = img_hier @ hier_text_matrix.T
|
| 838 |
+
strategy_scores["hierarchy_model_direct"] = sims_hier_model
|
| 839 |
+
class_bias_hier = sims_hier_model.mean(dim=0, keepdim=True)
|
| 840 |
+
strategy_scores["hier_model_calibrated"] = sims_hier_model - 0.2 * class_bias_hier
|
| 841 |
+
strategy_scores["hier_model_adaptive_prior"] = sims_hier_model + adaptive_weight * log_adaptive_prior
|
| 842 |
+
|
| 843 |
+
# Hybrid: hierarchy model scores for known labels, CLIP for unknown.
|
| 844 |
+
hybrid_scores = sims_no_color_ext.clone()
|
| 845 |
+
for label_idx, is_known in enumerate(known_label_mask):
|
| 846 |
+
if is_known:
|
| 847 |
+
hybrid_scores[:, label_idx] = sims_hier_model[:, label_idx]
|
| 848 |
+
strategy_scores["hybrid_hier_clip"] = hybrid_scores
|
| 849 |
+
|
| 850 |
+
# Blended: z-score-normalised mix of hierarchy and full-space scores.
|
| 851 |
+
hier_mu = sims_hier_model.mean()
|
| 852 |
+
hier_std = sims_hier_model.std() + 1e-8
|
| 853 |
+
full_mu = sims_extended.mean()
|
| 854 |
+
full_std = sims_extended.std() + 1e-8
|
| 855 |
+
hier_z = (sims_hier_model - hier_mu) / hier_std
|
| 856 |
+
full_z = (sims_extended - full_mu) / full_std
|
| 857 |
+
for alpha in (0.3, 0.5, 0.7):
|
| 858 |
+
strategy_scores[f"blend_hier_full_{alpha:.1f}"] = alpha * hier_z + (1 - alpha) * full_z
|
| 859 |
+
|
| 860 |
+
# ---- C2-focused strategies: hubness reduction & retrieval normalisation ----
|
| 861 |
+
|
| 862 |
+
c2_bases: List[Tuple[str, torch.Tensor]] = [
|
| 863 |
+
("single", sims_single),
|
| 864 |
+
("ensemble", sims_ensemble),
|
| 865 |
+
("extended", sims_extended),
|
| 866 |
+
("no_color_ext", sims_no_color_ext),
|
| 867 |
+
]
|
| 868 |
+
|
| 869 |
+
# Image-bias correction: subtract per-image mean similarity so that
|
| 870 |
+
# "hub" images that score high with every label are penalised.
|
| 871 |
+
for tag, mat in c2_bases:
|
| 872 |
+
strategy_scores[f"{tag}_img_debiased"] = mat - mat.mean(dim=1, keepdim=True)
|
| 873 |
+
|
| 874 |
+
# CSLS (Cross-domain Similarity Local Scaling).
|
| 875 |
+
k_csls = min(3, len(candidate_labels) - 1)
|
| 876 |
+
for tag, mat in c2_bases:
|
| 877 |
+
rt = mat.topk(k_csls, dim=1).values.mean(dim=1, keepdim=True)
|
| 878 |
+
rs = mat.topk(min(k_csls, mat.shape[0]), dim=0).values.mean(dim=0, keepdim=True)
|
| 879 |
+
strategy_scores[f"{tag}_csls"] = 2 * mat - rt - rs
|
| 880 |
+
|
| 881 |
+
# Per-label column z-normalisation: standardise each label's score
|
| 882 |
+
# distribution across all images.
|
| 883 |
+
for tag, mat in c2_bases:
|
| 884 |
+
col_mu = mat.mean(dim=0, keepdim=True)
|
| 885 |
+
col_std = mat.std(dim=0, keepdim=True) + 1e-8
|
| 886 |
+
strategy_scores[f"{tag}_col_znorm"] = (mat - col_mu) / col_std
|
| 887 |
+
|
| 888 |
+
# Inverted softmax (column-wise softmax = P(image | text)).
|
| 889 |
+
for tag, mat in [("ensemble", sims_ensemble), ("extended", sims_extended)]:
|
| 890 |
+
for inv_t in (0.01, 0.05):
|
| 891 |
+
strategy_scores[f"{tag}_invsm_{inv_t}"] = F.softmax(mat / inv_t, dim=0)
|
| 892 |
+
|
| 893 |
+
# Bidirectional softmax: P(text|image) + P(image|text).
|
| 894 |
+
for tag, mat in [("ensemble", sims_ensemble), ("extended", sims_extended)]:
|
| 895 |
+
strategy_scores[f"{tag}_bidir"] = (
|
| 896 |
+
F.softmax(mat * 20, dim=1) + F.softmax(mat * 20, dim=0)
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
# Log-domain Sinkhorn normalisation (doubly-stochastic projection).
|
| 900 |
+
for tag, mat in [("ensemble", sims_ensemble), ("extended", sims_extended)]:
|
| 901 |
+
log_k = mat * 20.0
|
| 902 |
+
for _ in range(10):
|
| 903 |
+
log_k = log_k - torch.logsumexp(log_k, dim=1, keepdim=True)
|
| 904 |
+
log_k = log_k - torch.logsumexp(log_k, dim=0, keepdim=True)
|
| 905 |
+
strategy_scores[f"{tag}_sinkhorn"] = log_k
|
| 906 |
+
|
| 907 |
+
# Max-sim over prompts: instead of averaging template embeddings, keep
|
| 908 |
+
# per-template discriminative signal and take max across templates.
|
| 909 |
+
for tpl_tag, tpls in [
|
| 910 |
+
("ensemble_maxsim", ensemble_templates),
|
| 911 |
+
("extended_maxsim", extended_templates),
|
| 912 |
+
]:
|
| 913 |
+
per_tpl_sims: List[torch.Tensor] = []
|
| 914 |
+
for tpl in tpls:
|
| 915 |
+
prompts = [tpl.format(label=label) for label in candidate_labels]
|
| 916 |
+
t_embs = get_text_embeddings_batch(model, processor, device, prompts)
|
| 917 |
+
per_tpl_sims.append(image_embs_tensor @ t_embs.T)
|
| 918 |
+
max_sims = torch.stack(per_tpl_sims).max(dim=0).values
|
| 919 |
+
strategy_scores[tpl_tag] = max_sims
|
| 920 |
+
strategy_scores[f"{tpl_tag}_img_debiased"] = (
|
| 921 |
+
max_sims - max_sims.mean(dim=1, keepdim=True)
|
| 922 |
+
)
|
| 923 |
+
rt = max_sims.topk(k_csls, dim=1).values.mean(dim=1, keepdim=True)
|
| 924 |
+
rs = max_sims.topk(min(k_csls, max_sims.shape[0]), dim=0).values.mean(dim=0, keepdim=True)
|
| 925 |
+
strategy_scores[f"{tpl_tag}_csls"] = 2 * max_sims - rt - rs
|
| 926 |
+
col_mu = max_sims.mean(dim=0, keepdim=True)
|
| 927 |
+
col_std = max_sims.std(dim=0, keepdim=True) + 1e-8
|
| 928 |
+
strategy_scores[f"{tpl_tag}_col_znorm"] = (max_sims - col_mu) / col_std
|
| 929 |
+
|
| 930 |
+
# Combined: debiased + prior, CSLS + prior.
|
| 931 |
+
for tag, mat in [("ensemble", sims_ensemble), ("extended", sims_extended)]:
|
| 932 |
+
debiased = mat - mat.mean(dim=1, keepdim=True)
|
| 933 |
+
strategy_scores[f"{tag}_debiased_prior"] = debiased + adaptive_weight * log_adaptive_prior
|
| 934 |
+
csls_mat = strategy_scores[f"{tag}_csls"]
|
| 935 |
+
strategy_scores[f"{tag}_csls_prior"] = csls_mat + adaptive_weight * log_adaptive_prior
|
| 936 |
+
|
| 937 |
+
# Query expansion (pseudo-relevance feedback): blend each label's text
|
| 938 |
+
# embedding with the mean of its top-K retrieved image embeddings, then
|
| 939 |
+
# re-rank.
|
| 940 |
+
for qe_tag, qe_base_mat, qe_txt in [
|
| 941 |
+
("ensemble_qe", sims_ensemble, text_embs_ensemble),
|
| 942 |
+
("extended_qe", sims_extended, text_embs_extended),
|
| 943 |
+
]:
|
| 944 |
+
k_qe = min(5, len(samples) - 1)
|
| 945 |
+
topk_indices = qe_base_mat.topk(k_qe, dim=0).indices # (k_qe, C)
|
| 946 |
+
for alpha_qe in (0.3, 0.5, 0.7):
|
| 947 |
+
expanded: List[torch.Tensor] = []
|
| 948 |
+
for li in range(qe_txt.shape[0]):
|
| 949 |
+
top_imgs = image_embs_tensor[topk_indices[:, li]]
|
| 950 |
+
expanded.append(
|
| 951 |
+
(1 - alpha_qe) * qe_txt[li] + alpha_qe * top_imgs.mean(dim=0)
|
| 952 |
+
)
|
| 953 |
+
exp_mat = F.normalize(torch.stack(expanded), dim=-1)
|
| 954 |
+
strategy_scores[f"{qe_tag}_{alpha_qe:.1f}"] = image_embs_tensor @ exp_mat.T
|
| 955 |
+
|
| 956 |
+
# Apply C2-focused transforms to blend strategies when hierarchy model
|
| 957 |
+
# is available.
|
| 958 |
+
if hierarchy_model is not None:
|
| 959 |
+
for alpha in (0.3, 0.5, 0.7):
|
| 960 |
+
bkey = f"blend_hier_full_{alpha:.1f}"
|
| 961 |
+
if bkey in strategy_scores:
|
| 962 |
+
bmat = strategy_scores[bkey]
|
| 963 |
+
strategy_scores[f"{bkey}_img_debiased"] = (
|
| 964 |
+
bmat - bmat.mean(dim=1, keepdim=True)
|
| 965 |
+
)
|
| 966 |
+
rt = bmat.topk(k_csls, dim=1).values.mean(dim=1, keepdim=True)
|
| 967 |
+
rs = bmat.topk(min(k_csls, bmat.shape[0]), dim=0).values.mean(dim=0, keepdim=True)
|
| 968 |
+
strategy_scores[f"{bkey}_csls"] = 2 * bmat - rt - rs
|
| 969 |
+
col_mu = bmat.mean(dim=0, keepdim=True)
|
| 970 |
+
col_std = bmat.std(dim=0, keepdim=True) + 1e-8
|
| 971 |
+
strategy_scores[f"{bkey}_col_znorm"] = (bmat - col_mu) / col_std
|
| 972 |
+
|
| 973 |
+
# Select best strategy independently for C1 and C2.
|
| 974 |
+
present_labels_sel = sorted({label for _, label in samples if label in set(candidate_labels)})
|
| 975 |
+
|
| 976 |
+
best_strategy_c1 = "single_prompt"
|
| 977 |
+
best_acc_c1 = -1.0
|
| 978 |
+
best_scores_c1 = sims_single
|
| 979 |
+
|
| 980 |
+
best_strategy_c2 = "single_prompt"
|
| 981 |
+
best_acc_c2 = -1.0
|
| 982 |
+
best_scores_c2 = sims_single
|
| 983 |
+
|
| 984 |
+
for strategy_name, score_mat in strategy_scores.items():
|
| 985 |
+
pred_idx = score_mat.argmax(dim=1).tolist()
|
| 986 |
+
correct = sum(
|
| 987 |
+
1 for i, (_, gt) in enumerate(samples) if candidate_labels[pred_idx[i]] == gt
|
| 988 |
+
)
|
| 989 |
+
acc = correct / len(samples)
|
| 990 |
+
|
| 991 |
+
c2_ok = 0
|
| 992 |
+
for label in present_labels_sel:
|
| 993 |
+
li = candidate_labels.index(label)
|
| 994 |
+
if samples[int(score_mat[:, li].argmax().item())][1] == label:
|
| 995 |
+
c2_ok += 1
|
| 996 |
+
acc_c2 = c2_ok / len(present_labels_sel) if present_labels_sel else 0.0
|
| 997 |
+
|
| 998 |
+
if acc > best_acc_c1:
|
| 999 |
+
best_acc_c1 = acc
|
| 1000 |
+
best_strategy_c1 = strategy_name
|
| 1001 |
+
best_scores_c1 = score_mat
|
| 1002 |
+
if acc_c2 > best_acc_c2:
|
| 1003 |
+
best_acc_c2 = acc_c2
|
| 1004 |
+
best_strategy_c2 = strategy_name
|
| 1005 |
+
best_scores_c2 = score_mat
|
| 1006 |
+
|
| 1007 |
+
print(f"{title_prefix} selected C1 strategy: {best_strategy_c1} ({best_acc_c1:.2%})")
|
| 1008 |
+
print(f"{title_prefix} selected C2 strategy: {best_strategy_c2} ({best_acc_c2:.2%})")
|
| 1009 |
+
|
| 1010 |
+
# C1: image -> all texts (classification)
|
| 1011 |
+
rows: List[List[str]] = []
|
| 1012 |
+
correct = 0
|
| 1013 |
+
|
| 1014 |
+
for idx, (_, ground_truth) in enumerate(samples):
|
| 1015 |
+
sims = best_scores_c1[idx]
|
| 1016 |
+
best_idx = int(sims.argmax().item())
|
| 1017 |
+
predicted = candidate_labels[best_idx]
|
| 1018 |
+
best_sim = float(sims[best_idx].item())
|
| 1019 |
+
|
| 1020 |
+
ok = predicted == ground_truth
|
| 1021 |
+
if ok:
|
| 1022 |
+
correct += 1
|
| 1023 |
+
|
| 1024 |
+
rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}", format_bool(ok)])
|
| 1025 |
+
|
| 1026 |
+
accuracy_c1 = correct / len(samples)
|
| 1027 |
+
|
| 1028 |
+
print_table(
|
| 1029 |
+
f"{title_prefix} C1 image->texts (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
|
| 1030 |
+
["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
|
| 1031 |
+
rows[:num_printed],
|
| 1032 |
+
)
|
| 1033 |
+
print(f"{title_prefix} C1 aggregate: {correct}/{len(samples)} correct ({accuracy_c1:.2%})")
|
| 1034 |
+
|
| 1035 |
+
# C2: text -> all images (retrieval by label) — uses its own best strategy.
|
| 1036 |
+
present_labels = sorted({label for _, label in samples if label in set(candidate_labels)})
|
| 1037 |
+
c2_rows: List[List[str]] = []
|
| 1038 |
+
c2_correct = 0
|
| 1039 |
+
for idx, label in enumerate(present_labels):
|
| 1040 |
+
label_idx = candidate_labels.index(label)
|
| 1041 |
+
sims = best_scores_c2[:, label_idx]
|
| 1042 |
+
best_img_idx = int(sims.argmax().item())
|
| 1043 |
+
retrieved_gt = samples[best_img_idx][1]
|
| 1044 |
+
best_sim = float(sims[best_img_idx].item())
|
| 1045 |
+
ok = retrieved_gt == label
|
| 1046 |
+
if ok:
|
| 1047 |
+
c2_correct += 1
|
| 1048 |
+
c2_rows.append([str(idx), label, retrieved_gt, f"{best_sim:.4f}", format_bool(ok)])
|
| 1049 |
+
|
| 1050 |
+
accuracy_c2 = (c2_correct / len(present_labels)) if present_labels else None
|
| 1051 |
+
print_table(
|
| 1052 |
+
f"{title_prefix} C2 text->images (showing {min(num_printed, len(c2_rows))}/{len(c2_rows)} labels)",
|
| 1053 |
+
["#", "Query Label", "Top-1 Image GT", "Best CosSim", "Result"],
|
| 1054 |
+
c2_rows[:num_printed],
|
| 1055 |
+
)
|
| 1056 |
+
if accuracy_c2 is None:
|
| 1057 |
+
print(f"{title_prefix} C2 aggregate: N/A (no candidate labels present in samples)")
|
| 1058 |
+
else:
|
| 1059 |
+
print(
|
| 1060 |
+
f"{title_prefix} C2 aggregate: {c2_correct}/{len(present_labels)} correct ({accuracy_c2:.2%})"
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
return {
|
| 1064 |
+
"accuracy_c1": accuracy_c1,
|
| 1065 |
+
"accuracy_c2": accuracy_c2,
|
| 1066 |
+
"strategy": best_strategy_c1,
|
| 1067 |
+
"strategy_c2": best_strategy_c2,
|
| 1068 |
+
}
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
def evaluate_zero_shot_baseline(
|
| 1072 |
+
baseline_model: CLIPModelTransformers,
|
| 1073 |
+
baseline_processor: CLIPProcessor,
|
| 1074 |
+
device: torch.device,
|
| 1075 |
+
samples: List[Tuple[Image.Image, str]],
|
| 1076 |
+
candidate_labels: List[str],
|
| 1077 |
+
title_prefix: str,
|
| 1078 |
+
num_printed: int,
|
| 1079 |
+
) -> Dict[str, Optional[float]]:
|
| 1080 |
+
if len(samples) == 0:
|
| 1081 |
+
print(f" Skipping baseline {title_prefix}: no valid samples")
|
| 1082 |
+
return {"accuracy_c1": None, "accuracy_c2": None}
|
| 1083 |
+
|
| 1084 |
+
candidate_texts = [f"a photo of {label}" for label in candidate_labels]
|
| 1085 |
+
text_inputs = baseline_processor(text=candidate_texts, return_tensors="pt", padding=True, truncation=True)
|
| 1086 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 1087 |
+
with torch.no_grad():
|
| 1088 |
+
text_embs = baseline_model.get_text_features(**text_inputs)
|
| 1089 |
+
text_embs = F.normalize(text_embs, dim=-1)
|
| 1090 |
+
|
| 1091 |
+
# Precompute image embeddings once for both C1 and C2.
|
| 1092 |
+
image_embs: List[torch.Tensor] = []
|
| 1093 |
+
for image, _ in samples:
|
| 1094 |
+
image_inputs = baseline_processor(images=[image], return_tensors="pt")
|
| 1095 |
+
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 1096 |
+
with torch.no_grad():
|
| 1097 |
+
img_emb = baseline_model.get_image_features(**image_inputs)
|
| 1098 |
+
img_emb = F.normalize(img_emb, dim=-1)
|
| 1099 |
+
image_embs.append(img_emb.squeeze(0))
|
| 1100 |
+
image_embs_tensor = torch.stack(image_embs, dim=0)
|
| 1101 |
+
|
| 1102 |
+
# C1: image -> all texts (classification)
|
| 1103 |
+
rows: List[List[str]] = []
|
| 1104 |
+
correct = 0
|
| 1105 |
+
|
| 1106 |
+
for idx, (_, ground_truth) in enumerate(samples):
|
| 1107 |
+
img_emb = image_embs_tensor[idx].unsqueeze(0)
|
| 1108 |
+
sims = F.cosine_similarity(img_emb, text_embs, dim=1)
|
| 1109 |
+
best_idx = sims.argmax().item()
|
| 1110 |
+
predicted = candidate_labels[best_idx]
|
| 1111 |
+
best_sim = sims[best_idx].item()
|
| 1112 |
+
|
| 1113 |
+
ok = predicted == ground_truth
|
| 1114 |
+
if ok:
|
| 1115 |
+
correct += 1
|
| 1116 |
+
|
| 1117 |
+
rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}", format_bool(ok)])
|
| 1118 |
+
|
| 1119 |
+
accuracy_c1 = correct / len(samples)
|
| 1120 |
+
baseline_title = f"Baseline {title_prefix}"
|
| 1121 |
+
print_table(
|
| 1122 |
+
f"{baseline_title} C1 image->texts (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
|
| 1123 |
+
["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
|
| 1124 |
+
rows[:num_printed],
|
| 1125 |
+
)
|
| 1126 |
+
print(f"{baseline_title} C1 aggregate: {correct}/{len(samples)} correct ({accuracy_c1:.2%})")
|
| 1127 |
+
|
| 1128 |
+
# C2: text -> all images (retrieval by label)
|
| 1129 |
+
present_labels = sorted({label for _, label in samples if label in set(candidate_labels)})
|
| 1130 |
+
c2_rows: List[List[str]] = []
|
| 1131 |
+
c2_correct = 0
|
| 1132 |
+
for idx, label in enumerate(present_labels):
|
| 1133 |
+
label_emb = text_embs[candidate_labels.index(label)].unsqueeze(0)
|
| 1134 |
+
sims = F.cosine_similarity(label_emb, image_embs_tensor, dim=1)
|
| 1135 |
+
best_img_idx = sims.argmax().item()
|
| 1136 |
+
retrieved_gt = samples[best_img_idx][1]
|
| 1137 |
+
best_sim = sims[best_img_idx].item()
|
| 1138 |
+
ok = retrieved_gt == label
|
| 1139 |
+
if ok:
|
| 1140 |
+
c2_correct += 1
|
| 1141 |
+
c2_rows.append([str(idx), label, retrieved_gt, f"{best_sim:.4f}", format_bool(ok)])
|
| 1142 |
+
|
| 1143 |
+
accuracy_c2 = (c2_correct / len(present_labels)) if present_labels else None
|
| 1144 |
+
print_table(
|
| 1145 |
+
f"{baseline_title} C2 text->images (showing {min(num_printed, len(c2_rows))}/{len(c2_rows)} labels)",
|
| 1146 |
+
["#", "Query Label", "Top-1 Image GT", "Best CosSim", "Result"],
|
| 1147 |
+
c2_rows[:num_printed],
|
| 1148 |
+
)
|
| 1149 |
+
if accuracy_c2 is None:
|
| 1150 |
+
print(f"{baseline_title} C2 aggregate: N/A (no candidate labels present in samples)")
|
| 1151 |
+
else:
|
| 1152 |
+
print(
|
| 1153 |
+
f"{baseline_title} C2 aggregate: {c2_correct}/{len(present_labels)} correct ({accuracy_c2:.2%})"
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
return {"accuracy_c1": accuracy_c1, "accuracy_c2": accuracy_c2}
|
| 1157 |
+
|
| 1158 |
+
|
| 1159 |
+
def load_fashion_mnist_samples(num_examples: int) -> List[Tuple[Image.Image, str]]:
|
| 1160 |
+
csv_file = Path(FASHION_MNIST_CSV)
|
| 1161 |
+
if not csv_file.exists():
|
| 1162 |
+
return []
|
| 1163 |
+
df = pd.read_csv(FASHION_MNIST_CSV)
|
| 1164 |
+
df = df.sample(n=min(num_examples, len(df)), random_state=42).reset_index(drop=True)
|
| 1165 |
+
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
|
| 1166 |
+
|
| 1167 |
+
samples: List[Tuple[Image.Image, str]] = []
|
| 1168 |
+
for _, row in df.iterrows():
|
| 1169 |
+
label_id = int(row["label"])
|
| 1170 |
+
ground_truth = FASHION_MNIST_LABELS.get(label_id, "unknown")
|
| 1171 |
+
pixels = row[pixel_cols].values.astype(float)
|
| 1172 |
+
img_array = pixels.reshape(28, 28).astype(np.uint8)
|
| 1173 |
+
img_array = np.stack([img_array] * 3, axis=-1)
|
| 1174 |
+
samples.append((Image.fromarray(img_array), ground_truth))
|
| 1175 |
+
return samples
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
def load_kagl_marqo_samples(num_examples: int) -> List[Tuple[Image.Image, str]]:
|
| 1179 |
+
try:
|
| 1180 |
+
from datasets import load_dataset # type: ignore
|
| 1181 |
+
except Exception:
|
| 1182 |
+
print(" Skipping KAGL Marqo: datasets package not available")
|
| 1183 |
+
return []
|
| 1184 |
+
|
| 1185 |
+
try:
|
| 1186 |
+
dataset = load_dataset("Marqo/KAGL", split="data")
|
| 1187 |
+
except Exception as exc:
|
| 1188 |
+
print(f" Skipping KAGL Marqo: failed to load dataset ({exc})")
|
| 1189 |
+
return []
|
| 1190 |
+
|
| 1191 |
+
dataset = dataset.shuffle(seed=42).select(range(min(num_examples, len(dataset))))
|
| 1192 |
+
samples: List[Tuple[Image.Image, str]] = []
|
| 1193 |
+
for item in dataset:
|
| 1194 |
+
raw_label = item.get("category2")
|
| 1195 |
+
if raw_label is None:
|
| 1196 |
+
continue
|
| 1197 |
+
ground_truth = normalize_hierarchy_label(str(raw_label))
|
| 1198 |
+
image_obj = item.get("image")
|
| 1199 |
+
if image_obj is None:
|
| 1200 |
+
continue
|
| 1201 |
+
if hasattr(image_obj, "convert"):
|
| 1202 |
+
image = image_obj.convert("RGB")
|
| 1203 |
+
elif isinstance(image_obj, dict) and "bytes" in image_obj:
|
| 1204 |
+
image = Image.open(BytesIO(image_obj["bytes"])).convert("RGB")
|
| 1205 |
+
else:
|
| 1206 |
+
continue
|
| 1207 |
+
samples.append((image, ground_truth))
|
| 1208 |
+
return samples
|
| 1209 |
+
|
| 1210 |
+
|
| 1211 |
+
def load_internal_samples(num_examples: int) -> List[Tuple[Image.Image, str]]:
|
| 1212 |
+
csv_file = Path(INTERNAL_DATASET_CSV)
|
| 1213 |
+
if not csv_file.exists():
|
| 1214 |
+
print(f" Skipping internal dataset: {INTERNAL_DATASET_CSV} not found")
|
| 1215 |
+
return []
|
| 1216 |
+
|
| 1217 |
+
df = pd.read_csv(INTERNAL_DATASET_CSV)
|
| 1218 |
+
if "hierarchy" not in df.columns:
|
| 1219 |
+
print(" Skipping internal dataset: missing 'hierarchy' column")
|
| 1220 |
+
return []
|
| 1221 |
+
|
| 1222 |
+
df = df.dropna(subset=["hierarchy", "image_url"]).sample(frac=1.0, random_state=42)
|
| 1223 |
+
samples: List[Tuple[Image.Image, str]] = []
|
| 1224 |
+
|
| 1225 |
+
for _, row in df.iterrows():
|
| 1226 |
+
if len(samples) >= num_examples:
|
| 1227 |
+
break
|
| 1228 |
+
ground_truth = normalize_hierarchy_label(str(row["hierarchy"]))
|
| 1229 |
+
image_url = str(row["image_url"])
|
| 1230 |
+
try:
|
| 1231 |
+
response = requests.get(image_url, timeout=5)
|
| 1232 |
+
response.raise_for_status()
|
| 1233 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 1234 |
+
samples.append((image, ground_truth))
|
| 1235 |
+
except Exception:
|
| 1236 |
+
continue
|
| 1237 |
+
return samples
|
| 1238 |
+
|
| 1239 |
+
|
| 1240 |
+
def run_test_c_baseline_fashion_clip(
|
| 1241 |
+
device: torch.device,
|
| 1242 |
+
num_examples: int,
|
| 1243 |
+
num_printed: int,
|
| 1244 |
+
csv_path: str = FASHION_MNIST_CSV,
|
| 1245 |
+
) -> Dict[str, Optional[float]]:
|
| 1246 |
+
"""
|
| 1247 |
+
Same zero-shot protocol as Test C, but using baseline Fashion-CLIP.
|
| 1248 |
+
"""
|
| 1249 |
+
csv_file = Path(csv_path)
|
| 1250 |
+
if not csv_file.exists():
|
| 1251 |
+
print(f" Skipping Baseline Test C: {csv_path} not found")
|
| 1252 |
+
return {"accuracy": None}
|
| 1253 |
+
|
| 1254 |
+
print("\nLoading baseline model (patrickjohncyh/fashion-clip)...")
|
| 1255 |
+
baseline_name = "patrickjohncyh/fashion-clip"
|
| 1256 |
+
baseline_processor = CLIPProcessor.from_pretrained(baseline_name)
|
| 1257 |
+
baseline_model = CLIPModelTransformers.from_pretrained(baseline_name).to(device)
|
| 1258 |
+
baseline_model.eval()
|
| 1259 |
+
print("Baseline model loaded.")
|
| 1260 |
+
|
| 1261 |
+
df = pd.read_csv(csv_path)
|
| 1262 |
+
df = df.sample(n=min(num_examples, len(df)), random_state=42).reset_index(drop=True)
|
| 1263 |
+
|
| 1264 |
+
candidate_labels = sorted(set(FASHION_MNIST_LABELS.values()))
|
| 1265 |
+
candidate_texts = [f"a photo of {label}" for label in candidate_labels]
|
| 1266 |
+
|
| 1267 |
+
text_inputs = baseline_processor(text=candidate_texts, return_tensors="pt", padding=True, truncation=True)
|
| 1268 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 1269 |
+
with torch.no_grad():
|
| 1270 |
+
text_embs = baseline_model.get_text_features(**text_inputs)
|
| 1271 |
+
text_embs = F.normalize(text_embs, dim=-1)
|
| 1272 |
+
|
| 1273 |
+
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
|
| 1274 |
+
rows: List[List[str]] = []
|
| 1275 |
+
failed_rows: List[List[str]] = []
|
| 1276 |
+
correct = 0
|
| 1277 |
+
|
| 1278 |
+
for idx in range(len(df)):
|
| 1279 |
+
row = df.iloc[idx]
|
| 1280 |
+
label_id = int(row["label"])
|
| 1281 |
+
ground_truth = FASHION_MNIST_LABELS.get(label_id, "unknown")
|
| 1282 |
+
|
| 1283 |
+
pixels = row[pixel_cols].values.astype(float)
|
| 1284 |
+
img_array = pixels.reshape(28, 28).astype(np.uint8)
|
| 1285 |
+
img_array = np.stack([img_array] * 3, axis=-1)
|
| 1286 |
+
image = Image.fromarray(img_array)
|
| 1287 |
+
|
| 1288 |
+
image_inputs = baseline_processor(images=[image], return_tensors="pt")
|
| 1289 |
+
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 1290 |
+
with torch.no_grad():
|
| 1291 |
+
img_emb = baseline_model.get_image_features(**image_inputs)
|
| 1292 |
+
img_emb = F.normalize(img_emb, dim=-1)
|
| 1293 |
+
|
| 1294 |
+
sims = F.cosine_similarity(img_emb, text_embs, dim=1)
|
| 1295 |
+
best_idx = sims.argmax().item()
|
| 1296 |
+
predicted = candidate_labels[best_idx]
|
| 1297 |
+
best_sim = sims[best_idx].item()
|
| 1298 |
+
|
| 1299 |
+
ok = predicted == ground_truth
|
| 1300 |
+
if ok:
|
| 1301 |
+
correct += 1
|
| 1302 |
+
|
| 1303 |
+
rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}", format_bool(ok)])
|
| 1304 |
+
if not ok:
|
| 1305 |
+
failed_rows.append([str(idx), ground_truth, predicted, f"{best_sim:.4f}"])
|
| 1306 |
+
|
| 1307 |
+
accuracy = correct / len(df)
|
| 1308 |
+
|
| 1309 |
+
print_table(
|
| 1310 |
+
f"Baseline Test C (Fashion-CLIP): zero-shot (showing {min(num_printed, len(rows))}/{len(rows)} examples)",
|
| 1311 |
+
["#", "Ground Truth", "Predicted", "Best CosSim", "Result"],
|
| 1312 |
+
rows[:num_printed],
|
| 1313 |
+
)
|
| 1314 |
+
print(f"Baseline Test C aggregate: {correct}/{len(df)} correct ({accuracy:.2%})")
|
| 1315 |
+
|
| 1316 |
+
return {"accuracy": accuracy}
|
| 1317 |
+
|
| 1318 |
+
|
| 1319 |
+
def main(selected_tests: set[str]) -> None:
|
| 1320 |
+
random.seed(42)
|
| 1321 |
+
cfg = resolve_runtime_config()
|
| 1322 |
+
model_path = Path(cfg.main_model_path)
|
| 1323 |
+
if not model_path.exists():
|
| 1324 |
+
raise FileNotFoundError(f"Main model checkpoint not found: {cfg.main_model_path}")
|
| 1325 |
+
|
| 1326 |
+
print("Loading model...")
|
| 1327 |
+
print(f" device: {cfg.device}")
|
| 1328 |
+
print(f" checkpoint: {cfg.main_model_path}")
|
| 1329 |
+
print(f" dims: color={cfg.color_emb_dim}, hierarchy={cfg.hierarchy_emb_dim}, total=512")
|
| 1330 |
+
model, processor = load_main_model(cfg.device, cfg.main_model_path)
|
| 1331 |
+
print("Model loaded.")
|
| 1332 |
+
|
| 1333 |
+
result_a: Optional[Dict[str, object]] = None
|
| 1334 |
+
result_b: Optional[Dict[str, object]] = None
|
| 1335 |
+
if "A" in selected_tests:
|
| 1336 |
+
result_a = run_test_a(
|
| 1337 |
+
model,
|
| 1338 |
+
processor,
|
| 1339 |
+
cfg,
|
| 1340 |
+
num_examples=DEFAULT_NUM_EXAMPLES,
|
| 1341 |
+
num_printed=DEFAULT_NUM_PRINTED,
|
| 1342 |
+
)
|
| 1343 |
+
if "B" in selected_tests:
|
| 1344 |
+
result_b = run_test_b(
|
| 1345 |
+
model,
|
| 1346 |
+
processor,
|
| 1347 |
+
cfg,
|
| 1348 |
+
num_examples=DEFAULT_NUM_EXAMPLES,
|
| 1349 |
+
num_printed=DEFAULT_NUM_PRINTED,
|
| 1350 |
+
)
|
| 1351 |
+
|
| 1352 |
+
c1_results_gap: Dict[str, Optional[float]] = {}
|
| 1353 |
+
c1_results_base: Dict[str, Optional[float]] = {}
|
| 1354 |
+
c2_results_gap: Dict[str, Optional[float]] = {}
|
| 1355 |
+
c2_results_base: Dict[str, Optional[float]] = {}
|
| 1356 |
+
c_strategy_gap: Dict[str, Optional[str]] = {}
|
| 1357 |
+
c_strategy_c2_gap: Dict[str, Optional[str]] = {}
|
| 1358 |
+
if "C" in selected_tests:
|
| 1359 |
+
print("\nLoading baseline model (patrickjohncyh/fashion-clip)...")
|
| 1360 |
+
baseline_name = "patrickjohncyh/fashion-clip"
|
| 1361 |
+
baseline_processor = CLIPProcessor.from_pretrained(baseline_name)
|
| 1362 |
+
baseline_model = CLIPModelTransformers.from_pretrained(baseline_name).to(cfg.device)
|
| 1363 |
+
baseline_model.eval()
|
| 1364 |
+
print("Baseline model loaded.")
|
| 1365 |
+
|
| 1366 |
+
candidate_labels = get_candidate_labels_from_internal_csv()
|
| 1367 |
+
print(f"\nZero-shot candidate labels ({len(candidate_labels)}): {candidate_labels}")
|
| 1368 |
+
|
| 1369 |
+
hierarchy_model_eval = load_hierarchy_model_for_eval(cfg.device)
|
| 1370 |
+
if hierarchy_model_eval is not None:
|
| 1371 |
+
print("Hierarchy model loaded for evaluation strategies.")
|
| 1372 |
+
else:
|
| 1373 |
+
print("Hierarchy model not available; subspace strategies will use CLIP-only fallback.")
|
| 1374 |
+
|
| 1375 |
+
datasets_for_c = {
|
| 1376 |
+
"Fashion-MNIST": load_fashion_mnist_samples(DEFAULT_NUM_EXAMPLES),
|
| 1377 |
+
"KAGL Marqo": load_kagl_marqo_samples(DEFAULT_NUM_EXAMPLES),
|
| 1378 |
+
"Internal dataset": load_internal_samples(min(DEFAULT_NUM_EXAMPLES, 200)),
|
| 1379 |
+
}
|
| 1380 |
+
for dataset_name, samples in datasets_for_c.items():
|
| 1381 |
+
print(f"\n{'=' * 120}")
|
| 1382 |
+
print(f"Test C on {dataset_name}")
|
| 1383 |
+
print(f"{'=' * 120}")
|
| 1384 |
+
print(f"Valid samples used: {len(samples)}")
|
| 1385 |
+
|
| 1386 |
+
dataset_candidate_labels = sorted(set(candidate_labels) | {label for _, label in samples})
|
| 1387 |
+
|
| 1388 |
+
gap_metrics = evaluate_zero_shot_gap(
|
| 1389 |
+
model=model,
|
| 1390 |
+
processor=processor,
|
| 1391 |
+
device=cfg.device,
|
| 1392 |
+
samples=samples,
|
| 1393 |
+
candidate_labels=dataset_candidate_labels,
|
| 1394 |
+
title_prefix=f"Test C ({dataset_name})",
|
| 1395 |
+
num_printed=DEFAULT_NUM_PRINTED,
|
| 1396 |
+
color_emb_dim=cfg.color_emb_dim,
|
| 1397 |
+
hierarchy_emb_dim=cfg.hierarchy_emb_dim,
|
| 1398 |
+
hierarchy_model=hierarchy_model_eval,
|
| 1399 |
+
)
|
| 1400 |
+
baseline_metrics = evaluate_zero_shot_baseline(
|
| 1401 |
+
baseline_model=baseline_model,
|
| 1402 |
+
baseline_processor=baseline_processor,
|
| 1403 |
+
device=cfg.device,
|
| 1404 |
+
samples=samples,
|
| 1405 |
+
candidate_labels=dataset_candidate_labels,
|
| 1406 |
+
title_prefix=f"Test C ({dataset_name})",
|
| 1407 |
+
num_printed=DEFAULT_NUM_PRINTED,
|
| 1408 |
+
)
|
| 1409 |
+
c1_results_gap[dataset_name] = gap_metrics["accuracy_c1"]
|
| 1410 |
+
c1_results_base[dataset_name] = baseline_metrics["accuracy_c1"]
|
| 1411 |
+
c2_results_gap[dataset_name] = gap_metrics["accuracy_c2"]
|
| 1412 |
+
c2_results_base[dataset_name] = baseline_metrics["accuracy_c2"]
|
| 1413 |
+
c_strategy_gap[dataset_name] = gap_metrics.get("strategy")
|
| 1414 |
+
c_strategy_c2_gap[dataset_name] = gap_metrics.get("strategy_c2")
|
| 1415 |
+
|
| 1416 |
+
print("\n" + "=" * 120)
|
| 1417 |
+
print("Final Summary")
|
| 1418 |
+
print("=" * 120)
|
| 1419 |
+
print(f"Tests selected: {''.join(sorted(selected_tests))}")
|
| 1420 |
+
if result_a is not None:
|
| 1421 |
+
print(f"Test A overall: {format_bool(bool(result_a['overall']))}")
|
| 1422 |
+
print(f"Test A full512 accuracy: {float(result_a['accuracy_full512']):.2%}")
|
| 1423 |
+
if result_b is not None:
|
| 1424 |
+
print(f"Test B overall: {format_bool(bool(result_b['overall']))}")
|
| 1425 |
+
print(f"Test B full512 accuracy: {float(result_b['accuracy_full512']):.2%}")
|
| 1426 |
+
if "C" in selected_tests:
|
| 1427 |
+
for dataset_name in ["Fashion-MNIST", "KAGL Marqo", "Internal dataset"]:
|
| 1428 |
+
gap_c1 = c1_results_gap.get(dataset_name)
|
| 1429 |
+
base_c1 = c1_results_base.get(dataset_name)
|
| 1430 |
+
gap_c2 = c2_results_gap.get(dataset_name)
|
| 1431 |
+
base_c2 = c2_results_base.get(dataset_name)
|
| 1432 |
+
|
| 1433 |
+
gap_c1_str = f"{gap_c1:.2%}" if gap_c1 is not None else "N/A"
|
| 1434 |
+
base_c1_str = f"{base_c1:.2%}" if base_c1 is not None else "N/A"
|
| 1435 |
+
gap_c2_str = f"{gap_c2:.2%}" if gap_c2 is not None else "N/A"
|
| 1436 |
+
base_c2_str = f"{base_c2:.2%}" if base_c2 is not None else "N/A"
|
| 1437 |
+
|
| 1438 |
+
print(f"Test C1 ({dataset_name}) GAP-CLIP accuracy: {gap_c1_str}")
|
| 1439 |
+
print(f"Test C1 ({dataset_name}) GAP-CLIP selected strategy: {c_strategy_gap.get(dataset_name)}")
|
| 1440 |
+
print(f"Test C1 ({dataset_name}) baseline accuracy: {base_c1_str}")
|
| 1441 |
+
if gap_c1 is not None and base_c1 is not None:
|
| 1442 |
+
print(f"Delta C1 ({dataset_name}, GAP-CLIP - baseline): {gap_c1 - base_c1:+.2%}")
|
| 1443 |
+
|
| 1444 |
+
print(f"Test C2 ({dataset_name}) GAP-CLIP accuracy: {gap_c2_str}")
|
| 1445 |
+
print(f"Test C2 ({dataset_name}) GAP-CLIP selected strategy: {c_strategy_c2_gap.get(dataset_name)}")
|
| 1446 |
+
print(f"Test C2 ({dataset_name}) baseline accuracy: {base_c2_str}")
|
| 1447 |
+
if gap_c2 is not None and base_c2 is not None:
|
| 1448 |
+
print(f"Delta C2 ({dataset_name}, GAP-CLIP - baseline): {gap_c2 - base_c2:+.2%}")
|
| 1449 |
+
|
| 1450 |
+
if result_a is not None:
|
| 1451 |
+
assert bool(result_a["overall"]), "Test A failed: hierarchy behavior did not match expected pattern."
|
| 1452 |
+
if result_b is not None:
|
| 1453 |
+
assert bool(result_b["overall"]), "Test B failed: first16 correlation was not consistently above full512."
|
| 1454 |
+
|
| 1455 |
+
print("\nAll embedding-structure tests passed.")
|
| 1456 |
+
|
| 1457 |
+
|
| 1458 |
+
if __name__ == "__main__":
|
| 1459 |
+
selected_tests = 'ABC'
|
| 1460 |
+
main(selected_tests)
|
evaluation/utils/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
evaluation/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Shared utilities for GAP-CLIP evaluation scripts.
|
evaluation/utils/datasets.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared dataset classes and loading utilities for GAP-CLIP evaluation scripts.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- FashionMNISTDataset (Fashion-MNIST grayscale images)
|
| 6 |
+
- KaggleDataset (KAGL Marqo HuggingFace dataset)
|
| 7 |
+
- LocalDataset (internal local validation dataset)
|
| 8 |
+
- Matching load_* convenience functions
|
| 9 |
+
- collate_fn_filter_none (for DataLoader)
|
| 10 |
+
- normalize_hierarchy_label (text normalisation helper)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import difflib
|
| 16 |
+
import hashlib
|
| 17 |
+
import sys
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from io import BytesIO
|
| 20 |
+
from typing import List, Optional
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import pandas as pd
|
| 24 |
+
import torch
|
| 25 |
+
from PIL import Image
|
| 26 |
+
import requests
|
| 27 |
+
from torch.utils.data import Dataset
|
| 28 |
+
from torchvision import transforms
|
| 29 |
+
|
| 30 |
+
# Make project root importable when running evaluation scripts directly.
|
| 31 |
+
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 32 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 33 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 34 |
+
|
| 35 |
+
from config import ( # type: ignore
|
| 36 |
+
column_local_image_path,
|
| 37 |
+
fashion_mnist_csv,
|
| 38 |
+
local_dataset_path,
|
| 39 |
+
images_dir,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_VALID_COLORS = [
|
| 43 |
+
"beige", "black", "blue", "brown", "green",
|
| 44 |
+
"orange", "pink", "purple", "red", "white", "yellow",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Fashion-MNIST helpers
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def get_fashion_mnist_labels() -> dict:
|
| 52 |
+
"""Return the 10 Fashion-MNIST integer-to-name mapping."""
|
| 53 |
+
return {
|
| 54 |
+
0: "T-shirt/top",
|
| 55 |
+
1: "Trouser",
|
| 56 |
+
2: "Pullover",
|
| 57 |
+
3: "Dress",
|
| 58 |
+
4: "Coat",
|
| 59 |
+
5: "Sandal",
|
| 60 |
+
6: "Shirt",
|
| 61 |
+
7: "Sneaker",
|
| 62 |
+
8: "Bag",
|
| 63 |
+
9: "Ankle boot",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes: List[str]) -> dict:
|
| 68 |
+
"""Map Fashion-MNIST integer labels to nearest hierarchy class name.
|
| 69 |
+
|
| 70 |
+
Returns dict {label_id: matched_class_name or None}.
|
| 71 |
+
"""
|
| 72 |
+
fashion_mnist_labels = get_fashion_mnist_labels()
|
| 73 |
+
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
|
| 74 |
+
mapping = {}
|
| 75 |
+
|
| 76 |
+
for fm_label_id, fm_label in fashion_mnist_labels.items():
|
| 77 |
+
fm_label_lower = fm_label.lower()
|
| 78 |
+
matched_hierarchy = None
|
| 79 |
+
|
| 80 |
+
if fm_label_lower in hierarchy_classes_lower:
|
| 81 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)]
|
| 82 |
+
elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower):
|
| 83 |
+
for h_class in hierarchy_classes:
|
| 84 |
+
if h_class.lower() in fm_label_lower or fm_label_lower in h_class.lower():
|
| 85 |
+
matched_hierarchy = h_class
|
| 86 |
+
break
|
| 87 |
+
else:
|
| 88 |
+
if fm_label_lower in ["t-shirt/top", "top"]:
|
| 89 |
+
if "top" in hierarchy_classes_lower:
|
| 90 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index("top")]
|
| 91 |
+
elif "trouser" in fm_label_lower:
|
| 92 |
+
for p in ["bottom", "pants", "trousers", "trouser", "pant"]:
|
| 93 |
+
if p in hierarchy_classes_lower:
|
| 94 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(p)]
|
| 95 |
+
break
|
| 96 |
+
elif "pullover" in fm_label_lower:
|
| 97 |
+
for p in ["sweater", "pullover"]:
|
| 98 |
+
if p in hierarchy_classes_lower:
|
| 99 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(p)]
|
| 100 |
+
break
|
| 101 |
+
elif "dress" in fm_label_lower:
|
| 102 |
+
if "dress" in hierarchy_classes_lower:
|
| 103 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index("dress")]
|
| 104 |
+
elif "coat" in fm_label_lower:
|
| 105 |
+
for p in ["jacket", "outerwear", "coat"]:
|
| 106 |
+
if p in hierarchy_classes_lower:
|
| 107 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(p)]
|
| 108 |
+
break
|
| 109 |
+
elif fm_label_lower in ["sandal", "sneaker", "ankle boot"]:
|
| 110 |
+
for p in ["shoes", "shoe", "sandal", "sneaker", "boot"]:
|
| 111 |
+
if p in hierarchy_classes_lower:
|
| 112 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(p)]
|
| 113 |
+
break
|
| 114 |
+
elif "bag" in fm_label_lower:
|
| 115 |
+
if "bag" in hierarchy_classes_lower:
|
| 116 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index("bag")]
|
| 117 |
+
|
| 118 |
+
if matched_hierarchy is None:
|
| 119 |
+
close = difflib.get_close_matches(fm_label_lower, hierarchy_classes_lower, n=1, cutoff=0.6)
|
| 120 |
+
if close:
|
| 121 |
+
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close[0])]
|
| 122 |
+
|
| 123 |
+
mapping[fm_label_id] = matched_hierarchy
|
| 124 |
+
status = matched_hierarchy if matched_hierarchy else "NO MATCH (will be filtered out)"
|
| 125 |
+
print(f" {fm_label} ({fm_label_id}) -> {status}")
|
| 126 |
+
|
| 127 |
+
return mapping
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def convert_fashion_mnist_to_image(pixel_values) -> Image.Image:
|
| 131 |
+
"""Convert a flat 784-element pixel array to an RGB PIL image."""
|
| 132 |
+
arr = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
|
| 133 |
+
arr = np.stack([arr] * 3, axis=-1)
|
| 134 |
+
return Image.fromarray(arr)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class FashionMNISTDataset(Dataset):
|
| 138 |
+
"""PyTorch dataset wrapping Fashion-MNIST CSV rows."""
|
| 139 |
+
|
| 140 |
+
def __init__(self, dataframe: pd.DataFrame, image_size: int = 224, label_mapping: Optional[dict] = None):
|
| 141 |
+
self.dataframe = dataframe
|
| 142 |
+
self.image_size = image_size
|
| 143 |
+
self.labels_map = get_fashion_mnist_labels()
|
| 144 |
+
self.label_mapping = label_mapping
|
| 145 |
+
|
| 146 |
+
self.transform = transforms.Compose([
|
| 147 |
+
transforms.Resize((image_size, image_size)),
|
| 148 |
+
transforms.ToTensor(),
|
| 149 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 150 |
+
])
|
| 151 |
+
|
| 152 |
+
def __len__(self) -> int:
|
| 153 |
+
return len(self.dataframe)
|
| 154 |
+
|
| 155 |
+
def __getitem__(self, idx):
|
| 156 |
+
row = self.dataframe.iloc[idx]
|
| 157 |
+
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
|
| 158 |
+
image = convert_fashion_mnist_to_image(row[pixel_cols].values)
|
| 159 |
+
image = self.transform(image)
|
| 160 |
+
|
| 161 |
+
label_id = int(row["label"])
|
| 162 |
+
description = self.labels_map[label_id]
|
| 163 |
+
color = "unknown"
|
| 164 |
+
hierarchy = (
|
| 165 |
+
self.label_mapping[label_id]
|
| 166 |
+
if (self.label_mapping and label_id in self.label_mapping)
|
| 167 |
+
else self.labels_map[label_id]
|
| 168 |
+
)
|
| 169 |
+
return image, description, color, hierarchy
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def load_fashion_mnist_dataset(
|
| 173 |
+
max_samples: int = 10000,
|
| 174 |
+
hierarchy_classes: Optional[List[str]] = None,
|
| 175 |
+
csv_path: Optional[str] = None,
|
| 176 |
+
) -> FashionMNISTDataset:
|
| 177 |
+
"""Load Fashion-MNIST test CSV into a FashionMNISTDataset.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
max_samples: Maximum number of samples to use.
|
| 181 |
+
hierarchy_classes: If provided, maps Fashion-MNIST labels to these classes.
|
| 182 |
+
csv_path: Path to fashion-mnist_test.csv. Defaults to config.fashion_mnist_csv.
|
| 183 |
+
"""
|
| 184 |
+
if csv_path is None:
|
| 185 |
+
csv_path = fashion_mnist_csv
|
| 186 |
+
|
| 187 |
+
print("Loading Fashion-MNIST test dataset...")
|
| 188 |
+
df = pd.read_csv(csv_path)
|
| 189 |
+
print(f"Fashion-MNIST dataset loaded: {len(df)} samples")
|
| 190 |
+
|
| 191 |
+
label_mapping = None
|
| 192 |
+
if hierarchy_classes is not None:
|
| 193 |
+
print("\nCreating mapping from Fashion-MNIST labels to hierarchy classes:")
|
| 194 |
+
label_mapping = create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes)
|
| 195 |
+
valid_ids = [lid for lid, h in label_mapping.items() if h is not None]
|
| 196 |
+
df = df[df["label"].isin(valid_ids)]
|
| 197 |
+
print(f"\nAfter filtering to mappable labels: {len(df)} samples")
|
| 198 |
+
|
| 199 |
+
df_sample = df.head(max_samples)
|
| 200 |
+
print(f"Using {len(df_sample)} samples for evaluation")
|
| 201 |
+
return FashionMNISTDataset(df_sample, label_mapping=label_mapping)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ---------------------------------------------------------------------------
|
| 205 |
+
# KAGL Marqo dataset
|
| 206 |
+
# ---------------------------------------------------------------------------
|
| 207 |
+
|
| 208 |
+
class KaggleDataset(Dataset):
|
| 209 |
+
"""Dataset class for KAGL Marqo HuggingFace dataset."""
|
| 210 |
+
|
| 211 |
+
def __init__(self, dataframe: pd.DataFrame, image_size: int = 224, include_hierarchy: bool = False):
|
| 212 |
+
self.dataframe = dataframe
|
| 213 |
+
self.image_size = image_size
|
| 214 |
+
self.include_hierarchy = include_hierarchy
|
| 215 |
+
|
| 216 |
+
self.transform = transforms.Compose([
|
| 217 |
+
transforms.Resize((224, 224)),
|
| 218 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
| 219 |
+
transforms.ToTensor(),
|
| 220 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 221 |
+
])
|
| 222 |
+
|
| 223 |
+
def __len__(self) -> int:
|
| 224 |
+
return len(self.dataframe)
|
| 225 |
+
|
| 226 |
+
def __getitem__(self, idx):
|
| 227 |
+
row = self.dataframe.iloc[idx]
|
| 228 |
+
image_data = row["image_url"]
|
| 229 |
+
|
| 230 |
+
if isinstance(image_data, dict) and "bytes" in image_data:
|
| 231 |
+
image = Image.open(BytesIO(image_data["bytes"])).convert("RGB")
|
| 232 |
+
elif hasattr(image_data, "convert"):
|
| 233 |
+
image = image_data.convert("RGB")
|
| 234 |
+
else:
|
| 235 |
+
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 236 |
+
|
| 237 |
+
image = self.transform(image)
|
| 238 |
+
description = row["text"]
|
| 239 |
+
color = row["color"]
|
| 240 |
+
|
| 241 |
+
if self.include_hierarchy:
|
| 242 |
+
hierarchy = row.get("hierarchy", "unknown")
|
| 243 |
+
return image, description, color, hierarchy
|
| 244 |
+
return image, description, color
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def load_kaggle_marqo_dataset(
|
| 248 |
+
max_samples: int = 5000,
|
| 249 |
+
include_hierarchy: bool = False,
|
| 250 |
+
) -> KaggleDataset:
|
| 251 |
+
"""Download and prepare the KAGL Marqo HuggingFace dataset."""
|
| 252 |
+
from datasets import load_dataset # type: ignore
|
| 253 |
+
|
| 254 |
+
print("Loading KAGL Marqo dataset...")
|
| 255 |
+
dataset = load_dataset("Marqo/KAGL")
|
| 256 |
+
df = dataset["data"].to_pandas()
|
| 257 |
+
print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}")
|
| 258 |
+
|
| 259 |
+
df = df.dropna(subset=["text", "image"])
|
| 260 |
+
|
| 261 |
+
if len(df) > max_samples:
|
| 262 |
+
df = df.sample(n=max_samples, random_state=42)
|
| 263 |
+
print(f"Sampled {max_samples} items")
|
| 264 |
+
|
| 265 |
+
kaggle_df = pd.DataFrame({
|
| 266 |
+
"image_url": df["image"],
|
| 267 |
+
"text": df["text"],
|
| 268 |
+
"color": df["baseColour"].str.lower().str.replace("grey", "gray"),
|
| 269 |
+
})
|
| 270 |
+
|
| 271 |
+
kaggle_df = kaggle_df.dropna(subset=["color"])
|
| 272 |
+
kaggle_df = kaggle_df[kaggle_df["color"].isin(_VALID_COLORS)]
|
| 273 |
+
print(f"After color filtering: {len(kaggle_df)} samples, colors: {sorted(kaggle_df['color'].unique())}")
|
| 274 |
+
|
| 275 |
+
return KaggleDataset(kaggle_df, include_hierarchy=include_hierarchy)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# ---------------------------------------------------------------------------
|
| 279 |
+
# Local validation dataset
|
| 280 |
+
# ---------------------------------------------------------------------------
|
| 281 |
+
|
| 282 |
+
class LocalDataset(Dataset):
|
| 283 |
+
"""Dataset class for the internal local validation dataset."""
|
| 284 |
+
|
| 285 |
+
def __init__(self, dataframe: pd.DataFrame, image_size: int = 224, include_hierarchy: bool = False):
|
| 286 |
+
self.dataframe = dataframe
|
| 287 |
+
self.image_size = image_size
|
| 288 |
+
self.include_hierarchy = include_hierarchy
|
| 289 |
+
|
| 290 |
+
self.transform = transforms.Compose([
|
| 291 |
+
transforms.Resize((224, 224)),
|
| 292 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
| 293 |
+
transforms.ToTensor(),
|
| 294 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 295 |
+
])
|
| 296 |
+
|
| 297 |
+
def __len__(self) -> int:
|
| 298 |
+
return len(self.dataframe)
|
| 299 |
+
|
| 300 |
+
def __getitem__(self, idx):
|
| 301 |
+
row = self.dataframe.iloc[idx]
|
| 302 |
+
try:
|
| 303 |
+
image_path = row.get(column_local_image_path) if hasattr(row, "get") else None
|
| 304 |
+
if isinstance(image_path, str) and image_path and Path(image_path).exists():
|
| 305 |
+
image = Image.open(image_path).convert("RGB")
|
| 306 |
+
else:
|
| 307 |
+
# Fallback: download image from URL (and cache).
|
| 308 |
+
image_url = row.get("image_url") if hasattr(row, "get") else None
|
| 309 |
+
if isinstance(image_url, dict) and "bytes" in image_url:
|
| 310 |
+
image = Image.open(BytesIO(image_url["bytes"])).convert("RGB")
|
| 311 |
+
elif isinstance(image_url, str) and image_url:
|
| 312 |
+
cache_dir = Path(images_dir)
|
| 313 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 314 |
+
url_hash = hashlib.md5(image_url.encode("utf-8")).hexdigest()
|
| 315 |
+
cache_path = cache_dir / f"{url_hash}.jpg"
|
| 316 |
+
if cache_path.exists():
|
| 317 |
+
image = Image.open(cache_path).convert("RGB")
|
| 318 |
+
else:
|
| 319 |
+
resp = requests.get(image_url, timeout=10)
|
| 320 |
+
resp.raise_for_status()
|
| 321 |
+
image = Image.open(BytesIO(resp.content)).convert("RGB")
|
| 322 |
+
image.save(cache_path, "JPEG", quality=85, optimize=True)
|
| 323 |
+
else:
|
| 324 |
+
raise ValueError("Missing image_path and image_url")
|
| 325 |
+
except Exception as e:
|
| 326 |
+
print(f"Error loading image: {e}")
|
| 327 |
+
image = Image.new("RGB", (224, 224), color="gray")
|
| 328 |
+
image = self.transform(image)
|
| 329 |
+
|
| 330 |
+
description = row["text"]
|
| 331 |
+
color = row["color"]
|
| 332 |
+
|
| 333 |
+
if self.include_hierarchy:
|
| 334 |
+
hierarchy = row.get("hierarchy", "unknown")
|
| 335 |
+
return image, description, color, hierarchy
|
| 336 |
+
return image, description, color
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def load_local_validation_dataset(
|
| 340 |
+
max_samples: int = 5000,
|
| 341 |
+
include_hierarchy: bool = False,
|
| 342 |
+
) -> LocalDataset:
|
| 343 |
+
"""Load and prepare the internal local validation dataset."""
|
| 344 |
+
print("Loading local validation dataset...")
|
| 345 |
+
df = pd.read_csv(local_dataset_path)
|
| 346 |
+
print(f"Dataset loaded: {len(df)} samples")
|
| 347 |
+
|
| 348 |
+
if column_local_image_path in df.columns:
|
| 349 |
+
df = df.dropna(subset=[column_local_image_path])
|
| 350 |
+
print(f"After filtering NaN image paths: {len(df)} samples")
|
| 351 |
+
else:
|
| 352 |
+
print(f"Column '{column_local_image_path}' not found; falling back to 'image_url'.")
|
| 353 |
+
|
| 354 |
+
if "color" in df.columns:
|
| 355 |
+
df = df[df["color"].isin(_VALID_COLORS)]
|
| 356 |
+
print(f"After color filtering: {len(df)} samples, colors: {sorted(df['color'].unique())}")
|
| 357 |
+
|
| 358 |
+
if len(df) > max_samples:
|
| 359 |
+
df = df.sample(n=max_samples, random_state=42)
|
| 360 |
+
print(f"Sampled {max_samples} items")
|
| 361 |
+
|
| 362 |
+
print(f"Using {len(df)} samples for evaluation")
|
| 363 |
+
return LocalDataset(df, include_hierarchy=include_hierarchy)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ---------------------------------------------------------------------------
|
| 367 |
+
# DataLoader utilities
|
| 368 |
+
# ---------------------------------------------------------------------------
|
| 369 |
+
|
| 370 |
+
def collate_fn_filter_none(batch):
|
| 371 |
+
"""Collate function that silently drops None items from a batch."""
|
| 372 |
+
original_len = len(batch)
|
| 373 |
+
batch = [item for item in batch if item is not None]
|
| 374 |
+
if original_len > len(batch):
|
| 375 |
+
print(f"Filtered out {original_len - len(batch)} None values from batch")
|
| 376 |
+
if not batch:
|
| 377 |
+
print("Empty batch after filtering None values")
|
| 378 |
+
return torch.tensor([]), [], []
|
| 379 |
+
images, texts, colors = zip(*batch)
|
| 380 |
+
return torch.stack(images), list(texts), list(colors)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# ---------------------------------------------------------------------------
|
| 384 |
+
# Text normalisation helpers
|
| 385 |
+
# ---------------------------------------------------------------------------
|
| 386 |
+
|
| 387 |
+
def normalize_hierarchy_label(label: str) -> str:
|
| 388 |
+
"""Lower-case and strip a hierarchy label for consistent comparison."""
|
| 389 |
+
return label.lower().strip() if label else ""
|
evaluation/utils/metrics.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared evaluation metrics for GAP-CLIP experiments.
|
| 3 |
+
|
| 4 |
+
Provides nearest-neighbor accuracy, separation score, centroid-based accuracy,
|
| 5 |
+
and confusion matrix generation — used across all evaluation sections.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from typing import List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import numpy as np
|
| 15 |
+
import seaborn as sns
|
| 16 |
+
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
| 17 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 18 |
+
from sklearn.preprocessing import normalize
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def compute_similarity_metrics(
|
| 22 |
+
embeddings: np.ndarray,
|
| 23 |
+
labels: List[str],
|
| 24 |
+
max_samples: int = 5000,
|
| 25 |
+
) -> dict:
|
| 26 |
+
"""Compute intra/inter-class similarities and nearest-neighbor accuracy.
|
| 27 |
+
|
| 28 |
+
Uses vectorized numpy operations for efficiency.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
embeddings: Array of shape (N, D).
|
| 32 |
+
labels: List of N class labels.
|
| 33 |
+
max_samples: Cap for large datasets (random subsample).
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Dict with keys: intra_class_mean, inter_class_mean, separation_score,
|
| 37 |
+
accuracy (NN), centroid_accuracy, intra_class_similarities,
|
| 38 |
+
inter_class_similarities.
|
| 39 |
+
"""
|
| 40 |
+
if len(embeddings) > max_samples:
|
| 41 |
+
indices = np.random.choice(len(embeddings), max_samples, replace=False)
|
| 42 |
+
embeddings = embeddings[indices]
|
| 43 |
+
labels = [labels[i] for i in indices]
|
| 44 |
+
|
| 45 |
+
similarities = cosine_similarity(embeddings)
|
| 46 |
+
|
| 47 |
+
label_array = np.array(labels)
|
| 48 |
+
unique_labels = np.unique(label_array)
|
| 49 |
+
label_groups = {label: np.where(label_array == label)[0] for label in unique_labels}
|
| 50 |
+
|
| 51 |
+
intra_class_similarities: List[float] = []
|
| 52 |
+
for indices in label_groups.values():
|
| 53 |
+
if len(indices) > 1:
|
| 54 |
+
sub = similarities[np.ix_(indices, indices)]
|
| 55 |
+
triu = np.triu_indices_from(sub, k=1)
|
| 56 |
+
intra_class_similarities.extend(sub[triu].tolist())
|
| 57 |
+
|
| 58 |
+
inter_class_similarities: List[float] = []
|
| 59 |
+
keys = list(label_groups.keys())
|
| 60 |
+
for i in range(len(keys)):
|
| 61 |
+
for j in range(i + 1, len(keys)):
|
| 62 |
+
inter = similarities[np.ix_(label_groups[keys[i]], label_groups[keys[j]])]
|
| 63 |
+
inter_class_similarities.extend(inter.flatten().tolist())
|
| 64 |
+
|
| 65 |
+
nn_acc = compute_embedding_accuracy(embeddings, labels, similarities)
|
| 66 |
+
centroid_acc = compute_centroid_accuracy(embeddings, labels)
|
| 67 |
+
|
| 68 |
+
return {
|
| 69 |
+
"intra_class_similarities": intra_class_similarities,
|
| 70 |
+
"inter_class_similarities": inter_class_similarities,
|
| 71 |
+
"intra_class_mean": float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0,
|
| 72 |
+
"inter_class_mean": float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0,
|
| 73 |
+
"separation_score": (
|
| 74 |
+
float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities))
|
| 75 |
+
if intra_class_similarities and inter_class_similarities
|
| 76 |
+
else 0.0
|
| 77 |
+
),
|
| 78 |
+
"accuracy": nn_acc,
|
| 79 |
+
"centroid_accuracy": centroid_acc,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def compute_embedding_accuracy(
|
| 84 |
+
embeddings: np.ndarray,
|
| 85 |
+
labels: List[str],
|
| 86 |
+
similarities: Optional[np.ndarray] = None,
|
| 87 |
+
) -> float:
|
| 88 |
+
"""Nearest-neighbor classification accuracy (leave-one-out).
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
embeddings: Array of shape (N, D).
|
| 92 |
+
labels: List of N class labels.
|
| 93 |
+
similarities: Pre-computed cosine similarity matrix (N, N). Computed
|
| 94 |
+
if not provided.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Fraction of samples whose nearest neighbor shares their label.
|
| 98 |
+
"""
|
| 99 |
+
n = len(embeddings)
|
| 100 |
+
if n == 0:
|
| 101 |
+
return 0.0
|
| 102 |
+
if similarities is None:
|
| 103 |
+
similarities = cosine_similarity(embeddings)
|
| 104 |
+
|
| 105 |
+
correct = 0
|
| 106 |
+
for i in range(n):
|
| 107 |
+
sims = similarities[i].copy()
|
| 108 |
+
sims[i] = -1.0
|
| 109 |
+
if labels[np.argmax(sims)] == labels[i]:
|
| 110 |
+
correct += 1
|
| 111 |
+
return correct / n
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def compute_centroid_accuracy(
|
| 115 |
+
embeddings: np.ndarray,
|
| 116 |
+
labels: List[str],
|
| 117 |
+
) -> float:
|
| 118 |
+
"""Centroid-based (1-NN centroid) classification accuracy.
|
| 119 |
+
|
| 120 |
+
Uses L2-normalized embeddings and centroids for correct cosine comparison.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
embeddings: Array of shape (N, D).
|
| 124 |
+
labels: List of N class labels.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Fraction of samples classified correctly by nearest centroid.
|
| 128 |
+
"""
|
| 129 |
+
if len(embeddings) == 0:
|
| 130 |
+
return 0.0
|
| 131 |
+
|
| 132 |
+
emb_norm = normalize(embeddings, norm="l2")
|
| 133 |
+
unique_labels = sorted(set(labels))
|
| 134 |
+
centroids = {}
|
| 135 |
+
for label in unique_labels:
|
| 136 |
+
idx = [i for i, l in enumerate(labels) if l == label]
|
| 137 |
+
centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm="l2")[0]
|
| 138 |
+
|
| 139 |
+
centroid_labels = list(centroids.keys())
|
| 140 |
+
centroid_matrix = np.vstack([centroids[l] for l in centroid_labels])
|
| 141 |
+
sims = cosine_similarity(emb_norm, centroid_matrix)
|
| 142 |
+
predicted = [centroid_labels[int(np.argmax(row))] for row in sims]
|
| 143 |
+
return sum(p == t for p, t in zip(predicted, labels)) / len(labels)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def predict_labels_from_embeddings(
|
| 147 |
+
embeddings: np.ndarray,
|
| 148 |
+
labels: List[str],
|
| 149 |
+
) -> List[str]:
|
| 150 |
+
"""Predict a label for each embedding using nearest centroid.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
List of predicted labels (same length as embeddings).
|
| 154 |
+
"""
|
| 155 |
+
valid_labels = [l for l in set(labels) if l is not None]
|
| 156 |
+
if not valid_labels:
|
| 157 |
+
return [None] * len(embeddings)
|
| 158 |
+
|
| 159 |
+
emb_norm = normalize(embeddings, norm="l2")
|
| 160 |
+
centroids = {}
|
| 161 |
+
for label in valid_labels:
|
| 162 |
+
mask = np.array(labels) == label
|
| 163 |
+
if np.any(mask):
|
| 164 |
+
centroids[label] = np.mean(emb_norm[mask], axis=0)
|
| 165 |
+
|
| 166 |
+
centroid_labels = list(centroids.keys())
|
| 167 |
+
centroid_matrix = np.vstack([centroids[l] for l in centroid_labels])
|
| 168 |
+
sims = cosine_similarity(emb_norm, centroid_matrix)
|
| 169 |
+
return [centroid_labels[int(np.argmax(row))] for row in sims]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def create_confusion_matrix(
|
| 173 |
+
true_labels: List[str],
|
| 174 |
+
predicted_labels: List[str],
|
| 175 |
+
title: str = "Confusion Matrix",
|
| 176 |
+
label_type: str = "Label",
|
| 177 |
+
) -> Tuple[plt.Figure, float, np.ndarray]:
|
| 178 |
+
"""Create and return a seaborn confusion-matrix heatmap figure.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
true_labels: Ground-truth labels.
|
| 182 |
+
predicted_labels: Predicted labels.
|
| 183 |
+
title: Plot title prefix.
|
| 184 |
+
label_type: Axis label (e.g. "Color", "Category").
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
(fig, accuracy, cm_array)
|
| 188 |
+
"""
|
| 189 |
+
unique_labels = sorted(set(true_labels + predicted_labels))
|
| 190 |
+
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
|
| 191 |
+
acc = accuracy_score(true_labels, predicted_labels)
|
| 192 |
+
|
| 193 |
+
fig = plt.figure(figsize=(10, 8))
|
| 194 |
+
sns.heatmap(
|
| 195 |
+
cm,
|
| 196 |
+
annot=True,
|
| 197 |
+
fmt="d",
|
| 198 |
+
cmap="Blues",
|
| 199 |
+
xticklabels=unique_labels,
|
| 200 |
+
yticklabels=unique_labels,
|
| 201 |
+
)
|
| 202 |
+
plt.title(f"{title}\nAccuracy: {acc:.3f} ({acc * 100:.1f}%)")
|
| 203 |
+
plt.ylabel(f"True {label_type}")
|
| 204 |
+
plt.xlabel(f"Predicted {label_type}")
|
| 205 |
+
plt.xticks(rotation=45)
|
| 206 |
+
plt.yticks(rotation=0)
|
| 207 |
+
plt.tight_layout()
|
| 208 |
+
return fig, acc, cm
|
evaluation/utils/model_loader.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared model loading and embedding extraction utilities.
|
| 3 |
+
|
| 4 |
+
All evaluation scripts that need to load GAP-CLIP, the Fashion-CLIP baseline,
|
| 5 |
+
or the specialized color model should import from here instead of duplicating
|
| 6 |
+
the loading logic.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from transformers import CLIPModel as CLIPModelTransformers
|
| 21 |
+
from transformers import CLIPProcessor
|
| 22 |
+
|
| 23 |
+
# Make project root importable when running evaluation scripts directly.
|
| 24 |
+
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 25 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 26 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# GAP-CLIP (main model)
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
def load_gap_clip(
|
| 34 |
+
model_path: str,
|
| 35 |
+
device: torch.device,
|
| 36 |
+
) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
|
| 37 |
+
"""Load GAP-CLIP (LAION CLIP + fine-tuned checkpoint) and its processor.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model_path: Path to the `gap_clip.pth` checkpoint.
|
| 41 |
+
device: Target device.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
(model, processor) ready for inference.
|
| 45 |
+
"""
|
| 46 |
+
model = CLIPModelTransformers.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
|
| 47 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 48 |
+
|
| 49 |
+
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
| 50 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 51 |
+
else:
|
| 52 |
+
model.load_state_dict(checkpoint)
|
| 53 |
+
|
| 54 |
+
model = model.to(device)
|
| 55 |
+
model.eval()
|
| 56 |
+
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
|
| 57 |
+
return model, processor
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
# Fashion-CLIP baseline
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
def load_baseline_fashion_clip(
|
| 65 |
+
device: torch.device,
|
| 66 |
+
) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
|
| 67 |
+
"""Load the Fashion-CLIP baseline (patrickjohncyh/fashion-clip).
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
(model, processor) ready for inference.
|
| 71 |
+
"""
|
| 72 |
+
model_name = "patrickjohncyh/fashion-clip"
|
| 73 |
+
processor = CLIPProcessor.from_pretrained(model_name)
|
| 74 |
+
model = CLIPModelTransformers.from_pretrained(model_name).to(device)
|
| 75 |
+
model.eval()
|
| 76 |
+
return model, processor
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# Specialized 16D color model
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
def load_color_model(
|
| 84 |
+
color_model_path: str,
|
| 85 |
+
tokenizer_path: str,
|
| 86 |
+
color_emb_dim: int,
|
| 87 |
+
device: torch.device,
|
| 88 |
+
repo_id: str = "Leacb4/gap-clip",
|
| 89 |
+
cache_dir: str = "./models_cache",
|
| 90 |
+
):
|
| 91 |
+
"""Load the specialized 16D color model (ColorCLIP) and its tokenizer.
|
| 92 |
+
|
| 93 |
+
Falls back to Hugging Face Hub if local files are not found.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
(color_model, color_tokenizer)
|
| 97 |
+
"""
|
| 98 |
+
from training.color_model import ColorCLIP, Tokenizer # type: ignore
|
| 99 |
+
|
| 100 |
+
local_model_exists = os.path.exists(color_model_path)
|
| 101 |
+
local_tokenizer_exists = os.path.exists(tokenizer_path)
|
| 102 |
+
|
| 103 |
+
if local_model_exists and local_tokenizer_exists:
|
| 104 |
+
print("Loading specialized color model (16D) from local files...")
|
| 105 |
+
state_dict = torch.load(color_model_path, map_location=device)
|
| 106 |
+
with open(tokenizer_path, "r") as f:
|
| 107 |
+
vocab = json.load(f)
|
| 108 |
+
else:
|
| 109 |
+
from huggingface_hub import hf_hub_download # type: ignore
|
| 110 |
+
|
| 111 |
+
print(f"Local color model/tokenizer not found. Loading from Hugging Face ({repo_id})...")
|
| 112 |
+
hf_model_path = hf_hub_download(
|
| 113 |
+
repo_id=repo_id, filename="color_model.pt", cache_dir=cache_dir
|
| 114 |
+
)
|
| 115 |
+
hf_vocab_path = hf_hub_download(
|
| 116 |
+
repo_id=repo_id, filename="tokenizer_vocab.json", cache_dir=cache_dir
|
| 117 |
+
)
|
| 118 |
+
state_dict = torch.load(hf_model_path, map_location=device)
|
| 119 |
+
with open(hf_vocab_path, "r") as f:
|
| 120 |
+
vocab = json.load(f)
|
| 121 |
+
|
| 122 |
+
vocab_size = state_dict["text_encoder.embedding.weight"].shape[0]
|
| 123 |
+
print(f" Detected vocab size from checkpoint: {vocab_size}")
|
| 124 |
+
|
| 125 |
+
tokenizer = Tokenizer()
|
| 126 |
+
tokenizer.load_vocab(vocab)
|
| 127 |
+
|
| 128 |
+
color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=color_emb_dim)
|
| 129 |
+
color_model.load_state_dict(state_dict)
|
| 130 |
+
color_model.to(device)
|
| 131 |
+
color_model.eval()
|
| 132 |
+
print("Color model loaded successfully")
|
| 133 |
+
return color_model, tokenizer
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# ---------------------------------------------------------------------------
|
| 137 |
+
# Embedding extraction helpers
|
| 138 |
+
# ---------------------------------------------------------------------------
|
| 139 |
+
|
| 140 |
+
def get_text_embedding(
|
| 141 |
+
model: CLIPModelTransformers,
|
| 142 |
+
processor: CLIPProcessor,
|
| 143 |
+
device: torch.device,
|
| 144 |
+
text: str,
|
| 145 |
+
) -> torch.Tensor:
|
| 146 |
+
"""Extract a single normalized text embedding (shape: [512])."""
|
| 147 |
+
text_inputs = processor(text=[text], padding=True, return_tensors="pt")
|
| 148 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 149 |
+
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
text_outputs = model.text_model(**text_inputs)
|
| 152 |
+
text_features = model.text_projection(text_outputs.pooler_output)
|
| 153 |
+
text_features = F.normalize(text_features, dim=-1)
|
| 154 |
+
|
| 155 |
+
return text_features.squeeze(0)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_text_embeddings_batch(
|
| 159 |
+
model: CLIPModelTransformers,
|
| 160 |
+
processor: CLIPProcessor,
|
| 161 |
+
device: torch.device,
|
| 162 |
+
texts: list[str],
|
| 163 |
+
) -> torch.Tensor:
|
| 164 |
+
"""Extract normalized text embeddings for a batch of strings (shape: [N, 512])."""
|
| 165 |
+
text_inputs = processor(text=texts, padding=True, return_tensors="pt", truncation=True, max_length=77)
|
| 166 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 167 |
+
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
text_outputs = model.text_model(**text_inputs)
|
| 170 |
+
text_features = model.text_projection(text_outputs.pooler_output)
|
| 171 |
+
text_features = F.normalize(text_features, dim=-1)
|
| 172 |
+
|
| 173 |
+
return text_features
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_image_embedding(
|
| 177 |
+
model: CLIPModelTransformers,
|
| 178 |
+
image: torch.Tensor,
|
| 179 |
+
device: torch.device,
|
| 180 |
+
) -> torch.Tensor:
|
| 181 |
+
"""Extract a normalized image embedding from a preprocessed tensor.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
model: GAP-CLIP model.
|
| 185 |
+
image: Tensor of shape (C, H, W) or (1, C, H, W) or (N, C, H, W).
|
| 186 |
+
device: Target device.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Normalized embedding tensor of shape (1, 512) or (N, 512).
|
| 190 |
+
"""
|
| 191 |
+
model.eval()
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
if image.dim() == 3 and image.size(0) == 1:
|
| 194 |
+
image = image.expand(3, -1, -1)
|
| 195 |
+
elif image.dim() == 4 and image.size(1) == 1:
|
| 196 |
+
image = image.expand(-1, 3, -1, -1)
|
| 197 |
+
if image.dim() == 3:
|
| 198 |
+
image = image.unsqueeze(0)
|
| 199 |
+
|
| 200 |
+
image = image.to(device)
|
| 201 |
+
vision_outputs = model.vision_model(pixel_values=image)
|
| 202 |
+
image_features = model.visual_projection(vision_outputs.pooler_output)
|
| 203 |
+
return F.normalize(image_features, dim=-1)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_image_embedding_from_pil(
|
| 207 |
+
model: CLIPModelTransformers,
|
| 208 |
+
processor: CLIPProcessor,
|
| 209 |
+
device: torch.device,
|
| 210 |
+
pil_image: Image.Image,
|
| 211 |
+
) -> torch.Tensor:
|
| 212 |
+
"""Extract a normalized image embedding from a PIL image (shape: [512])."""
|
| 213 |
+
inputs = processor(images=pil_image, return_tensors="pt")
|
| 214 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 215 |
+
|
| 216 |
+
with torch.no_grad():
|
| 217 |
+
vision_outputs = model.vision_model(**inputs)
|
| 218 |
+
image_features = model.visual_projection(vision_outputs.pooler_output)
|
| 219 |
+
image_features = F.normalize(image_features, dim=-1)
|
| 220 |
+
|
| 221 |
+
return image_features.squeeze(0)
|
example_usage.py
CHANGED
|
@@ -15,8 +15,8 @@ import json
|
|
| 15 |
import os
|
| 16 |
|
| 17 |
# Import local models (to adapt to your structure)
|
| 18 |
-
from color_model import ColorCLIP, Tokenizer
|
| 19 |
-
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 20 |
import config
|
| 21 |
|
| 22 |
def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
|
|
|
| 15 |
import os
|
| 16 |
|
| 17 |
# Import local models (to adapt to your structure)
|
| 18 |
+
from training.color_model import ColorCLIP, Tokenizer
|
| 19 |
+
from training.hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 20 |
import config
|
| 21 |
|
| 22 |
def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
figures/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
color_model.pt → figures/baseline_blue_pant.png
RENAMED
|
File without changes
|
hierarchy_model.pth → figures/baseline_red_dress.png
RENAMED
|
File without changes
|
figures/confusion_matrices/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
gap_clip.pth → figures/confusion_matrices/cm_color/kaggle_baseline_image_color_confusion_matrix.png
RENAMED
|
File without changes
|
figures/confusion_matrices/cm_color/kaggle_baseline_text_color_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_color/kaggle_image_color_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_color/kaggle_text_color_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_color/local_baseline_image_color_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_color/local_baseline_text_color_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_color/local_image_color_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_color/local_text_color_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/baseline_image_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/baseline_internal_image_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/baseline_internal_text_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/baseline_kagl_marqo_image_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/baseline_kagl_marqo_text_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/baseline_text_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/gap_clip_image_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/gap_clip_internal_image_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/gap_clip_internal_text_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/gap_clip_kagl_marqo_image_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|
figures/confusion_matrices/cm_hierarchy/gap_clip_kagl_marqo_text_hierarchy_confusion_matrix.png
ADDED
|
Git LFS Details
|