TrungTran commited on
Commit
7e0629f
Β·
verified Β·
1 Parent(s): fbf77cb

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FaceAge-DINOv3 β€” Gradio demo for HuggingFace Spaces.
3
+
4
+ Face detection : YuNet (OpenCV built-in, ~350 KB model, no extra deps)
5
+ Age/gender : FaceAge-DINOv3 ONNX (CPU, ~1.2 GB)
6
+ """
7
+ import urllib.request
8
+ import os
9
+ import numpy as np
10
+ import gradio as gr
11
+ from PIL import Image, ImageDraw, ImageFont
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # Age/gender preprocessing (ImageNet normalisation, matches training)
15
+ # ---------------------------------------------------------------------------
16
+
17
+ _MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
18
+ _STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
19
+ _IMG_SIZE = 224
20
+
21
+
22
+ def _preprocess(img_rgb: np.ndarray) -> np.ndarray:
23
+ """HxWx3 uint8 RGB β†’ 1x3x224x224 float32."""
24
+ from PIL import Image as _PIL
25
+ pil = _PIL.fromarray(img_rgb).resize((_IMG_SIZE, _IMG_SIZE), _PIL.BICUBIC)
26
+ arr = np.asarray(pil, dtype=np.float32) / 255.0
27
+ arr = (arr - _MEAN) / _STD
28
+ return np.ascontiguousarray(arr.transpose(2, 0, 1)[np.newaxis])
29
+
30
+
31
+ def _decode_age(logits: np.ndarray) -> float:
32
+ """CORAL: age = Ξ£ sigmoid(logits)."""
33
+ logits = np.clip(logits, -88.0, 88.0)
34
+ return float((1.0 / (1.0 + np.exp(-logits))).sum())
35
+
36
+
37
+ def _decode_gender(logits: np.ndarray) -> tuple[str, float]:
38
+ ex = np.exp(logits - logits.max())
39
+ probs = ex / ex.sum()
40
+ idx = int(probs.argmax())
41
+ return ("male" if idx == 1 else "female"), float(probs[idx])
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Age/gender model (ONNX, loaded from HF Hub)
46
+ # ---------------------------------------------------------------------------
47
+
48
+ _ORT_SESSION = None
49
+ _ORT_IN_NAME = None
50
+
51
+
52
+ def _load_age_model():
53
+ global _ORT_SESSION, _ORT_IN_NAME
54
+ if _ORT_SESSION is not None:
55
+ return
56
+
57
+ import onnxruntime as ort
58
+ from huggingface_hub import hf_hub_download
59
+
60
+ print("[AgeModel] Downloading ONNX from HuggingFace Hub …")
61
+ onnx_path = hf_hub_download(
62
+ repo_id = "TrungTran/faceage-dino",
63
+ filename = "faceage_dino_fp32.onnx",
64
+ )
65
+
66
+ opts = ort.SessionOptions()
67
+ opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
68
+ opts.intra_op_num_threads = 4
69
+ _ORT_SESSION = ort.InferenceSession(
70
+ onnx_path, sess_options=opts,
71
+ providers=["CPUExecutionProvider"],
72
+ )
73
+ _ORT_IN_NAME = _ORT_SESSION.get_inputs()[0].name
74
+ print(f"[AgeModel] Ready ({onnx_path})")
75
+
76
+
77
+ def _predict_crop(face_rgb: np.ndarray) -> dict:
78
+ x = _preprocess(face_rgb)
79
+ age_logits, gender_logits = _ORT_SESSION.run(None, {_ORT_IN_NAME: x})
80
+ age = _decode_age(age_logits[0])
81
+ gender, conf = _decode_gender(gender_logits[0])
82
+ return {"age": age, "gender": gender, "conf": conf}
83
+
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # YuNet face detector (cv2.FaceDetectorYN, OpenCV β‰₯ 4.5.4)
87
+ # ---------------------------------------------------------------------------
88
+
89
+ _YUNET_URL = (
90
+ "https://github.com/opencv/opencv_zoo/raw/main/models/"
91
+ "face_detection_yunet/face_detection_yunet_2023mar.onnx"
92
+ )
93
+ _YUNET_PATH = "/tmp/face_detection_yunet_2023mar.onnx"
94
+ _DETECTOR = None
95
+
96
+
97
+ def _load_detector():
98
+ global _DETECTOR
99
+ if _DETECTOR is not None:
100
+ return
101
+
102
+ # Download model if not cached
103
+ if not os.path.exists(_YUNET_PATH):
104
+ print(f"[YuNet] Downloading model …")
105
+ try:
106
+ urllib.request.urlretrieve(_YUNET_URL, _YUNET_PATH)
107
+ print(f"[YuNet] Saved to {_YUNET_PATH}")
108
+ except Exception as e:
109
+ print(f"[YuNet] Download failed: {e} β€” face detection disabled")
110
+ _DETECTOR = "unavailable"
111
+ return
112
+
113
+ import cv2
114
+ try:
115
+ _DETECTOR = cv2.FaceDetectorYN.create(
116
+ model = _YUNET_PATH,
117
+ config = "",
118
+ input_size = (320, 320),
119
+ score_threshold = 0.6,
120
+ nms_threshold = 0.3,
121
+ top_k = 100,
122
+ )
123
+ print("[YuNet] Face detector ready")
124
+ except Exception as e:
125
+ print(f"[YuNet] Init failed: {e} β€” face detection disabled")
126
+ _DETECTOR = "unavailable"
127
+
128
+
129
+ def _detect_faces(img_rgb: np.ndarray,
130
+ min_face_px: int = 20) -> list[tuple[int, int, int, int]]:
131
+ """
132
+ Returns list of (x0, y0, x1, y1) sorted by area (largest first).
133
+ Falls back to empty list if YuNet is unavailable.
134
+ """
135
+ if _DETECTOR == "unavailable" or _DETECTOR is None:
136
+ return []
137
+
138
+ import cv2
139
+ h, w = img_rgb.shape[:2]
140
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
141
+
142
+ # YuNet requires input_size to match the image
143
+ _DETECTOR.setInputSize((w, h))
144
+ _, faces = _DETECTOR.detect(img_bgr) # faces: None or Nx15
145
+
146
+ if faces is None:
147
+ return []
148
+
149
+ bboxes = []
150
+ for face in faces:
151
+ x, y, fw, fh = face[:4].astype(int)
152
+ x0, y0 = max(0, x), max(0, y)
153
+ x1, y1 = min(w, x + fw), min(h, y + fh)
154
+ if (x1 - x0) >= min_face_px and (y1 - y0) >= min_face_px:
155
+ bboxes.append((x0, y0, x1, y1))
156
+
157
+ bboxes.sort(key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)
158
+ return bboxes
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Drawing
163
+ # ---------------------------------------------------------------------------
164
+
165
+ _PALETTE = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4",
166
+ "#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F"]
167
+
168
+ _FONT_PATHS = [
169
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
170
+ "/usr/share/fonts/dejavu/DejaVuSans-Bold.ttf",
171
+ "/System/Library/Fonts/Helvetica.ttc",
172
+ ]
173
+
174
+
175
+ def _get_font(size: int):
176
+ for path in _FONT_PATHS:
177
+ try:
178
+ return ImageFont.truetype(path, size)
179
+ except Exception:
180
+ pass
181
+ return ImageFont.load_default()
182
+
183
+
184
+ def _draw_results(pil_img: Image.Image, results: list[dict]) -> Image.Image:
185
+ draw = ImageDraw.Draw(pil_img)
186
+ font_lg = _get_font(20)
187
+ font_sm = _get_font(15)
188
+
189
+ for i, r in enumerate(results):
190
+ color = _PALETTE[i % len(_PALETTE)]
191
+ bbox = r.get("bbox")
192
+ label = f"{r['gender']} {r['age']:.0f} y"
193
+
194
+ if bbox:
195
+ x0, y0, x1, y1 = bbox
196
+ # Box
197
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=3)
198
+ # Label background
199
+ tw = int(draw.textlength(label, font=font_lg))
200
+ th = 24
201
+ lx0, ly0 = x0, max(0, y0 - th - 4)
202
+ draw.rectangle([lx0, ly0, lx0 + tw + 10, ly0 + th + 4], fill=color)
203
+ draw.text((lx0 + 5, ly0 + 2), label, fill="white", font=font_lg)
204
+ else:
205
+ # Full-image fallback β€” overlay in top-left corner
206
+ full_label = f"{r['gender']} {r['age']:.0f} y ({r['conf']:.0%})"
207
+ tw = int(draw.textlength(full_label, font=font_lg))
208
+ draw.rectangle([8, 8, tw + 18, 38], fill=color)
209
+ draw.text((13, 10), full_label, fill="white", font=font_lg)
210
+
211
+ return pil_img
212
+
213
+
214
+ # ---------------------------------------------------------------------------
215
+ # Main predict function
216
+ # ---------------------------------------------------------------------------
217
+
218
+ def predict(image: Image.Image, max_faces: int,
219
+ conf_thresh: float) -> tuple[Image.Image, str]:
220
+ if image is None:
221
+ return None, "⬆️ Please upload a photo."
222
+
223
+ _load_age_model()
224
+ _load_detector()
225
+
226
+ img_rgb = np.asarray(image.convert("RGB"))
227
+ bboxes = _detect_faces(img_rgb)[:max_faces]
228
+
229
+ results = []
230
+ if bboxes:
231
+ for bbox in bboxes:
232
+ x0, y0, x1, y1 = bbox
233
+ crop = img_rgb[y0:y1, x0:x1]
234
+ r = _predict_crop(crop)
235
+ r["bbox"] = bbox
236
+ results.append(r)
237
+ else:
238
+ # No faces found β€” run on the entire image
239
+ r = _predict_crop(img_rgb)
240
+ results.append(r)
241
+
242
+ # Annotated output image
243
+ out_img = image.convert("RGB").copy()
244
+ out_img = _draw_results(out_img, results)
245
+
246
+ # Text summary
247
+ lines = []
248
+ mode = f"({len(bboxes)} face{'s' if len(bboxes)!=1 else ''} detected)" \
249
+ if bboxes else "(no face detected β€” full image used)"
250
+ lines.append(f"**{mode}**\n")
251
+ for i, r in enumerate(results, 1):
252
+ icon = "πŸ‘¨" if r["gender"] == "male" else "πŸ‘©"
253
+ lines.append(
254
+ f"{icon} **Face {i}** β€” Age **{r['age']:.1f}** Β· "
255
+ f"{r['gender']} ({r['conf']:.0%})"
256
+ )
257
+ summary = "\n".join(lines)
258
+
259
+ return out_img, summary
260
+
261
+
262
+ # ---------------------------------------------------------------------------
263
+ # Gradio UI
264
+ # ---------------------------------------------------------------------------
265
+
266
+ _DESC = """
267
+ ## FaceAge-DINOv3 β€” Age & Gender Estimation
268
+
269
+ Upload a photo. **YuNet** auto-detects faces, then **FaceAge-DINOv3** predicts age and gender.
270
+
271
+ | | |
272
+ |--|--|
273
+ | πŸ† LAGENDA 84k MAE | **3.760** (beats MiVOLO v2 measured 3.859) |
274
+ | 🧠 Backbone | DINOv3-ViT-L/16 (Meta AI, 307M params) |
275
+ | ⚑ Speed | ~100 ms / face on CPU (ONNX FP32) |
276
+ | πŸ” Detector | YuNet (OpenCV, ~350 KB) |
277
+
278
+ [πŸ“„ Model Card](https://huggingface.co/trungthanhtran/faceage-dino)
279
+ """
280
+
281
+ with gr.Blocks(title="FaceAge-DINOv3", theme=gr.themes.Soft()) as demo:
282
+ gr.Markdown(_DESC)
283
+
284
+ with gr.Row():
285
+ with gr.Column(scale=1):
286
+ inp_img = gr.Image(type="pil", label="πŸ“· Upload photo")
287
+ with gr.Row():
288
+ inp_max = gr.Slider(1, 10, value=5, step=1,
289
+ label="Max faces")
290
+ inp_conf = gr.Slider(0.3, 0.9, value=0.6, step=0.05,
291
+ label="Detection confidence")
292
+ btn = gr.Button("πŸ” Predict", variant="primary", size="lg")
293
+
294
+ with gr.Column(scale=1):
295
+ out_img = gr.Image(type="pil", label="Result")
296
+ out_text = gr.Markdown()
297
+
298
+ btn.click(
299
+ fn = predict,
300
+ inputs = [inp_img, inp_max, inp_conf],
301
+ outputs = [out_img, out_text],
302
+ )
303
+
304
+ gr.Markdown("""
305
+ ---
306
+ *Trained on LAGENDA Β· IMDB-Clean Β· UTKFace Β· AgeDB Β· FairFace Β· Open Images.*
307
+ *DINOv3-ViT-L pretrained by Meta AI on LVD-1.68B images.*
308
+ """)
309
+
310
+ if __name__ == "__main__":
311
+ demo.launch()