# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import argparse import os import sys from pathlib import Path import imageio.v2 as imageio import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm from soma import SOMALayer from soma.geometry.rig_utils import joint_local_to_world, joint_world_to_local from tools.vis_pyrender import ( MeshRenderer, default_pyopengl_platform, look_at, set_pyopengl_platform, ) # -------------------------------------------------------------------------------- # Joint Names & Mapping (from nvhuman_layer/joint_names.py) # -------------------------------------------------------------------------------- # fmt: off nvskel93_name = [ "Hips", "Spine1", "Spine2", "Chest", "Neck1", "Neck2", "Head", "HeadEnd", "Jaw", "LeftEye", "RightEye", "LeftShoulder", "LeftArm", "LeftForeArm", "LeftHand", "LeftHandThumb1", "LeftHandThumb2", "LeftHandThumb3", "LeftHandThumbEnd", "LeftHandIndex1", "LeftHandIndex2", "LeftHandIndex3", "LeftHandIndex4", "LeftHandIndexEnd", "LeftHandMiddle1", "LeftHandMiddle2", "LeftHandMiddle3", "LeftHandMiddle4", "LeftHandMiddleEnd", "LeftHandRing1", "LeftHandRing2", "LeftHandRing3", "LeftHandRing4", "LeftHandRingEnd", "LeftHandPinky1", "LeftHandPinky2", "LeftHandPinky3", "LeftHandPinky4", "LeftHandPinkyEnd", "LeftForeArmTwist1", "LeftForeArmTwist2", "LeftArmTwist1", "LeftArmTwist2", "RightShoulder", "RightArm", "RightForeArm", "RightHand", "RightHandThumb1", "RightHandThumb2", "RightHandThumb3", "RightHandThumbEnd", "RightHandIndex1", "RightHandIndex2", "RightHandIndex3", "RightHandIndex4", "RightHandIndexEnd", "RightHandMiddle1", "RightHandMiddle2", "RightHandMiddle3", "RightHandMiddle4", "RightHandMiddleEnd", "RightHandRing1", "RightHandRing2", "RightHandRing3", "RightHandRing4", "RightHandRingEnd", "RightHandPinky1", "RightHandPinky2", "RightHandPinky3", "RightHandPinky4", "RightHandPinkyEnd", "RightForeArmTwist1", "RightForeArmTwist2", "RightArmTwist1", "RightArmTwist2", "LeftLeg", "LeftShin", "LeftFoot", "LeftToeBase", "LeftToeEnd", "LeftShinTwist1", "LeftShinTwist2", "LeftLegTwist1", "LeftLegTwist2", "RightLeg", "RightShin", "RightFoot", "RightToeBase", "RightToeEnd", "RightShinTwist1", "RightShinTwist2", "RightLegTwist1", "RightLegTwist2", ] nvskel77_name = [ "Hips", "Spine1", "Spine2", "Chest", "Neck1", "Neck2", "Head", "HeadEnd", "Jaw", "LeftEye", "RightEye", "LeftShoulder", "LeftArm", "LeftForeArm", "LeftHand", "LeftHandThumb1", "LeftHandThumb2", "LeftHandThumb3", "LeftHandThumbEnd", "LeftHandIndex1", "LeftHandIndex2", "LeftHandIndex3", "LeftHandIndex4", "LeftHandIndexEnd", "LeftHandMiddle1", "LeftHandMiddle2", "LeftHandMiddle3", "LeftHandMiddle4", "LeftHandMiddleEnd", "LeftHandRing1", "LeftHandRing2", "LeftHandRing3", "LeftHandRing4", "LeftHandRingEnd", "LeftHandPinky1", "LeftHandPinky2", "LeftHandPinky3", "LeftHandPinky4", "LeftHandPinkyEnd", "RightShoulder", "RightArm", "RightForeArm", "RightHand", "RightHandThumb1", "RightHandThumb2", "RightHandThumb3", "RightHandThumbEnd", "RightHandIndex1", "RightHandIndex2", "RightHandIndex3", "RightHandIndex4", "RightHandIndexEnd", "RightHandMiddle1", "RightHandMiddle2", "RightHandMiddle3", "RightHandMiddle4", "RightHandMiddleEnd", "RightHandRing1", "RightHandRing2", "RightHandRing3", "RightHandRing4", "RightHandRingEnd", "RightHandPinky1", "RightHandPinky2", "RightHandPinky3", "RightHandPinky4", "RightHandPinkyEnd", "LeftLeg", "LeftShin", "LeftFoot", "LeftToeBase", "LeftToeEnd", "RightLeg", "RightShin", "RightFoot", "RightToeBase", "RightToeEnd", ] # fmt: on nvskel93to77_idx = [nvskel93_name.index(name) for name in nvskel77_name] color_map = { "soma": (0.4, 0.8, 0.4, 1.0), # light green "mhr": (0.98, 0.65, 0.15, 1.0), # blue "anny": (0.25, 0.75, 1.0, 1.0), # yellow "smpl": (0.55, 0.15, 0.85, 1.0), # pink "smplx": (0.55, 0.15, 0.85, 1.0), # pink "garment": (0.15, 0.15, 1.0, 1.0), # orange } def get_smooth_noise(T, dim, device, num_keyframes=None, mode="normal"): if num_keyframes is None: num_keyframes = max(3, T // 30) if mode == "normal": keyframes = torch.randn(1, dim, num_keyframes, device=device) elif mode == "uniform": keyframes = torch.rand(1, dim, num_keyframes, device=device) res = F.interpolate(keyframes, size=T, mode="linear", align_corners=True)[0].T return res def save_video(frames, path, fps=30): imageio.mimsave(path, frames, fps=fps) print(f"Saved {path}") def main(): parser = argparse.ArgumentParser(description="SOMA pyrender demo") parser.add_argument("--data-root", default="assets", help="Path to SOMA assets") parser.add_argument( "--motion-file", default="assets/example_animation.npy", help="Path to motion file (.npy). If None, uses a dummy motion.", ) parser.add_argument("--device", default="cuda:0") parser.add_argument("--output-dir", default="out/vis_identity_model") parser.add_argument("--image-size", type=int, default=1920) parser.add_argument("--pyopengl-platform", default=default_pyopengl_platform()) parser.add_argument("--random-shape", action="store_true", default=False) parser.add_argument( "--identity-model-type", default="soma,mhr,anny,smpl,smplx,garment", help="Comma-separated list of identity models to use. Options: soma, mhr, anny, smpl, smplx garment (default: soma,mhr,anny,smpl,smplx,garment)", ) parser.add_argument( "--pose-batch-size", type=int, default=0, help="Run forward pass in batches of this many poses to reduce GPU memory. 0 = process all frames at once (default). Try 32 or 64 if OOM.", ) parser.add_argument( "--low-lod", action="store_true", default=False, help="Use low level-of-detail mesh (fewer vertices/faces)", ) parser.add_argument( "--apply-correctives", action="store_true", default=False, help="Apply pose corrective offsets (default: False)", ) parser.add_argument( "--gender", default="neutral", help="Gender of the model (default: neutral). Only used for smpl and smplx models.", ) args = parser.parse_args() identity_models = [m.strip().lower() for m in args.identity_model_type.split(",")] valid_models = {"soma", "mhr", "anny", "smpl", "smplx", "garment"} invalid_models = set(identity_models) - valid_models if invalid_models: raise ValueError( f"Invalid identity model type(s): {invalid_models}. Valid options: {valid_models}" ) args.identity_models = identity_models repo_root = Path(__file__).resolve().parents[1] if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) set_pyopengl_platform(args.pyopengl_platform) device = torch.device(args.device if torch.cuda.is_available() else "cpu") os.makedirs(args.output_dir, exist_ok=True) print(f"Initializing models: {', '.join(args.identity_models)}...") models = {} for identity_model_type in args.identity_models: if identity_model_type == "smpl": identity_model_kwargs = { "gender": args.gender, } else: identity_model_kwargs = {} models[identity_model_type] = SOMALayer( data_root=args.data_root, low_lod=args.low_lod, device=str(device), identity_model_type=identity_model_type, mode="warp", identity_model_kwargs=identity_model_kwargs, ).to(device) reference_model = models[args.identity_models[0]] if args.motion_file and os.path.exists(args.motion_file): print(f"Loading motion from {args.motion_file}...") motion_full = torch.from_numpy(np.load(args.motion_file)).float().to(device) joint_rot_mats_local = motion_full[..., :3, :3] root_trans = motion_full[..., 1, :3, 3] else: print("No motion file provided or file not found. Using dummy motion (T-pose rotation).") T = 30 joint_rot_mats_local = ( torch.eye(3, device=device).unsqueeze(0).unsqueeze(0).repeat(T, 78, 1, 1) ) angle = torch.linspace(0, 2 * np.pi, T, device=device) cos = torch.cos(angle) sin = torch.sin(angle) zeros = torch.zeros_like(angle) ones = torch.ones_like(angle) rot_y = torch.stack( [ torch.stack([cos, zeros, sin], dim=-1), torch.stack([zeros, ones, zeros], dim=-1), torch.stack([-sin, zeros, cos], dim=-1), ], dim=-2, ) # (T, 3, 3) joint_rot_mats_local[:, 1] = rot_y # Rotate Hips root_trans = torch.zeros(T, 3, device=device) if joint_rot_mats_local.shape[1] == 94: subset_idx = [0] + [i + 1 for i in nvskel93to77_idx] joint_rot_mats_local = joint_rot_mats_local[:, subset_idx] correction = reference_model.t_pose_world[:, :3, :3].transpose(-2, -1) joint_rot_mats_world = joint_local_to_world( joint_rot_mats_local, reference_model.joint_parent_ids ) joint_rot_mats_world = joint_rot_mats_world @ correction joint_rot_mats_local = joint_world_to_local( joint_rot_mats_world, reference_model.joint_parent_ids ) T = joint_rot_mats_local.shape[0] global_orient = joint_rot_mats_local[:T, 1] # (T, 3, 3) - Hips is index 1 body_pose = joint_rot_mats_local[:T, 2:] # (T, 77, 3, 3) pose = torch.cat([global_orient.unsqueeze(1), body_pose], dim=1) # Prepare Identity Parameters identity_coeffs_map = {} for model_type, model in models.items(): n = model.identity_model.num_identity_coeffs if model_type == "anny": anny_im = model.identity_model.identity_model if args.random_shape: phenotypes = { k: get_smooth_noise(T, 1, device, mode="uniform").squeeze(-1) for k in anny_im.phenotype_labels } else: phenotypes = { k: torch.ones(T, device=device) * 0.5 for k in anny_im.phenotype_labels } local_changes = {k: torch.zeros(T, device=device) for k in anny_im.local_change_labels} identity_coeffs_map["anny"] = (phenotypes, local_changes) elif model_type == "mhr": n_scale = model.identity_model.num_scale_params if args.random_shape: coeffs = get_smooth_noise(T, n, device) scale = get_smooth_noise(T, n_scale, device, mode="normal") * 0.2 else: coeffs = torch.zeros(T, n, device=device) scale = torch.zeros(T, n_scale, device=device) identity_coeffs_map[model_type] = (coeffs, scale) else: if args.random_shape: coeffs = get_smooth_noise(T, n, device) else: coeffs = torch.zeros(T, n, device=device) identity_coeffs_map[model_type] = (coeffs, None) transl = root_trans[:T] # 4. Forward Pass using prepare_identity() + pose() API. # When identity is constant (not random_shape), prepare_identity is called # once per model and only pose() runs per batch -- skipping the expensive # identity model + skeleton transfer on every frame. pose_batch_size = args.pose_batch_size if args.pose_batch_size > 0 else T print(f"Running forward pass (pose_batch_size={pose_batch_size})...") outputs = {} with torch.no_grad(): if not args.random_shape: for model_type, model in models.items(): coeffs, scale = identity_coeffs_map[model_type] if isinstance(coeffs, dict): coeffs_single = {k: v[:1] for k, v in coeffs.items()} scale_single = {k: v[:1] for k, v in scale.items()} if scale else None else: coeffs_single = coeffs[:1] scale_single = scale[:1] if scale is not None else None model.prepare_identity(coeffs_single, scale_single) for start in range(0, T, pose_batch_size): end = min(start + pose_batch_size, T) pose_b = pose[start:end] transl_b = transl[start:end] for model_type, model in models.items(): if args.random_shape: coeffs, scale = identity_coeffs_map[model_type] if isinstance(coeffs, dict): coeffs_b = {k: v[start:end] for k, v in coeffs.items()} scale_b = {k: v[start:end] for k, v in scale.items()} if scale else None else: coeffs_b = coeffs[start:end] scale_b = scale[start:end] if scale is not None else None model.prepare_identity(coeffs_b, scale_b) out_b = model.pose( pose_b, transl=transl_b, pose2rot=False, apply_correctives=args.apply_correctives, ) if model_type not in outputs: outputs[model_type] = {"vertices": [], "joints": []} outputs[model_type]["vertices"].append(out_b["vertices"]) outputs[model_type]["joints"].append(out_b["joints"]) for model_type in list(outputs.keys()): outputs[model_type]["vertices"] = torch.cat(outputs[model_type]["vertices"], dim=0) outputs[model_type]["joints"] = torch.cat(outputs[model_type]["joints"], dim=0) # 5. Render (model-first loop with streaming video writer) print("Rendering videos...") suffix = "rand_shape" if args.random_shape else "fixed_shape" faces = { model_type: models[model_type].faces.detach().cpu().numpy() for model_type in args.identity_models } cam_pose = look_at( eye=np.array([0.0, 1.0, 6.0]), target=np.array([0.0, 1.0, 0.0]), up=np.array([0.0, 1.0, 0.0]), ) light_dir = np.array([0.0, -0.5, -1.0]) renderer = MeshRenderer(image_size=args.image_size, light_intensity=5) for model_type in args.identity_models: out_path = f"{args.output_dir}/{model_type}_{suffix}.mp4" renderer.setup_mesh( faces=faces[model_type], mesh_color=color_map[model_type], cam_pose=cam_pose, light_dir=light_dir, metallic=0.0, roughness=0.5, base_color_factor=[0.9, 0.9, 0.9, 1.0], ) writer = imageio.get_writer(out_path, fps=30) for t in tqdm(range(T), desc=model_type): verts = outputs[model_type]["vertices"][t].detach().cpu().numpy() img = renderer.render_frame(verts) writer.append_data(img[..., ::-1]) writer.close() print(f"Saved {out_path}") renderer.delete() if __name__ == "__main__": main()