| """ |
| Inference script: VAD -> 52-dim ARKit Blendshape coefficients (SYMMETRIC VERSION) |
| """ |
| import os |
| import re |
| import json |
| import argparse |
| from typing import List, Dict, Optional, Union |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| VAD_DIM = 3 |
| BLENDSHAPE_DIM = 52 |
|
|
| BASIC_VAD = { |
| "neutral": ( 0.00, 0.00, 0.00), |
| "happiness": ( 0.80, 0.60, 0.50), |
| "surprise": ( 0.30, 0.90, 0.20), |
| "sadness": (-0.80, -0.40, -0.30), |
| "anger": (-0.70, 0.80, 0.70), |
| "disgust": (-0.60, 0.30, 0.40), |
| "fear": (-0.70, 0.80, -0.30), |
| "contempt": (-0.40, 0.30, 0.80), |
| "joy": ( 0.80, 0.60, 0.50), |
| "happy": ( 0.80, 0.60, 0.50), |
| "sad": (-0.80, -0.40, -0.30), |
| "angry": (-0.70, 0.80, 0.70), |
| } |
|
|
| |
| SYMMETRIC_PAIRS = [ |
| (0, 1), (3, 4), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), |
| (18, 19), (20, 21), (27, 28), (29, 30), (33, 34), (35, 36), (43, 44), |
| (45, 46), (47, 48), (49, 50), |
| ] |
|
|
| BLENDSHAPE_NAMES = [ |
| "browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft", "browOuterUpRight", |
| "cheekPuff", "cheekSquintLeft", "cheekSquintRight", "eyeBlinkLeft", "eyeBlinkRight", |
| "eyeLookDownLeft", "eyeLookDownRight", "eyeLookInLeft", "eyeLookInRight", |
| "eyeLookOutLeft", "eyeLookOutRight", "eyeLookUpLeft", "eyeLookUpRight", |
| "eyeSquintLeft", "eyeSquintRight", "eyeWideLeft", "eyeWideRight", |
| "jawForward", "jawLeft", "jawOpen", "jawRight", |
| "mouthClose", "mouthDimpleLeft", "mouthDimpleRight", "mouthFrownLeft", "mouthFrownRight", |
| "mouthFunnel", "mouthLeft", "mouthLowerDownLeft", "mouthLowerDownRight", |
| "mouthPressLeft", "mouthPressRight", "mouthPucker", "mouthRight", |
| "mouthRollLower", "mouthRollUpper", "mouthShrugLower", "mouthShrugUpper", |
| "mouthSmileLeft", "mouthSmileRight", "mouthStretchLeft", "mouthStretchRight", |
| "mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight", |
| "tongueOut" |
| ] |
|
|
| class VADToBlendshapeMLP(nn.Module): |
| def __init__(self, hidden_dims=[256, 512, 256], dropout=0.1): |
| super().__init__() |
| dims = [VAD_DIM] + hidden_dims + [BLENDSHAPE_DIM] |
| layers = [] |
| for i in range(len(dims) - 2): |
| layers.append(nn.Linear(dims[i], dims[i + 1])) |
| layers.append(nn.LayerNorm(dims[i + 1])) |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) |
| if dropout > 0: |
| layers.append(nn.Dropout(dropout)) |
| layers.append(nn.Linear(dims[-2], dims[-1])) |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return torch.clamp(self.net(x), 0.0, 1.0) |
|
|
|
|
| def load_model(checkpoint_path, metadata_path=None): |
| if metadata_path is None: |
| metadata_path = os.path.join(os.path.dirname(checkpoint_path), "model_metadata.json") |
| with open(metadata_path, "r") as f: |
| meta = json.load(f) |
|
|
| model = VADToBlendshapeMLP(hidden_dims=meta.get("hidden_dims", [256, 512, 256])) |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
| return model, meta |
|
|
|
|
| def emotion_to_vad(emotion_str, intensity=1.0): |
| parts = [p.strip().lower() for p in emotion_str.split("+")] |
| vad = np.zeros(3, dtype=np.float32) |
| total = 0.0 |
| for p in parts: |
| if p not in BASIC_VAD: |
| raise ValueError(f"Unknown emotion '{p}'") |
| w = intensity / len(parts) |
| vad += w * np.array(BASIC_VAD[p], dtype=np.float32) |
| total += w |
| if total > 0: |
| vad /= total |
| return vad |
|
|
|
|
| def predict(model, vad): |
| if isinstance(vad, list): |
| vad = np.array(vad, dtype=np.float32) |
| x = torch.from_numpy(vad).unsqueeze(0).float() |
| with torch.no_grad(): |
| pred = model(x) |
| return pred.squeeze(0).cpu().numpy() |
|
|
|
|
| def enforce_symmetry(bs): |
| """Post-process blendshape to force exact left-right symmetry.""" |
| bs = bs.copy() |
| for li, ri in SYMMETRIC_PAIRS: |
| avg = (bs[li] + bs[ri]) / 2.0 |
| bs[li] = avg |
| bs[ri] = avg |
| return bs |
|
|
|
|
| def compute_asymmetry(bs): |
| """Compute mean |left-right| difference.""" |
| return np.mean([abs(bs[li] - bs[ri]) for li, ri in SYMMETRIC_PAIRS]) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", type=str, default="best_model.pt") |
| parser.add_argument("--metadata", type=str, default=None) |
| parser.add_argument("--vad", type=float, nargs=3, metavar=("V", "A", "D")) |
| parser.add_argument("--emotion", type=str) |
| parser.add_argument("--intensity", type=float, default=1.0) |
| parser.add_argument("--enforce-symmetry", action="store_true", help="Force exact left-right symmetry on output") |
| parser.add_argument("--topk", type=int, default=10) |
| parser.add_argument("--format", choices=["json", "list", "arkit"], default="list") |
| args = parser.parse_args() |
|
|
| model, meta = load_model(args.checkpoint, args.metadata) |
|
|
| if args.vad is not None: |
| vad = np.array(args.vad, dtype=np.float32) |
| elif args.emotion is not None: |
| vad = emotion_to_vad(args.emotion, args.intensity) |
| else: |
| parser.error("Provide --vad or --emotion") |
|
|
| bs = predict(model, vad) |
| if args.enforce_symmetry: |
| bs = enforce_symmetry(bs) |
|
|
| asym = compute_asymmetry(bs) |
| print(f"VAD: [{vad[0]:+.2f}, {vad[1]:+.2f}, {vad[2]:+.2f}] Asymmetry: {asym:.6f}") |
|
|
| if args.format == "json": |
| result = { |
| "vad": vad.tolist(), |
| "blendshape": bs.tolist(), |
| "blendshape_dict": {name: float(val) for name, val in zip(BLENDSHAPE_NAMES, bs)}, |
| "top_active": [ |
| {"name": BLENDSHAPE_NAMES[i], "value": float(bs[i])} |
| for i in np.argsort(bs)[::-1][:args.topk] |
| ], |
| "asymmetry": float(asym), |
| } |
| print(json.dumps(result, indent=2, ensure_ascii=False)) |
| elif args.format == "arkit": |
| arr = [{"blendshapeName": name, "weight": float(val)} for name, val in zip(BLENDSHAPE_NAMES, bs)] |
| print(json.dumps(arr, indent=2)) |
| else: |
| print("Blendshape (52-dim):", bs.round(4).tolist()) |
| topk = np.argsort(bs)[::-1][:args.topk] |
| print(f"\nTop {args.topk} active blendshapes:") |
| for i in topk: |
| print(f" {BLENDSHAPE_NAMES[i]}: {bs[i]:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|