File size: 15,272 Bytes
bd95c9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Tests verifying SOMALayer works correctly with PyTorch DataLoader under multiprocessing.

The specific concern is that Warp-based operations (LBS, rotation fitting) and Warp's own
initialization could fail in forked worker processes. Three patterns are exercised:
  1. Warp called only in main process (workers just load tensors)
  2. Lazy SOMALayer initialization inside worker __getitem__
  3. Per-worker SOMALayer init via worker_init_fn
  4. spawn multiprocessing context (fresh processes, no fork state)
"""

import os
import tempfile
import unittest
from pathlib import Path

import torch
from torch.utils.data import DataLoader, Dataset

REPO_ROOT = Path(__file__).resolve().parents[1]
ASSETS_DIR = REPO_ROOT / "assets"

NUM_JOINTS = 77


def _assets_available():
    return ASSETS_DIR.is_dir() and (ASSETS_DIR / "SOMA_neutral.npz").is_file()


class SomaPoseDataset(Dataset):
    """Minimal dataset that returns pre-generated pose tensors.

    Workers only load CPU tensors — no Warp calls inside workers.
    The SOMALayer forward pass is run in the main process collation step.
    """

    def __init__(self, id_coeffs_dim, scale_dim, size=4):
        self.poses = torch.zeros(size, NUM_JOINTS, 3)
        self.identity_coeffs = torch.zeros(size, id_coeffs_dim)
        self.scale_params = torch.zeros(size, scale_dim)
        self.transl = torch.zeros(size, 3)

    def __len__(self):
        return len(self.poses)

    def __getitem__(self, idx):
        return (
            self.poses[idx],
            self.identity_coeffs[idx],
            self.scale_params[idx],
            self.transl[idx],
        )


class SomaDataset(Dataset):
    """Dataset that initializes SOMALayer once at construction."""

    def __init__(self, data_root, size=4):
        from soma import SOMALayer

        self.data_root = data_root
        self._layer = SOMALayer(
            data_root=self.data_root,
            device="cpu",
            identity_model_type="mhr",
            mode="warp",
        )
        im = self._layer.identity_model
        self.poses = torch.zeros(size, NUM_JOINTS, 3)
        self.identity_coeffs = torch.zeros(size, im.num_identity_coeffs)
        self.scale_params = torch.zeros(size, im.num_scale_params)
        self.transl = torch.zeros(size, 3)

    def __len__(self):
        return len(self.poses)

    def __getitem__(self, idx):
        pose = self.poses[idx].unsqueeze(0)
        id_coeffs = self.identity_coeffs[idx].unsqueeze(0)
        scale = self.scale_params[idx].unsqueeze(0)
        transl = self.transl[idx].unsqueeze(0)
        with torch.no_grad():
            out = self._layer(pose, id_coeffs, scale, transl)
        return {
            "vertices": out["vertices"].squeeze(0),
            "joints": out["joints"].squeeze(0),
            "pose": pose.squeeze(0),
            "id_coeffs": id_coeffs.squeeze(0),
            "scale": scale.squeeze(0),
            "transl": transl.squeeze(0),
        }


class _LazySomaDataset(Dataset):
    """Dataset that initializes SOMALayer lazily inside __getitem__.

    This ensures Warp is initialized fresh in each worker process rather than
    inheriting state from a fork of the main process.
    """

    def __init__(self, data_root, id_coeffs_dim, scale_dim, size=4):
        self.data_root = data_root
        self.poses = torch.zeros(size, NUM_JOINTS, 3)
        self.identity_coeffs = torch.zeros(size, id_coeffs_dim)
        self.scale_params = torch.zeros(size, scale_dim)
        self.transl = torch.zeros(size, 3)
        self._layer = None

    def __len__(self):
        return len(self.poses)

    def __getitem__(self, idx):
        if self._layer is None:
            from soma import SOMALayer

            self._layer = SOMALayer(
                data_root=self.data_root,
                device="cpu",
                identity_model_type="mhr",
                mode="warp",
            )
        pose = self.poses[idx].unsqueeze(0)
        id_coeffs = self.identity_coeffs[idx].unsqueeze(0)
        scale = self.scale_params[idx].unsqueeze(0)
        transl = self.transl[idx].unsqueeze(0)
        with torch.no_grad():
            out = self._layer(pose, id_coeffs, scale, transl)
        return out["vertices"].squeeze(0), out["joints"].squeeze(0)


class _WorkerInitDataset(Dataset):
    """Dataset where SOMALayer is injected by worker_init_fn."""

    def __init__(self, id_coeffs_dim, scale_dim, size=4):
        self.poses = torch.zeros(size, NUM_JOINTS, 3)
        self.identity_coeffs = torch.zeros(size, id_coeffs_dim)
        self.scale_params = torch.zeros(size, scale_dim)
        self.transl = torch.zeros(size, 3)
        self._layer = None  # set by worker_init_fn

    def __len__(self):
        return len(self.poses)

    def __getitem__(self, idx):
        pose = self.poses[idx].unsqueeze(0)
        id_coeffs = self.identity_coeffs[idx].unsqueeze(0)
        scale = self.scale_params[idx].unsqueeze(0)
        transl = self.transl[idx].unsqueeze(0)
        with torch.no_grad():
            out = self._layer(pose, id_coeffs, scale, transl)
        return {
            "vertices": out["vertices"].squeeze(0),
            "joints": out["joints"].squeeze(0),
            "pose": pose.squeeze(0),
            "id_coeffs": id_coeffs.squeeze(0),
            "scale": scale.squeeze(0),
            "transl": transl.squeeze(0),
        }


def _soma_worker_init(worker_id):
    """worker_init_fn: initialize SOMALayer once per worker process."""
    info = torch.utils.data.get_worker_info()
    from soma import SOMALayer

    info.dataset._layer = SOMALayer(
        data_root=str(ASSETS_DIR),
        device="cpu",
        identity_model_type="mhr",
        mode="warp",
    )


class TestSomaLayerDataLoader(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        if not _assets_available():
            raise unittest.SkipTest(
                "Assets not found. Run `git lfs pull` to fetch SOMA_neutral.npz."
            )
        cls.data_root = str(ASSETS_DIR)
        # Query dims once so datasets don't hardcode them.
        layer = cls._make_layer_static(cls.data_root, "cpu")
        im = layer.identity_model
        cls.id_coeffs_dim = im.num_identity_coeffs
        cls.scale_dim = im.num_scale_params

    @staticmethod
    def _make_layer_static(data_root, device):
        from soma import SOMALayer

        return SOMALayer(
            data_root=data_root,
            device=device,
            identity_model_type="mhr",
            mode="warp",
        )

    def _make_layer(self, device="cpu"):
        return self._make_layer_static(self.data_root, device)

    def _assert_output_shapes(self, vertices, joints, batch_size, num_verts):
        self.assertEqual(vertices.dim(), 3)
        self.assertEqual(vertices.shape[0], batch_size)
        self.assertEqual(vertices.shape[2], 3)
        self.assertEqual(vertices.shape[1], num_verts)
        self.assertEqual(joints.shape, (batch_size, NUM_JOINTS, 3))

    def test_no_workers_baseline(self):
        """Sanity check: single-process DataLoader, Warp ops work correctly."""
        layer = self._make_layer("cpu")
        num_verts = layer.bind_shape.shape[0]
        dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4)
        loader = DataLoader(dataset, batch_size=2, num_workers=0)

        for poses, id_coeffs, scale_params, transl in loader:
            with torch.no_grad():
                out = layer(poses, id_coeffs, scale_params, transl)
            self.assertIn("vertices", out)
            self.assertIn("joints", out)
            self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts)

    def test_multi_worker_warp_in_main_process(self):
        """Safe pattern: workers only load tensors; Warp called only in the main process."""
        layer = self._make_layer("cpu")
        num_verts = layer.bind_shape.shape[0]
        dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4)
        loader = DataLoader(dataset, batch_size=2, num_workers=2)

        for poses, id_coeffs, scale_params, transl in loader:
            with torch.no_grad():
                out = layer(poses, id_coeffs, scale_params, transl)
            self.assertIn("vertices", out)
            self.assertIn("joints", out)
            self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts)

    def test_multi_worker_init_at_construction(self):
        """Warp is initialized fresh inside each forked worker via lazy SOMALayer init."""
        import multiprocessing

        if multiprocessing.get_start_method() != "fork":
            self.skipTest(
                "test requires fork-based DataLoader workers; "
                f"current start method is {multiprocessing.get_start_method()!r}"
            )
        dataset = SomaDataset(self.data_root, size=4)
        loader = DataLoader(dataset, batch_size=2, num_workers=2)
        if not torch.cuda.is_available():
            self.skipTest("CUDA not available")
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            device = "cuda:1"
        soma_layer = self._make_layer(device)

        with tempfile.TemporaryFile() as _tmp:
            _saved = os.dup(2)
            os.dup2(_tmp.fileno(), 2)
            try:
                for data in loader:
                    # move to device
                    for key, value in data.items():
                        data[key] = value.to(device)
                    vertices = data["vertices"]
                    joints = data["joints"]
                    batch_size = vertices.shape[0]
                    with torch.no_grad():
                        out = soma_layer(
                            data["pose"], data["id_coeffs"], data["scale"], data["transl"]
                        )
                        pred_joints = out["joints"]
                    diff_joints = (joints - pred_joints).abs().max()
                    self.assertLess(diff_joints, 1e-3)
                    self.assertEqual(vertices.dim(), 3)
                    self.assertEqual(vertices.shape[2], 3)
                    self.assertEqual(joints.shape, (batch_size, NUM_JOINTS, 3))
            finally:
                os.dup2(_saved, 2)
                os.close(_saved)
            _tmp.seek(0)
            _stderr = _tmp.read().decode("utf-8", errors="replace")

        self.assertNotIn(
            "Warp CUDA error 3",
            _stderr,
            "CUDA error 3 appeared in worker stderr — fork hook may not be working",
        )

    def test_multi_worker_lazy_init_in_worker(self):
        """Warp is initialized fresh inside each forked worker via lazy SOMALayer init."""
        dataset = _LazySomaDataset(self.data_root, self.id_coeffs_dim, self.scale_dim, size=4)
        # num_verts is unknown without a layer; just check dims
        loader = DataLoader(dataset, batch_size=2, num_workers=2)

        for vertices, joints in loader:
            batch_size = vertices.shape[0]
            self.assertEqual(vertices.dim(), 3)
            self.assertEqual(vertices.shape[2], 3)
            self.assertEqual(joints.shape[0], batch_size)
            self.assertEqual(joints.shape[1], NUM_JOINTS)
            self.assertEqual(joints.shape[2], 3)

    def test_multi_worker_worker_init_fn(self):
        """Recommended pattern: SOMALayer initialized once per worker via worker_init_fn."""
        dataset = _WorkerInitDataset(self.id_coeffs_dim, self.scale_dim, size=4)
        loader = DataLoader(
            dataset,
            batch_size=2,
            num_workers=2,
            worker_init_fn=_soma_worker_init,
        )
        if not torch.cuda.is_available():
            self.skipTest("CUDA not available")

        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            device = "cuda:1"
        soma_layer = self._make_layer(device)

        with tempfile.TemporaryFile() as _tmp:
            _saved = os.dup(2)
            os.dup2(_tmp.fileno(), 2)
            try:
                for data in loader:
                    # move to device
                    for key, value in data.items():
                        data[key] = value.to(device)
                    vertices = data["vertices"]
                    joints = data["joints"]
                    batch_size = vertices.shape[0]
                    with torch.no_grad():
                        out = soma_layer(
                            data["pose"], data["id_coeffs"], data["scale"], data["transl"]
                        )
                        pred_joints = out["joints"]
                    diff_joints = (joints - pred_joints).abs().max()
                    self.assertLess(diff_joints, 1e-3)
                    self.assertEqual(vertices.dim(), 3)
                    self.assertEqual(vertices.shape[2], 3)
                    self.assertEqual(joints.shape, (batch_size, NUM_JOINTS, 3))
            finally:
                os.dup2(_saved, 2)
                os.close(_saved)
            _tmp.seek(0)
            _stderr = _tmp.read().decode("utf-8", errors="replace")

        self.assertNotIn(
            "Warp CUDA error 3",
            _stderr,
            "CUDA error 3 appeared in worker stderr — fork hook may not be working",
        )

    def test_spawn_context(self):
        """spawn multiprocessing context: fresh processes, no fork state inheritance."""
        layer = self._make_layer("cpu")
        num_verts = layer.bind_shape.shape[0]
        dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4)
        loader = DataLoader(
            dataset,
            batch_size=2,
            num_workers=2,
            multiprocessing_context="spawn",
        )

        for poses, id_coeffs, scale_params, transl in loader:
            with torch.no_grad():
                out = layer(poses, id_coeffs, scale_params, transl)
            self.assertIn("vertices", out)
            self.assertIn("joints", out)
            self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts)

    def test_cuda_spawn_context(self):
        """CUDA-safe pattern: spawn context avoids CUDA fork issues. Skipped if no GPU."""
        if not torch.cuda.is_available():
            self.skipTest("CUDA not available")

        layer = self._make_layer("cuda")
        num_verts = layer.bind_shape.shape[0]
        dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4)
        loader = DataLoader(
            dataset,
            batch_size=2,
            num_workers=2,
            multiprocessing_context="spawn",
        )

        for poses, id_coeffs, scale_params, transl in loader:
            poses = poses.cuda()
            id_coeffs = id_coeffs.cuda()
            scale_params = scale_params.cuda()
            transl = transl.cuda()
            with torch.no_grad():
                out = layer(poses, id_coeffs, scale_params, transl)
            self.assertIn("vertices", out)
            self.assertIn("joints", out)
            self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts)


if __name__ == "__main__":
    unittest.main()