Leacb4 commited on
Commit
8278380
·
verified ·
1 Parent(s): 8355833

Upload evaluation/annex94_search_demo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/annex94_search_demo.py +15 -30
evaluation/annex94_search_demo.py CHANGED
@@ -32,12 +32,8 @@ PROJECT_ROOT = Path(__file__).resolve().parents[1]
32
  if str(PROJECT_ROOT) not in sys.path:
33
  sys.path.insert(0, str(PROJECT_ROOT))
34
 
35
- # Import custom models
36
- try:
37
- from training.color_model import CLIPModel as ColorModel
38
- except ModuleNotFoundError:
39
- ColorModel = None
40
- from training.hierarchy_model import Model as HierarchyModel, HierarchyExtractor
41
  import config
42
 
43
  warnings.filterwarnings("ignore")
@@ -77,39 +73,28 @@ class FashionSearchEngine:
77
  print("✅ Fashion Search Engine ready!")
78
 
79
  def _load_models(self):
80
- """Load all required models"""
81
  print("📦 Loading models...")
82
 
83
  # Load color model (optional for search in this script).
84
  self.color_model = None
85
  color_model_path = getattr(config, "color_model_path", None)
86
- if ColorModel is None:
87
- print("⚠️ color_model.py not found; continuing without color model.")
88
- elif not color_model_path or not Path(color_model_path).exists():
89
  print("⚠️ color model checkpoint not found; continuing without color model.")
90
  else:
91
- color_checkpoint = torch.load(
92
- color_model_path, map_location=self.device, weights_only=True
93
- )
94
- self.color_model = ColorModel(embed_dim=self.color_dim).to(self.device)
95
- self.color_model.load_state_dict(color_checkpoint)
96
- self.color_model.eval()
 
97
 
98
  # Load hierarchy model
99
- hierarchy_checkpoint = torch.load(
100
- config.hierarchy_model_path, map_location=self.device
 
101
  )
102
- self.hierarchy_classes = hierarchy_checkpoint.get("hierarchy_classes", [])
103
- self.hierarchy_model = HierarchyModel(
104
- num_hierarchy_classes=len(self.hierarchy_classes),
105
- embed_dim=self.hierarchy_dim,
106
- ).to(self.device)
107
- self.hierarchy_model.load_state_dict(hierarchy_checkpoint["model_state"])
108
-
109
- # Set hierarchy extractor
110
- hierarchy_extractor = HierarchyExtractor(self.hierarchy_classes, verbose=False)
111
- self.hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
112
- self.hierarchy_model.eval()
113
 
114
  # Load main CLIP model (baseline or fine-tuned GAP-CLIP)
115
  if self.use_baseline:
@@ -409,7 +394,7 @@ if __name__ == "__main__":
409
 
410
  if args.queries:
411
  all_results = {}
412
- figures_dir = Path(args.save).parent if args.save else Path("evaluation")
413
  figures_dir.mkdir(parents=True, exist_ok=True)
414
  (figures_dir / "figures").mkdir(parents=True, exist_ok=True)
415
 
 
32
  if str(PROJECT_ROOT) not in sys.path:
33
  sys.path.insert(0, str(PROJECT_ROOT))
34
 
35
+ # Import custom models via shared loaders
36
+ from evaluation.utils.model_loader import load_color_model, load_hierarchy_model
 
 
 
 
37
  import config
38
 
39
  warnings.filterwarnings("ignore")
 
73
  print("✅ Fashion Search Engine ready!")
74
 
75
  def _load_models(self):
76
+ """Load all required models."""
77
  print("📦 Loading models...")
78
 
79
  # Load color model (optional for search in this script).
80
  self.color_model = None
81
  color_model_path = getattr(config, "color_model_path", None)
82
+ if not color_model_path or not Path(color_model_path).exists():
 
 
83
  print("⚠️ color model checkpoint not found; continuing without color model.")
84
  else:
85
+ try:
86
+ self.color_model, _ = load_color_model(
87
+ color_model_path=config.color_model_path,
88
+ device=self.device,
89
+ )
90
+ except Exception as e:
91
+ print(f"⚠️ Failed to load color model: {e}; continuing without it.")
92
 
93
  # Load hierarchy model
94
+ self.hierarchy_model = load_hierarchy_model(
95
+ hierarchy_model_path=config.hierarchy_model_path,
96
+ device=self.device,
97
  )
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # Load main CLIP model (baseline or fine-tuned GAP-CLIP)
100
  if self.use_baseline:
 
394
 
395
  if args.queries:
396
  all_results = {}
397
+ figures_dir = Path("evaluation")
398
  figures_dir.mkdir(parents=True, exist_ok=True)
399
  (figures_dir / "figures").mkdir(parents=True, exist_ok=True)
400