Leacb4 commited on
Commit
8debb94
·
verified ·
1 Parent(s): 89c04d6

Upload example_usage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_usage.py +206 -0
example_usage.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Exemple d'utilisation des modèles depuis Hugging Face
4
+ """
5
+
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
9
+ from huggingface_hub import hf_hub_download
10
+ import json
11
+ import os
12
+
13
+ # Import des modèles locaux (à adapter selon votre structure)
14
+ from color_model import ColorCLIP, SimpleTokenizer
15
+ from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
16
+ from config import color_emb_dim, hierarchy_emb_dim
17
+
18
+ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
19
+ """
20
+ Charger les modèles depuis Hugging Face
21
+
22
+ Args:
23
+ repo_id: ID du repository Hugging Face
24
+ cache_dir: Dossier de cache local
25
+ """
26
+
27
+ os.makedirs(cache_dir, exist_ok=True)
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ print(f"📥 Chargement des modèles depuis '{repo_id}'...")
31
+
32
+ # 1. Charger le modèle de couleur
33
+ print(" 📦 Chargement du modèle de couleur...")
34
+ color_model_path = hf_hub_download(
35
+ repo_id=repo_id,
36
+ filename="color_model.pt",
37
+ cache_dir=cache_dir
38
+ )
39
+
40
+ # Charger le vocabulaire
41
+ vocab_path = hf_hub_download(
42
+ repo_id=repo_id,
43
+ filename="tokenizer_vocab.json",
44
+ cache_dir=cache_dir
45
+ )
46
+
47
+ with open(vocab_path, 'r') as f:
48
+ vocab_dict = json.load(f)
49
+
50
+ tokenizer = SimpleTokenizer()
51
+ tokenizer.load_vocab(vocab_dict)
52
+
53
+ checkpoint = torch.load(color_model_path, map_location=device)
54
+ vocab_size = checkpoint['text_encoder.embedding.weight'].shape[0]
55
+ color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=color_emb_dim).to(device)
56
+ color_model.tokenizer = tokenizer
57
+ color_model.load_state_dict(checkpoint)
58
+ color_model.eval()
59
+ print(" ✅ Modèle de couleur chargé")
60
+
61
+ # 2. Charger le modèle de hiérarchie
62
+ print(" 📦 Chargement du modèle de hiérarchie...")
63
+ hierarchy_model_path = hf_hub_download(
64
+ repo_id=repo_id,
65
+ filename="hierarchy_model.pth",
66
+ cache_dir=cache_dir
67
+ )
68
+
69
+ hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=device)
70
+ hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
71
+
72
+ hierarchy_model = HierarchyModel(
73
+ num_hierarchy_classes=len(hierarchy_classes),
74
+ embed_dim=hierarchy_emb_dim
75
+ ).to(device)
76
+ hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
77
+
78
+ hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
79
+ hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
80
+ hierarchy_model.eval()
81
+ print(" ✅ Modèle de hiérarchie chargé")
82
+
83
+ # 3. Charger le modèle principal CLIP
84
+ print(" 📦 Chargement du modèle principal CLIP...")
85
+ main_model_path = hf_hub_download(
86
+ repo_id=repo_id,
87
+ filename="laion_explicable_model.pth",
88
+ cache_dir=cache_dir
89
+ )
90
+
91
+ clip_model = CLIPModel_transformers.from_pretrained(
92
+ 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
93
+ )
94
+ checkpoint = torch.load(main_model_path, map_location=device)
95
+
96
+ # Gérer différentes structures de checkpoint
97
+ if isinstance(checkpoint, dict):
98
+ if 'model_state_dict' in checkpoint:
99
+ clip_model.load_state_dict(checkpoint['model_state_dict'])
100
+ else:
101
+ # Si le checkpoint est directement le state_dict
102
+ clip_model.load_state_dict(checkpoint)
103
+ else:
104
+ clip_model.load_state_dict(checkpoint)
105
+
106
+ clip_model = clip_model.to(device)
107
+ clip_model.eval()
108
+
109
+ processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
110
+ print(" ✅ Modèle principal CLIP chargé")
111
+
112
+ print("\n✅ Tous les modèles sont chargés!")
113
+
114
+ return {
115
+ 'color_model': color_model,
116
+ 'hierarchy_model': hierarchy_model,
117
+ 'main_model': clip_model,
118
+ 'processor': processor,
119
+ 'device': device
120
+ }
121
+
122
+
123
+ def example_search(models, image_path: str = None, text_query: str = None):
124
+ """
125
+ Exemple de recherche avec les modèles
126
+
127
+ Args:
128
+ models: Dictionnaire des modèles chargés
129
+ image_path: Chemin vers une image (optionnel)
130
+ text_query: Requête textuelle (optionnel)
131
+ """
132
+
133
+ color_model = models['color_model']
134
+ hierarchy_model = models['hierarchy_model']
135
+ main_model = models['main_model']
136
+ processor = models['processor']
137
+ device = models['device']
138
+
139
+ print("\n🔍 Exemple de recherche...")
140
+
141
+ if text_query:
142
+ print(f" 📝 Requête textuelle: '{text_query}'")
143
+
144
+ # Obtenir les embeddings de couleur et hiérarchie
145
+ color_emb = color_model.get_text_embeddings([text_query])
146
+ hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
147
+
148
+ print(f" 🎨 Embedding couleur: {color_emb.shape}")
149
+ print(f" 📂 Embedding hiérarchie: {hierarchy_emb.shape}")
150
+
151
+ # Obtenir les embeddings du modèle principal
152
+ text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
153
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
154
+
155
+ with torch.no_grad():
156
+ outputs = main_model(**text_inputs)
157
+ text_features = outputs.text_embeds
158
+
159
+ print(f" 🎯 Embedding principal: {text_features.shape}")
160
+
161
+ if image_path and os.path.exists(image_path):
162
+ print(f" 🖼️ Image: {image_path}")
163
+ image = Image.open(image_path).convert("RGB")
164
+
165
+ # Obtenir les embeddings d'image
166
+ image_inputs = processor(images=[image], return_tensors="pt")
167
+ image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
168
+
169
+ with torch.no_grad():
170
+ outputs = main_model(**image_inputs)
171
+ image_features = outputs.image_embeds
172
+
173
+ print(f" 🎯 Embedding image: {image_features.shape}")
174
+
175
+
176
+ if __name__ == "__main__":
177
+ import argparse
178
+
179
+ parser = argparse.ArgumentParser(description="Exemple d'utilisation des modèles")
180
+ parser.add_argument(
181
+ "--repo-id",
182
+ type=str,
183
+ required=True,
184
+ help="ID du repository Hugging Face"
185
+ )
186
+ parser.add_argument(
187
+ "--text",
188
+ type=str,
189
+ default="red dress",
190
+ help="Requête textuelle de recherche"
191
+ )
192
+ parser.add_argument(
193
+ "--image",
194
+ type=str,
195
+ default=None,
196
+ help="Chemin vers une image"
197
+ )
198
+
199
+ args = parser.parse_args()
200
+
201
+ # Charger les modèles
202
+ models = load_models_from_hf(args.repo_id)
203
+
204
+ # Exemple de recherche
205
+ example_search(models, image_path=args.image, text_query=args.text)
206
+