TrungTran's picture
Update app.py
b1f9f19 verified
raw
history blame
10.1 kB
"""
FaceAge-DINOv3 β€” Gradio demo for HuggingFace Spaces.
Face detection : YuNet (OpenCV built-in, ~350 KB model, no extra deps)
Age/gender : FaceAge-DINOv3 ONNX (CPU, ~1.2 GB)
"""
import os
import numpy as np
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
# ---------------------------------------------------------------------------
# Age/gender preprocessing (ImageNet normalisation, matches training)
# ---------------------------------------------------------------------------
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
_IMG_SIZE = 224
def _preprocess(img_rgb: np.ndarray) -> np.ndarray:
"""HxWx3 uint8 RGB β†’ 1x3x224x224 float32."""
from PIL import Image as _PIL
pil = _PIL.fromarray(img_rgb).resize((_IMG_SIZE, _IMG_SIZE), _PIL.BICUBIC)
arr = np.asarray(pil, dtype=np.float32) / 255.0
arr = (arr - _MEAN) / _STD
return np.ascontiguousarray(arr.transpose(2, 0, 1)[np.newaxis])
def _decode_age(logits: np.ndarray) -> float:
"""CORAL: age = Ξ£ sigmoid(logits)."""
logits = np.clip(logits, -88.0, 88.0)
return float((1.0 / (1.0 + np.exp(-logits))).sum())
def _decode_gender(logits: np.ndarray) -> tuple[str, float]:
ex = np.exp(logits - logits.max())
probs = ex / ex.sum()
idx = int(probs.argmax())
return ("male" if idx == 1 else "female"), float(probs[idx])
# ---------------------------------------------------------------------------
# Age/gender model (ONNX, loaded from HF Hub)
# ---------------------------------------------------------------------------
_ORT_SESSION = None
_ORT_IN_NAME = None
def _load_age_model():
global _ORT_SESSION, _ORT_IN_NAME
if _ORT_SESSION is not None:
return
import onnxruntime as ort
from huggingface_hub import hf_hub_download
print("[AgeModel] Downloading ONNX from HuggingFace Hub …")
onnx_path = hf_hub_download(
repo_id = "TrungTran/faceage-dino",
filename = "faceage_dino_fp32.onnx",
)
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.intra_op_num_threads = 4
_ORT_SESSION = ort.InferenceSession(
onnx_path, sess_options=opts,
providers=["CPUExecutionProvider"],
)
_ORT_IN_NAME = _ORT_SESSION.get_inputs()[0].name
print(f"[AgeModel] Ready ({onnx_path})")
def _predict_crop(face_rgb: np.ndarray) -> dict:
x = _preprocess(face_rgb)
age_logits, gender_logits = _ORT_SESSION.run(None, {_ORT_IN_NAME: x})
age = _decode_age(age_logits[0])
gender, conf = _decode_gender(gender_logits[0])
return {"age": age, "gender": gender, "conf": conf}
# ---------------------------------------------------------------------------
# YuNet face detector (cv2.FaceDetectorYN, loaded from HuggingFace Hub)
# ---------------------------------------------------------------------------
_YUNET_REPO = "opencv/face_detection_yunet"
_YUNET_FILE = "face_detection_yunet_2023mar.onnx"
_DETECTOR = None
def _load_detector():
global _DETECTOR
if _DETECTOR is not None:
return
from huggingface_hub import hf_hub_download
import cv2
try:
yunet_path = hf_hub_download(repo_id=_YUNET_REPO, filename=_YUNET_FILE)
print(f"[YuNet] Model: {yunet_path}")
except Exception as e:
print(f"[YuNet] Download failed: {e} β€” face detection disabled")
_DETECTOR = "unavailable"
return
try:
_DETECTOR = cv2.FaceDetectorYN.create(
model = yunet_path,
config = "",
input_size = (320, 320),
score_threshold = 0.6,
nms_threshold = 0.3,
top_k = 100,
)
print("[YuNet] Face detector ready")
except Exception as e:
print(f"[YuNet] Init failed: {e} β€” face detection disabled")
_DETECTOR = "unavailable"
def _detect_faces(img_rgb: np.ndarray,
min_face_px: int = 20,
margin: int = 15) -> list[tuple[int, int, int, int]]:
"""
Returns list of (x0, y0, x1, y1) with margin padding, sorted by area desc.
Falls back to empty list if YuNet is unavailable.
"""
if _DETECTOR == "unavailable" or _DETECTOR is None:
return []
import cv2
h, w = img_rgb.shape[:2]
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
_DETECTOR.setInputSize((w, h))
_, faces = _DETECTOR.detect(img_bgr) # None or Nx15: [x,y,w,h, ...]
if faces is None:
return []
bboxes = []
for face in faces:
x, y, fw, fh = face[:4].astype(int)
# Add margin, clamp to image bounds
x0 = max(0, x - margin)
y0 = max(0, y - margin)
x1 = min(w, x + fw + margin)
y1 = min(h, y + fh + margin)
if (x1 - x0) >= min_face_px and (y1 - y0) >= min_face_px:
bboxes.append((x0, y0, x1, y1))
bboxes.sort(key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)
return bboxes
# ---------------------------------------------------------------------------
# Drawing
# ---------------------------------------------------------------------------
_PALETTE = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4",
"#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F"]
_FONT_PATHS = [
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/dejavu/DejaVuSans-Bold.ttf",
"/System/Library/Fonts/Helvetica.ttc",
]
def _get_font(size: int):
for path in _FONT_PATHS:
try:
return ImageFont.truetype(path, size)
except Exception:
pass
return ImageFont.load_default()
def _draw_results(pil_img: Image.Image, results: list[dict]) -> Image.Image:
draw = ImageDraw.Draw(pil_img)
font_lg = _get_font(20)
font_sm = _get_font(15)
for i, r in enumerate(results):
color = _PALETTE[i % len(_PALETTE)]
bbox = r.get("bbox")
label = f"{r['gender']} {r['age']:.0f} y"
if bbox:
x0, y0, x1, y1 = bbox
# Box
draw.rectangle([x0, y0, x1, y1], outline=color, width=3)
# Label background
tw = int(draw.textlength(label, font=font_lg))
th = 24
lx0, ly0 = x0, max(0, y0 - th - 4)
draw.rectangle([lx0, ly0, lx0 + tw + 10, ly0 + th + 4], fill=color)
draw.text((lx0 + 5, ly0 + 2), label, fill="white", font=font_lg)
else:
# Full-image fallback β€” overlay in top-left corner
full_label = f"{r['gender']} {r['age']:.0f} y ({r['conf']:.0%})"
tw = int(draw.textlength(full_label, font=font_lg))
draw.rectangle([8, 8, tw + 18, 38], fill=color)
draw.text((13, 10), full_label, fill="white", font=font_lg)
return pil_img
# ---------------------------------------------------------------------------
# Main predict function
# ---------------------------------------------------------------------------
def predict(image: Image.Image, max_faces: int,
conf_thresh: float) -> tuple[Image.Image, str]:
if image is None:
return None, "⬆️ Please upload a photo."
_load_age_model()
_load_detector()
img_rgb = np.asarray(image.convert("RGB"))
bboxes = _detect_faces(img_rgb)[:max_faces]
results = []
if bboxes:
for bbox in bboxes:
x0, y0, x1, y1 = bbox
crop = img_rgb[y0:y1, x0:x1]
r = _predict_crop(crop)
r["bbox"] = bbox
results.append(r)
else:
# No faces found β€” run on the entire image
r = _predict_crop(img_rgb)
results.append(r)
# Annotated output image
out_img = image.convert("RGB").copy()
out_img = _draw_results(out_img, results)
# Text summary
lines = []
mode = f"({len(bboxes)} face{'s' if len(bboxes)!=1 else ''} detected)" \
if bboxes else "(no face detected β€” full image used)"
lines.append(f"**{mode}**\n")
for i, r in enumerate(results, 1):
icon = "πŸ‘¨" if r["gender"] == "male" else "πŸ‘©"
lines.append(
f"{icon} **Face {i}** β€” Age **{r['age']:.1f}** Β· "
f"{r['gender']} ({r['conf']:.0%})"
)
summary = "\n".join(lines)
return out_img, summary
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
_DESC = """
## FaceAge-DINOv3 β€” Age & Gender Estimation
Upload a photo. **YuNet** auto-detects faces, then **FaceAge-DINOv3** predicts age and gender.
| | |
|--|--|
| πŸ† LAGENDA 84k MAE | **3.760** (beats MiVOLO v2 measured on the public model 3.859) |
| 🧠 Backbone | DINOv3-ViT-L/16 (Meta AI, 307M params) |
| ⚑ Speed | ~100 ms / face on CPU (ONNX FP32) |
| πŸ” Detector | YuNet (OpenCV, ~350 KB) |
[πŸ“„ Model Card](https://huggingface.co/trungthanhtran/faceage-dino)
"""
with gr.Blocks(title="FaceAge-DINOv3", theme=gr.themes.Soft()) as demo:
gr.Markdown(_DESC)
with gr.Row():
with gr.Column(scale=1):
inp_img = gr.Image(type="pil", label="πŸ“· Upload photo")
with gr.Row():
inp_max = gr.Slider(1, 10, value=5, step=1,
label="Max faces")
inp_conf = gr.Slider(0.3, 0.9, value=0.6, step=0.05,
label="Detection confidence")
btn = gr.Button("πŸ” Predict", variant="primary", size="lg")
with gr.Column(scale=1):
out_img = gr.Image(type="pil", label="Result")
out_text = gr.Markdown()
btn.click(
fn = predict,
inputs = [inp_img, inp_max, inp_conf],
outputs = [out_img, out_text],
)
gr.Markdown("""
---
*Trained on Our data collection*
*DINOv3-ViT-L pretrained by Meta AI on LVD-1.68B images.*
""")
if __name__ == "__main__":
demo.launch()