""" Inference script: VAD -> 52-dim ARKit Blendshape coefficients (SYMMETRIC VERSION) """ import os import re import json import argparse from typing import List, Dict, Optional, Union import numpy as np import torch import torch.nn as nn VAD_DIM = 3 BLENDSHAPE_DIM = 52 BASIC_VAD = { "neutral": ( 0.00, 0.00, 0.00), "happiness": ( 0.80, 0.60, 0.50), "surprise": ( 0.30, 0.90, 0.20), "sadness": (-0.80, -0.40, -0.30), "anger": (-0.70, 0.80, 0.70), "disgust": (-0.60, 0.30, 0.40), "fear": (-0.70, 0.80, -0.30), "contempt": (-0.40, 0.30, 0.80), "joy": ( 0.80, 0.60, 0.50), "happy": ( 0.80, 0.60, 0.50), "sad": (-0.80, -0.40, -0.30), "angry": (-0.70, 0.80, 0.70), } # Left-right symmetric pairs SYMMETRIC_PAIRS = [ (0, 1), (3, 4), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (27, 28), (29, 30), (33, 34), (35, 36), (43, 44), (45, 46), (47, 48), (49, 50), ] BLENDSHAPE_NAMES = [ "browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft", "browOuterUpRight", "cheekPuff", "cheekSquintLeft", "cheekSquintRight", "eyeBlinkLeft", "eyeBlinkRight", "eyeLookDownLeft", "eyeLookDownRight", "eyeLookInLeft", "eyeLookInRight", "eyeLookOutLeft", "eyeLookOutRight", "eyeLookUpLeft", "eyeLookUpRight", "eyeSquintLeft", "eyeSquintRight", "eyeWideLeft", "eyeWideRight", "jawForward", "jawLeft", "jawOpen", "jawRight", "mouthClose", "mouthDimpleLeft", "mouthDimpleRight", "mouthFrownLeft", "mouthFrownRight", "mouthFunnel", "mouthLeft", "mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft", "mouthPressRight", "mouthPucker", "mouthRight", "mouthRollLower", "mouthRollUpper", "mouthShrugLower", "mouthShrugUpper", "mouthSmileLeft", "mouthSmileRight", "mouthStretchLeft", "mouthStretchRight", "mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight", "tongueOut" ] class VADToBlendshapeMLP(nn.Module): def __init__(self, hidden_dims=[256, 512, 256], dropout=0.1): super().__init__() dims = [VAD_DIM] + hidden_dims + [BLENDSHAPE_DIM] layers = [] for i in range(len(dims) - 2): layers.append(nn.Linear(dims[i], dims[i + 1])) layers.append(nn.LayerNorm(dims[i + 1])) layers.append(nn.LeakyReLU(0.2, inplace=True)) if dropout > 0: layers.append(nn.Dropout(dropout)) layers.append(nn.Linear(dims[-2], dims[-1])) self.net = nn.Sequential(*layers) def forward(self, x): return torch.clamp(self.net(x), 0.0, 1.0) def load_model(checkpoint_path, metadata_path=None): if metadata_path is None: metadata_path = os.path.join(os.path.dirname(checkpoint_path), "model_metadata.json") with open(metadata_path, "r") as f: meta = json.load(f) model = VADToBlendshapeMLP(hidden_dims=meta.get("hidden_dims", [256, 512, 256])) ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) model.eval() return model, meta def emotion_to_vad(emotion_str, intensity=1.0): parts = [p.strip().lower() for p in emotion_str.split("+")] vad = np.zeros(3, dtype=np.float32) total = 0.0 for p in parts: if p not in BASIC_VAD: raise ValueError(f"Unknown emotion '{p}'") w = intensity / len(parts) vad += w * np.array(BASIC_VAD[p], dtype=np.float32) total += w if total > 0: vad /= total return vad def predict(model, vad): if isinstance(vad, list): vad = np.array(vad, dtype=np.float32) x = torch.from_numpy(vad).unsqueeze(0).float() with torch.no_grad(): pred = model(x) return pred.squeeze(0).cpu().numpy() def enforce_symmetry(bs): """Post-process blendshape to force exact left-right symmetry.""" bs = bs.copy() for li, ri in SYMMETRIC_PAIRS: avg = (bs[li] + bs[ri]) / 2.0 bs[li] = avg bs[ri] = avg return bs def compute_asymmetry(bs): """Compute mean |left-right| difference.""" return np.mean([abs(bs[li] - bs[ri]) for li, ri in SYMMETRIC_PAIRS]) def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str, default="best_model.pt") parser.add_argument("--metadata", type=str, default=None) parser.add_argument("--vad", type=float, nargs=3, metavar=("V", "A", "D")) parser.add_argument("--emotion", type=str) parser.add_argument("--intensity", type=float, default=1.0) parser.add_argument("--enforce-symmetry", action="store_true", help="Force exact left-right symmetry on output") parser.add_argument("--topk", type=int, default=10) parser.add_argument("--format", choices=["json", "list", "arkit"], default="list") args = parser.parse_args() model, meta = load_model(args.checkpoint, args.metadata) if args.vad is not None: vad = np.array(args.vad, dtype=np.float32) elif args.emotion is not None: vad = emotion_to_vad(args.emotion, args.intensity) else: parser.error("Provide --vad or --emotion") bs = predict(model, vad) if args.enforce_symmetry: bs = enforce_symmetry(bs) asym = compute_asymmetry(bs) print(f"VAD: [{vad[0]:+.2f}, {vad[1]:+.2f}, {vad[2]:+.2f}] Asymmetry: {asym:.6f}") if args.format == "json": result = { "vad": vad.tolist(), "blendshape": bs.tolist(), "blendshape_dict": {name: float(val) for name, val in zip(BLENDSHAPE_NAMES, bs)}, "top_active": [ {"name": BLENDSHAPE_NAMES[i], "value": float(bs[i])} for i in np.argsort(bs)[::-1][:args.topk] ], "asymmetry": float(asym), } print(json.dumps(result, indent=2, ensure_ascii=False)) elif args.format == "arkit": arr = [{"blendshapeName": name, "weight": float(val)} for name, val in zip(BLENDSHAPE_NAMES, bs)] print(json.dumps(arr, indent=2)) else: print("Blendshape (52-dim):", bs.round(4).tolist()) topk = np.argsort(bs)[::-1][:args.topk] print(f"\nTop {args.topk} active blendshapes:") for i in topk: print(f" {BLENDSHAPE_NAMES[i]}: {bs[i]:.4f}") if __name__ == "__main__": main()