File size: 10,419 Bytes
7e0629f
b20ec2d
7e0629f
 
b20ec2d
7e0629f
 
 
2a9e2ae
 
7e0629f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9fbb0c
7e0629f
2a9e2ae
7e0629f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
049d210
7e0629f
 
049d210
 
7e0629f
 
 
 
 
 
 
 
049d210
7e0629f
049d210
 
 
 
 
 
 
 
 
7e0629f
 
049d210
7e0629f
 
 
 
 
 
 
 
 
 
 
 
5a621f1
 
 
7e0629f
5a621f1
7e0629f
5a621f1
7e0629f
 
 
 
 
 
049d210
7e0629f
 
 
049d210
7e0629f
 
 
 
 
 
 
5a621f1
 
 
 
 
 
 
7e0629f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8599c1
7e0629f
 
 
 
 
 
 
 
 
 
 
 
 
c8599c1
7e0629f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b20ec2d
7e0629f
b20ec2d
7e0629f
 
 
5a621f1
7e0629f
 
 
 
fd39792
7e0629f
 
b20ec2d
7e0629f
 
 
 
77d52f7
 
7e0629f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8599c1
7e0629f
 
 
 
a7cab75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
FaceAge ClientScan β€” Gradio demo for HuggingFace Spaces.

Face detection : YuNet (OpenCV built-in, ~350 KB model, no extra deps)
Age/gender     : FaceAge ClientScan ONNX (CPU, ~1.2 GB)
"""
import os
import numpy as np

_HF_TOKEN = os.environ.get("HF_TOKEN") or None
import gradio as gr
from PIL import Image, ImageDraw, ImageFont

# ---------------------------------------------------------------------------
# Age/gender preprocessing  (ImageNet normalisation, matches training)
# ---------------------------------------------------------------------------

_MEAN     = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD      = np.array([0.229, 0.224, 0.225], dtype=np.float32)
_IMG_SIZE = 224


def _preprocess(img_rgb: np.ndarray) -> np.ndarray:
    """HxWx3 uint8 RGB  β†’  1x3x224x224 float32."""
    from PIL import Image as _PIL
    pil = _PIL.fromarray(img_rgb).resize((_IMG_SIZE, _IMG_SIZE), _PIL.BICUBIC)
    arr = np.asarray(pil, dtype=np.float32) / 255.0
    arr = (arr - _MEAN) / _STD
    return np.ascontiguousarray(arr.transpose(2, 0, 1)[np.newaxis])


def _decode_age(logits: np.ndarray) -> float:
    """CORAL: age = Ξ£ sigmoid(logits)."""
    logits = np.clip(logits, -88.0, 88.0)
    return float((1.0 / (1.0 + np.exp(-logits))).sum())


def _decode_gender(logits: np.ndarray) -> tuple[str, float]:
    ex    = np.exp(logits - logits.max())
    probs = ex / ex.sum()
    idx   = int(probs.argmax())
    return ("male" if idx == 1 else "female"), float(probs[idx])


# ---------------------------------------------------------------------------
# Age/gender model  (ONNX, loaded from HF Hub)
# ---------------------------------------------------------------------------

_ORT_SESSION = None
_ORT_IN_NAME = None


def _load_age_model():
    global _ORT_SESSION, _ORT_IN_NAME
    if _ORT_SESSION is not None:
        return

    import onnxruntime as ort
    from huggingface_hub import hf_hub_download

    print("[AgeModel] Downloading ONNX from HuggingFace Hub …")
    onnx_path = hf_hub_download(
        repo_id  = "TrungTran/faceage_ClientScan",
        filename = "faceage_dino_fp32.onnx",
        token    = _HF_TOKEN,
    )

    opts = ort.SessionOptions()
    opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    opts.intra_op_num_threads = 4
    _ORT_SESSION = ort.InferenceSession(
        onnx_path, sess_options=opts,
        providers=["CPUExecutionProvider"],
    )
    _ORT_IN_NAME = _ORT_SESSION.get_inputs()[0].name
    print(f"[AgeModel] Ready  ({onnx_path})")


def _predict_crop(face_rgb: np.ndarray) -> dict:
    x = _preprocess(face_rgb)
    age_logits, gender_logits = _ORT_SESSION.run(None, {_ORT_IN_NAME: x})
    age            = _decode_age(age_logits[0])
    gender, conf   = _decode_gender(gender_logits[0])
    return {"age": age, "gender": gender, "conf": conf}


# ---------------------------------------------------------------------------
# YuNet face detector  (cv2.FaceDetectorYN, loaded from HuggingFace Hub)
# ---------------------------------------------------------------------------

_YUNET_REPO = "opencv/face_detection_yunet"
_YUNET_FILE = "face_detection_yunet_2023mar.onnx"
_DETECTOR   = None


def _load_detector():
    global _DETECTOR
    if _DETECTOR is not None:
        return

    from huggingface_hub import hf_hub_download
    import cv2

    try:
        yunet_path = hf_hub_download(repo_id=_YUNET_REPO, filename=_YUNET_FILE)
        print(f"[YuNet] Model: {yunet_path}")
    except Exception as e:
        print(f"[YuNet] Download failed: {e} β€” face detection disabled")
        _DETECTOR = "unavailable"
        return

    try:
        _DETECTOR = cv2.FaceDetectorYN.create(
            model           = yunet_path,
            config          = "",
            input_size      = (320, 320),
            score_threshold = 0.6,
            nms_threshold   = 0.3,
            top_k           = 100,
        )
        print("[YuNet] Face detector ready")
    except Exception as e:
        print(f"[YuNet] Init failed: {e} β€” face detection disabled")
        _DETECTOR = "unavailable"


_FACE_PAD = 0.10   # 10% proportional padding β€” matches LAGENDA benchmark MAE=3.555


def _detect_faces(img_rgb: np.ndarray,
                  min_face_px: int = 20) -> list[tuple[int, int, int, int]]:
    """
    Returns list of (x0, y0, x1, y1) with 10% proportional padding, sorted by area desc.
    Falls back to empty list if YuNet is unavailable.
    """
    if _DETECTOR == "unavailable" or _DETECTOR is None:
        return []

    import cv2
    h, w    = img_rgb.shape[:2]
    img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)

    _DETECTOR.setInputSize((w, h))
    _, faces = _DETECTOR.detect(img_bgr)   # None or Nx15: [x,y,w,h, ...]

    if faces is None:
        return []

    bboxes = []
    for face in faces:
        x, y, fw, fh = face[:4].astype(int)
        # 10% proportional padding (matches training/benchmark setup)
        pw = int(fw * _FACE_PAD)
        ph = int(fh * _FACE_PAD)
        x0 = max(0, x - pw)
        y0 = max(0, y - ph)
        x1 = min(w, x + fw + pw)
        y1 = min(h, y + fh + ph)
        if (x1 - x0) >= min_face_px and (y1 - y0) >= min_face_px:
            bboxes.append((x0, y0, x1, y1))

    bboxes.sort(key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)
    return bboxes


# ---------------------------------------------------------------------------
# Drawing
# ---------------------------------------------------------------------------

_PALETTE = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4",
            "#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F"]

_FONT_PATHS = [
    "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
    "/usr/share/fonts/dejavu/DejaVuSans-Bold.ttf",
    "/System/Library/Fonts/Helvetica.ttc",
]


def _get_font(size: int):
    for path in _FONT_PATHS:
        try:
            return ImageFont.truetype(path, size)
        except Exception:
            pass
    return ImageFont.load_default()


def _draw_results(pil_img: Image.Image, results: list[dict]) -> Image.Image:
    draw     = ImageDraw.Draw(pil_img)
    font_lg  = _get_font(20)
    font_sm  = _get_font(15)

    for i, r in enumerate(results):
        color = _PALETTE[i % len(_PALETTE)]
        bbox  = r.get("bbox")
        label = f"{r['gender']}  {r['age']:.1f} y"

        if bbox:
            x0, y0, x1, y1 = bbox
            # Box
            draw.rectangle([x0, y0, x1, y1], outline=color, width=3)
            # Label background
            tw = int(draw.textlength(label, font=font_lg))
            th = 24
            lx0, ly0 = x0, max(0, y0 - th - 4)
            draw.rectangle([lx0, ly0, lx0 + tw + 10, ly0 + th + 4], fill=color)
            draw.text((lx0 + 5, ly0 + 2), label, fill="white", font=font_lg)
        else:
            # Full-image fallback β€” overlay in top-left corner
            full_label = f"{r['gender']}  {r['age']:.1f} y  ({r['conf']:.0%})"
            tw = int(draw.textlength(full_label, font=font_lg))
            draw.rectangle([8, 8, tw + 18, 38], fill=color)
            draw.text((13, 10), full_label, fill="white", font=font_lg)

    return pil_img


# ---------------------------------------------------------------------------
# Main predict function
# ---------------------------------------------------------------------------

def predict(image: Image.Image, max_faces: int,
            conf_thresh: float) -> tuple[Image.Image, str]:
    if image is None:
        return None, "⬆️  Please upload a photo."

    _load_age_model()
    _load_detector()

    img_rgb = np.asarray(image.convert("RGB"))
    bboxes  = _detect_faces(img_rgb)[:max_faces]

    results = []
    if bboxes:
        for bbox in bboxes:
            x0, y0, x1, y1 = bbox
            crop = img_rgb[y0:y1, x0:x1]
            r    = _predict_crop(crop)
            r["bbox"] = bbox
            results.append(r)
    else:
        # No faces found β€” run on the entire image
        r = _predict_crop(img_rgb)
        results.append(r)

    # Annotated output image
    out_img = image.convert("RGB").copy()
    out_img = _draw_results(out_img, results)

    # Text summary
    lines = []
    mode  = f"({len(bboxes)} face{'s' if len(bboxes)!=1 else ''} detected)" \
            if bboxes else "(no face detected β€” full image used)"
    lines.append(f"**{mode}**\n")
    for i, r in enumerate(results, 1):
        icon = "πŸ‘¨" if r["gender"] == "male" else "πŸ‘©"
        lines.append(
            f"{icon} **Face {i}** β€” Age **{r['age']:.1f}**  Β·  "
            f"{r['gender']} ({r['conf']:.0%})"
        )
    summary = "\n".join(lines)

    return out_img, summary


# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------

_DESC = """
## FaceAge ClientScan β€” Age & Gender Estimation

