File size: 10,050 Bytes
bd95c9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Mesh-level validation of PoseMirror_SOMA and PoseMirror_MHR.

For each mirror class, compares the parameter-mirrored mesh against a geometric
mesh mirror (flip X + mirror_vert_indices) in Nova topology, excluding facial
inner geometry.

Pose data paths are resolved exclusively from environment variables.
Tests skip gracefully when the variables are not set or the data is unavailable,
so CI and external users are unaffected.

Environment variables:
    SOMA_POSE_NPZ  Path to a single Nova pose .npz file (key: pose_local or transforms).
    SOMA_POSE_DIR  Directory of Nova pose .npz files (first file is used).
    MHR_POSE_NPZ   Path to a single MHR pose .npz file (key: pose_params).
    MHR_POSE_DIR    Directory of MHR pose .npz files (first file is used).

Usage:
    SOMA_POSE_NPZ="path/to/soma_pose.npz" MHR_POSE_NPZ="path/to/mhr_pose.npz" \
        pytest tests/test_pose_mirror.py -v
"""

from __future__ import annotations

import os
from pathlib import Path

import numpy as np
import pytest
import torch
from scipy.sparse import csc_matrix

REPO_ROOT = Path(__file__).resolve().parents[1]
ASSETS_DIR = REPO_ROOT / "assets"
CORE_ASSET = ASSETS_DIR / "SOMA_neutral.npz"

MAX_FRAMES = 100
BATCH_SIZE = 64


# ---------------------------------------------------------------------------
# Data-path resolution helpers
# ---------------------------------------------------------------------------


def _resolve_npz_path(env_single: str, env_dir: str) -> Path | None:
    """Return the first .npz found via the given environment variables.

    Checks *env_single* first (should point to a single file), then
    *env_dir* (should point to a directory; the first .npz is used).
    Returns ``None`` when neither variable is set or no file is found.
    """
    if val := os.environ.get(env_single):
        p = Path(val)
        return p if p.is_file() else None

    if val := os.environ.get(env_dir):
        search_dir = Path(val)
        if search_dir.is_dir():
            npzs = sorted(search_dir.glob("*.npz"))
            if npzs:
                return npzs[0]
    return None


def _load_mhr_param_names(npz_path: Path):
    """Extract the first 204 parameter names from parameter_transform.npz."""
    with np.load(npz_path, allow_pickle=False) as data:
        return list(data["parameter_names"][:204])


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------


@pytest.fixture(scope="module")
def device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.fixture(scope="module")
def soma_rig(device):
    """Load Nova rig data needed for mesh-level mirror comparison."""
    if not CORE_ASSET.is_file():
        pytest.skip(
            f"Core asset not found: {CORE_ASSET}. Run `git lfs pull` to fetch LFS-tracked files."
        )

    rig_data = np.load(CORE_ASSET, allow_pickle=False)

    if "mirror_vert_indices" not in rig_data:
        pytest.skip(
            "SOMA_neutral.npz does not contain 'mirror_vert_indices'. "
            "Either update the NPZ or point the test at an NPZ that includes it."
        )

    sw_sp = csc_matrix(
        (
            rig_data["skinning_weights_data"],
            rig_data["skinning_weights_indices"],
            rig_data["skinning_weights_indptr"],
        ),
        shape=rig_data["skinning_weights_shape"],
    ).todense()

    facial_inner = np.concatenate(
        [
            rig_data["segment_eye_bags"],
            rig_data["segment_mouth_bag"],
        ]
    )
    body_mask = np.ones(len(rig_data["bind_shape"]), dtype=bool)
    body_mask[facial_inner] = False

    return dict(
        joint_names=rig_data["joint_names"].tolist(),
        joint_parent_ids=torch.from_numpy(rig_data["joint_parent_ids"].astype(np.int32).copy()).to(
            device
        ),
        bind_pose_world=torch.from_numpy(rig_data["bind_pose_world"]).to(device),
        bind_shape=torch.from_numpy(rig_data["bind_shape"]).to(device),
        skinning_weights=torch.from_numpy(np.array(sw_sp)).to(device),
        mirror_vert_indices=rig_data["mirror_vert_indices"],
        body_mask=body_mask,
    )


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


@torch.no_grad()
def test_soma_pose_mirror(device, soma_rig):
    """PoseMirror_SOMA: param-mirrored mesh should match geometric mesh mirror."""
    from soma.geometry.batched_skinning import BatchedSkinning
    from soma.geometry.rig_utils import (
        PoseMirror_SOMA,
        joint_local_to_world,
        joint_world_to_local,
    )

    npz_path = _resolve_npz_path("SOMA_POSE_NPZ", "SOMA_POSE_DIR")
    if npz_path is None:
        pytest.skip("Nova pose data not available (set SOMA_POSE_NPZ or SOMA_POSE_DIR)")

    soma_npz = np.load(npz_path)
    key = "pose_local" if "pose_local" in soma_npz else "transforms"
    pose_local_np = soma_npz[key].astype(np.float32)
    n_frames = min(pose_local_np.shape[0], MAX_FRAMES)
    pose_local = torch.from_numpy(pose_local_np[:n_frames]).to(device)

    joint_parent_ids = soma_rig["joint_parent_ids"]
    pose_world = joint_local_to_world(pose_local, joint_parent_ids)

    mirror = PoseMirror_SOMA(soma_rig["joint_names"])
    pose_mirror_world = mirror(pose_world)
    pose_mirror_local = joint_world_to_local(pose_mirror_world, joint_parent_ids)

    skinning = BatchedSkinning(
        joint_parent_ids,
        soma_rig["skinning_weights"],
        soma_rig["bind_pose_world"],
        soma_rig["bind_shape"],
        mode="warp",
    )

    verts_orig_list, verts_pm_list = [], []
    for s in range(0, n_frames, BATCH_SIZE):
        e = min(s + BATCH_SIZE, n_frames)
        verts_orig_list.append(
            skinning.pose(pose_local[s:e, :, :3, :3], pose_local[s:e, 1, :3, 3]).cpu()
        )
        verts_pm_list.append(
            skinning.pose(pose_mirror_local[s:e, :, :3, :3], pose_mirror_local[s:e, 1, :3, 3]).cpu()
        )

    verts_orig = torch.cat(verts_orig_list)
    verts_param_mirror = torch.cat(verts_pm_list)

    mvi = soma_rig["mirror_vert_indices"]
    verts_mesh_mirror = verts_orig.clone()
    verts_mesh_mirror[..., 0] *= -1
    verts_mesh_mirror = verts_mesh_mirror[:, mvi]

    body_mask = soma_rig["body_mask"]
    err = (verts_param_mirror - verts_mesh_mirror).norm(dim=-1)[:, body_mask].numpy()
    mean_err = float(err.mean())
    p99_err = float(np.percentile(err, 99))

    print(f"\n  [SOMA mirror] frames={n_frames}  mean={mean_err:.6f}  p99={p99_err:.6f}")

    assert mean_err < 0.01, f"SOMA mirror mean error {mean_err:.6f} exceeds 0.01 threshold"
    assert p99_err < 0.02, f"SOMA mirror P99 error {p99_err:.6f} exceeds 0.02 threshold"


@torch.no_grad()
def test_mhr_pose_mirror(device, soma_rig):
    """PoseMirror_MHR: param-mirrored mesh should match geometric mesh mirror."""
    import trimesh

    from soma.geometry.barycentric_interp import BarycentricInterpolator
    from soma.geometry.rig_utils import PoseMirror_MHR

    npz_path = _resolve_npz_path("MHR_POSE_NPZ", "MHR_POSE_DIR")
    if npz_path is None:
        pytest.skip("MHR pose data not available (set MHR_POSE_NPZ or MHR_POSE_DIR)")

    pt_path = ASSETS_DIR / "MHR" / "parameter_transform.npz"
    model_path = ASSETS_DIR / "MHR" / "mhr_model_lod1.pt"
    mesh_mhr_path = ASSETS_DIR / "MHR" / "base_body_lod1.obj"
    mesh_soma_path = ASSETS_DIR / "MHR" / "SOMA_wrap_lod1.obj"
    for p in (pt_path, model_path, mesh_mhr_path, mesh_soma_path):
        if not p.is_file():
            pytest.skip(f"Required MHR asset not found: {p}")

    param_names = _load_mhr_param_names(pt_path)

    mhr_npz = np.load(npz_path)
    pp = torch.from_numpy(mhr_npz["pose_params"]).float().to(device)
    n_total = pp.shape[0]
    if pp.shape[1] < 204:
        pp = torch.cat([pp, torch.zeros(n_total, 204 - pp.shape[1], device=device)], 1)
    else:
        pp = pp[:, :204]
    n_frames = min(n_total, MAX_FRAMES)
    pp = pp[:n_frames]

    mirror = PoseMirror_MHR(param_names)
    pp_mirrored = mirror(pp)

    mhr_model = torch.jit.load(str(model_path), map_location=device)

    mesh_mhr = trimesh.load(str(mesh_mhr_path), maintain_order=True, process=False)
    mesh_soma = trimesh.load(str(mesh_soma_path), maintain_order=True, process=False)
    V_mhr = torch.from_numpy(mesh_mhr.vertices).float().to(device)
    F_mhr = torch.from_numpy(mesh_mhr.faces).to(device)
    V_soma = torch.from_numpy(mesh_soma.vertices).float().to(device)
    mhr_to_soma = BarycentricInterpolator(V_mhr, F_mhr, V_soma)

    id_c = torch.zeros(1, 45, device=device)
    fe = torch.zeros(1, 72, device=device)

    def mhr_forward(params_204):
        results = []
        for s in range(0, params_204.shape[0], BATCH_SIZE):
            e = min(s + BATCH_SIZE, params_204.shape[0])
            v, _ = mhr_model(
                id_c.expand(e - s, -1),
                params_204[s:e],
                fe.expand(e - s, -1),
            )
            results.append(mhr_to_soma(v).cpu())
            del v
        return torch.cat(results)

    verts_orig = mhr_forward(pp)
    verts_param_mirror = mhr_forward(pp_mirrored)

    mvi = soma_rig["mirror_vert_indices"]
    verts_mesh_mirror = verts_orig.clone()
    verts_mesh_mirror[..., 0] *= -1
    verts_mesh_mirror = verts_mesh_mirror[:, mvi]

    body_mask = soma_rig["body_mask"]
    err = (verts_param_mirror - verts_mesh_mirror).norm(dim=-1)[:, body_mask].numpy()
    mean_err = float(err.mean())
    p99_err = float(np.percentile(err, 99))

    print(f"\n  [MHR mirror] frames={n_frames}  mean={mean_err:.6f}  p99={p99_err:.6f}")

    assert mean_err < 0.005, f"MHR mirror mean error {mean_err:.6f} exceeds 0.005 threshold"
    assert p99_err < 0.01, f"MHR mirror P99 error {p99_err:.6f} exceeds 0.01 threshold"