vad-to-blendshape / inference.py
karie666666's picture
Upload inference.py with huggingface_hub
3002a11 verified
"""
Inference script: VAD -> 52-dim ARKit Blendshape coefficients (SYMMETRIC VERSION)
"""
import os
import re
import json
import argparse
from typing import List, Dict, Optional, Union
import numpy as np
import torch
import torch.nn as nn
VAD_DIM = 3
BLENDSHAPE_DIM = 52
BASIC_VAD = {
"neutral": ( 0.00, 0.00, 0.00),
"happiness": ( 0.80, 0.60, 0.50),
"surprise": ( 0.30, 0.90, 0.20),
"sadness": (-0.80, -0.40, -0.30),
"anger": (-0.70, 0.80, 0.70),
"disgust": (-0.60, 0.30, 0.40),
"fear": (-0.70, 0.80, -0.30),
"contempt": (-0.40, 0.30, 0.80),
"joy": ( 0.80, 0.60, 0.50),
"happy": ( 0.80, 0.60, 0.50),
"sad": (-0.80, -0.40, -0.30),
"angry": (-0.70, 0.80, 0.70),
}
# Left-right symmetric pairs
SYMMETRIC_PAIRS = [
(0, 1), (3, 4), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17),
(18, 19), (20, 21), (27, 28), (29, 30), (33, 34), (35, 36), (43, 44),
(45, 46), (47, 48), (49, 50),
]
BLENDSHAPE_NAMES = [
"browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft", "browOuterUpRight",
"cheekPuff", "cheekSquintLeft", "cheekSquintRight", "eyeBlinkLeft", "eyeBlinkRight",
"eyeLookDownLeft", "eyeLookDownRight", "eyeLookInLeft", "eyeLookInRight",
"eyeLookOutLeft", "eyeLookOutRight", "eyeLookUpLeft", "eyeLookUpRight",
"eyeSquintLeft", "eyeSquintRight", "eyeWideLeft", "eyeWideRight",
"jawForward", "jawLeft", "jawOpen", "jawRight",
"mouthClose", "mouthDimpleLeft", "mouthDimpleRight", "mouthFrownLeft", "mouthFrownRight",
"mouthFunnel", "mouthLeft", "mouthLowerDownLeft", "mouthLowerDownRight",
"mouthPressLeft", "mouthPressRight", "mouthPucker", "mouthRight",
"mouthRollLower", "mouthRollUpper", "mouthShrugLower", "mouthShrugUpper",
"mouthSmileLeft", "mouthSmileRight", "mouthStretchLeft", "mouthStretchRight",
"mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight",
"tongueOut"
]
class VADToBlendshapeMLP(nn.Module):
def __init__(self, hidden_dims=[256, 512, 256], dropout=0.1):
super().__init__()
dims = [VAD_DIM] + hidden_dims + [BLENDSHAPE_DIM]
layers = []
for i in range(len(dims) - 2):
layers.append(nn.Linear(dims[i], dims[i + 1]))
layers.append(nn.LayerNorm(dims[i + 1]))
layers.append(nn.LeakyReLU(0.2, inplace=True))
if dropout > 0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(dims[-2], dims[-1]))
self.net = nn.Sequential(*layers)
def forward(self, x):
return torch.clamp(self.net(x), 0.0, 1.0)
def load_model(checkpoint_path, metadata_path=None):
if metadata_path is None:
metadata_path = os.path.join(os.path.dirname(checkpoint_path), "model_metadata.json")
with open(metadata_path, "r") as f:
meta = json.load(f)
model = VADToBlendshapeMLP(hidden_dims=meta.get("hidden_dims", [256, 512, 256]))
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
return model, meta
def emotion_to_vad(emotion_str, intensity=1.0):
parts = [p.strip().lower() for p in emotion_str.split("+")]
vad = np.zeros(3, dtype=np.float32)
total = 0.0
for p in parts:
if p not in BASIC_VAD:
raise ValueError(f"Unknown emotion '{p}'")
w = intensity / len(parts)
vad += w * np.array(BASIC_VAD[p], dtype=np.float32)
total += w
if total > 0:
vad /= total
return vad
def predict(model, vad):
if isinstance(vad, list):
vad = np.array(vad, dtype=np.float32)
x = torch.from_numpy(vad).unsqueeze(0).float()
with torch.no_grad():
pred = model(x)
return pred.squeeze(0).cpu().numpy()
def enforce_symmetry(bs):
"""Post-process blendshape to force exact left-right symmetry."""
bs = bs.copy()
for li, ri in SYMMETRIC_PAIRS:
avg = (bs[li] + bs[ri]) / 2.0
bs[li] = avg
bs[ri] = avg
return bs
def compute_asymmetry(bs):
"""Compute mean |left-right| difference."""
return np.mean([abs(bs[li] - bs[ri]) for li, ri in SYMMETRIC_PAIRS])
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default="best_model.pt")
parser.add_argument("--metadata", type=str, default=None)
parser.add_argument("--vad", type=float, nargs=3, metavar=("V", "A", "D"))
parser.add_argument("--emotion", type=str)
parser.add_argument("--intensity", type=float, default=1.0)
parser.add_argument("--enforce-symmetry", action="store_true", help="Force exact left-right symmetry on output")
parser.add_argument("--topk", type=int, default=10)
parser.add_argument("--format", choices=["json", "list", "arkit"], default="list")
args = parser.parse_args()
model, meta = load_model(args.checkpoint, args.metadata)
if args.vad is not None:
vad = np.array(args.vad, dtype=np.float32)
elif args.emotion is not None:
vad = emotion_to_vad(args.emotion, args.intensity)
else:
parser.error("Provide --vad or --emotion")
bs = predict(model, vad)
if args.enforce_symmetry:
bs = enforce_symmetry(bs)
asym = compute_asymmetry(bs)
print(f"VAD: [{vad[0]:+.2f}, {vad[1]:+.2f}, {vad[2]:+.2f}] Asymmetry: {asym:.6f}")
if args.format == "json":
result = {
"vad": vad.tolist(),
"blendshape": bs.tolist(),
"blendshape_dict": {name: float(val) for name, val in zip(BLENDSHAPE_NAMES, bs)},
"top_active": [
{"name": BLENDSHAPE_NAMES[i], "value": float(bs[i])}
for i in np.argsort(bs)[::-1][:args.topk]
],
"asymmetry": float(asym),
}
print(json.dumps(result, indent=2, ensure_ascii=False))
elif args.format == "arkit":
arr = [{"blendshapeName": name, "weight": float(val)} for name, val in zip(BLENDSHAPE_NAMES, bs)]
print(json.dumps(arr, indent=2))
else:
print("Blendshape (52-dim):", bs.round(4).tolist())
topk = np.argsort(bs)[::-1][:args.topk]
print(f"\nTop {args.topk} active blendshapes:")
for i in topk:
print(f" {BLENDSHAPE_NAMES[i]}: {bs[i]:.4f}")
if __name__ == "__main__":
main()