Upload evaluation/annex94_search_demo.py with huggingface_hub
Browse files
evaluation/annex94_search_demo.py
CHANGED
|
@@ -32,12 +32,8 @@ PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
| 32 |
if str(PROJECT_ROOT) not in sys.path:
|
| 33 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 34 |
|
| 35 |
-
# Import custom models
|
| 36 |
-
|
| 37 |
-
from training.color_model import CLIPModel as ColorModel
|
| 38 |
-
except ModuleNotFoundError:
|
| 39 |
-
ColorModel = None
|
| 40 |
-
from training.hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 41 |
import config
|
| 42 |
|
| 43 |
warnings.filterwarnings("ignore")
|
|
@@ -77,39 +73,28 @@ class FashionSearchEngine:
|
|
| 77 |
print("✅ Fashion Search Engine ready!")
|
| 78 |
|
| 79 |
def _load_models(self):
|
| 80 |
-
"""Load all required models"""
|
| 81 |
print("📦 Loading models...")
|
| 82 |
|
| 83 |
# Load color model (optional for search in this script).
|
| 84 |
self.color_model = None
|
| 85 |
color_model_path = getattr(config, "color_model_path", None)
|
| 86 |
-
if
|
| 87 |
-
print("⚠️ color_model.py not found; continuing without color model.")
|
| 88 |
-
elif not color_model_path or not Path(color_model_path).exists():
|
| 89 |
print("⚠️ color model checkpoint not found; continuing without color model.")
|
| 90 |
else:
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
| 97 |
|
| 98 |
# Load hierarchy model
|
| 99 |
-
|
| 100 |
-
config.hierarchy_model_path,
|
|
|
|
| 101 |
)
|
| 102 |
-
self.hierarchy_classes = hierarchy_checkpoint.get("hierarchy_classes", [])
|
| 103 |
-
self.hierarchy_model = HierarchyModel(
|
| 104 |
-
num_hierarchy_classes=len(self.hierarchy_classes),
|
| 105 |
-
embed_dim=self.hierarchy_dim,
|
| 106 |
-
).to(self.device)
|
| 107 |
-
self.hierarchy_model.load_state_dict(hierarchy_checkpoint["model_state"])
|
| 108 |
-
|
| 109 |
-
# Set hierarchy extractor
|
| 110 |
-
hierarchy_extractor = HierarchyExtractor(self.hierarchy_classes, verbose=False)
|
| 111 |
-
self.hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
|
| 112 |
-
self.hierarchy_model.eval()
|
| 113 |
|
| 114 |
# Load main CLIP model (baseline or fine-tuned GAP-CLIP)
|
| 115 |
if self.use_baseline:
|
|
@@ -409,7 +394,7 @@ if __name__ == "__main__":
|
|
| 409 |
|
| 410 |
if args.queries:
|
| 411 |
all_results = {}
|
| 412 |
-
figures_dir = Path(
|
| 413 |
figures_dir.mkdir(parents=True, exist_ok=True)
|
| 414 |
(figures_dir / "figures").mkdir(parents=True, exist_ok=True)
|
| 415 |
|
|
|
|
| 32 |
if str(PROJECT_ROOT) not in sys.path:
|
| 33 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 34 |
|
| 35 |
+
# Import custom models via shared loaders
|
| 36 |
+
from evaluation.utils.model_loader import load_color_model, load_hierarchy_model
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
import config
|
| 38 |
|
| 39 |
warnings.filterwarnings("ignore")
|
|
|
|
| 73 |
print("✅ Fashion Search Engine ready!")
|
| 74 |
|
| 75 |
def _load_models(self):
|
| 76 |
+
"""Load all required models."""
|
| 77 |
print("📦 Loading models...")
|
| 78 |
|
| 79 |
# Load color model (optional for search in this script).
|
| 80 |
self.color_model = None
|
| 81 |
color_model_path = getattr(config, "color_model_path", None)
|
| 82 |
+
if not color_model_path or not Path(color_model_path).exists():
|
|
|
|
|
|
|
| 83 |
print("⚠️ color model checkpoint not found; continuing without color model.")
|
| 84 |
else:
|
| 85 |
+
try:
|
| 86 |
+
self.color_model, _ = load_color_model(
|
| 87 |
+
color_model_path=config.color_model_path,
|
| 88 |
+
device=self.device,
|
| 89 |
+
)
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"⚠️ Failed to load color model: {e}; continuing without it.")
|
| 92 |
|
| 93 |
# Load hierarchy model
|
| 94 |
+
self.hierarchy_model = load_hierarchy_model(
|
| 95 |
+
hierarchy_model_path=config.hierarchy_model_path,
|
| 96 |
+
device=self.device,
|
| 97 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
# Load main CLIP model (baseline or fine-tuned GAP-CLIP)
|
| 100 |
if self.use_baseline:
|
|
|
|
| 394 |
|
| 395 |
if args.queries:
|
| 396 |
all_results = {}
|
| 397 |
+
figures_dir = Path("evaluation")
|
| 398 |
figures_dir.mkdir(parents=True, exist_ok=True)
|
| 399 |
(figures_dir / "figures").mkdir(parents=True, exist_ok=True)
|
| 400 |
|