Upload a photo. **YuNet** auto-detects faces, then **FaceAge ClientScan** predicts age and gender.

| | |
|--|--|
| πŸ† LAGENDA 84k MAE | **3.555** (beats MiVOLO v2 paper 3.650, face-only) |
| 🧠 Backbone | DINOv3-ViT-L/16 (Meta AI, 307M params) |
| ⚑ Speed | ~100 ms / face on CPU (ONNX FP32) |
| πŸ” Detector | YuNet (OpenCV, ~350 KB) |

[πŸ“„ Model Card](https://huggingface.co/TrungTran/faceage_ClientScan)
"""

with gr.Blocks(title="FaceAge ClientScan", theme=gr.themes.Soft()) as demo:
    gr.Markdown(_DESC)

    with gr.Row():
        with gr.Column(scale=1):
            inp_img   = gr.Image(type="pil", label="πŸ“· Upload photo or use webcam",
                                 sources=["upload", "webcam", "clipboard"])
            with gr.Row():
                inp_max   = gr.Slider(1, 10, value=5, step=1,
                                      label="Max faces")
                inp_conf  = gr.Slider(0.3, 0.9, value=0.6, step=0.05,
                                      label="Detection confidence")
            btn = gr.Button("πŸ” Predict", variant="primary", size="lg")

        with gr.Column(scale=1):
            out_img  = gr.Image(type="pil", label="Result")
            out_text = gr.Markdown()

    btn.click(
        fn      = predict,
        inputs  = [inp_img, inp_max, inp_conf],
        outputs = [out_img, out_text],
    )

    gr.Markdown("""
---
*Our Collection: 4M Images.*
*DINOv3-ViT-L pretrained by Meta AI on LVD-1.68B images.*
""")

if __name__ == "__main__":
    demo.launch()