somatosmpl / tests /test_dataloader.py
zirobtc's picture
Upload folder using huggingface_hub
bd95c9c verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Tests verifying SOMALayer works correctly with PyTorch DataLoader under multiprocessing.
The specific concern is that Warp-based operations (LBS, rotation fitting) and Warp's own
initialization could fail in forked worker processes. Three patterns are exercised:
1. Warp called only in main process (workers just load tensors)
2. Lazy SOMALayer initialization inside worker __getitem__
3. Per-worker SOMALayer init via worker_init_fn
4. spawn multiprocessing context (fresh processes, no fork state)
"""
import os
import tempfile
import unittest
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
REPO_ROOT = Path(__file__).resolve().parents[1]
ASSETS_DIR = REPO_ROOT / "assets"
NUM_JOINTS = 77
def _assets_available():
return ASSETS_DIR.is_dir() and (ASSETS_DIR / "SOMA_neutral.npz").is_file()
class SomaPoseDataset(Dataset):
"""Minimal dataset that returns pre-generated pose tensors.
Workers only load CPU tensors — no Warp calls inside workers.
The SOMALayer forward pass is run in the main process collation step.
"""
def __init__(self, id_coeffs_dim, scale_dim, size=4):
self.poses = torch.zeros(size, NUM_JOINTS, 3)
self.identity_coeffs = torch.zeros(size, id_coeffs_dim)
self.scale_params = torch.zeros(size, scale_dim)
self.transl = torch.zeros(size, 3)
def __len__(self):
return len(self.poses)
def __getitem__(self, idx):
return (
self.poses[idx],
self.identity_coeffs[idx],
self.scale_params[idx],
self.transl[idx],
)
class SomaDataset(Dataset):
"""Dataset that initializes SOMALayer once at construction."""
def __init__(self, data_root, size=4):
from soma import SOMALayer
self.data_root = data_root
self._layer = SOMALayer(
data_root=self.data_root,
device="cpu",
identity_model_type="mhr",
mode="warp",
)
im = self._layer.identity_model
self.poses = torch.zeros(size, NUM_JOINTS, 3)
self.identity_coeffs = torch.zeros(size, im.num_identity_coeffs)
self.scale_params = torch.zeros(size, im.num_scale_params)
self.transl = torch.zeros(size, 3)
def __len__(self):
return len(self.poses)
def __getitem__(self, idx):
pose = self.poses[idx].unsqueeze(0)
id_coeffs = self.identity_coeffs[idx].unsqueeze(0)
scale = self.scale_params[idx].unsqueeze(0)
transl = self.transl[idx].unsqueeze(0)
with torch.no_grad():
out = self._layer(pose, id_coeffs, scale, transl)
return {
"vertices": out["vertices"].squeeze(0),
"joints": out["joints"].squeeze(0),
"pose": pose.squeeze(0),
"id_coeffs": id_coeffs.squeeze(0),
"scale": scale.squeeze(0),
"transl": transl.squeeze(0),
}
class _LazySomaDataset(Dataset):
"""Dataset that initializes SOMALayer lazily inside __getitem__.
This ensures Warp is initialized fresh in each worker process rather than
inheriting state from a fork of the main process.
"""
def __init__(self, data_root, id_coeffs_dim, scale_dim, size=4):
self.data_root = data_root
self.poses = torch.zeros(size, NUM_JOINTS, 3)
self.identity_coeffs = torch.zeros(size, id_coeffs_dim)
self.scale_params = torch.zeros(size, scale_dim)
self.transl = torch.zeros(size, 3)
self._layer = None
def __len__(self):
return len(self.poses)
def __getitem__(self, idx):
if self._layer is None:
from soma import SOMALayer
self._layer = SOMALayer(
data_root=self.data_root,
device="cpu",
identity_model_type="mhr",
mode="warp",
)
pose = self.poses[idx].unsqueeze(0)
id_coeffs = self.identity_coeffs[idx].unsqueeze(0)
scale = self.scale_params[idx].unsqueeze(0)
transl = self.transl[idx].unsqueeze(0)
with torch.no_grad():
out = self._layer(pose, id_coeffs, scale, transl)
return out["vertices"].squeeze(0), out["joints"].squeeze(0)
class _WorkerInitDataset(Dataset):
"""Dataset where SOMALayer is injected by worker_init_fn."""
def __init__(self, id_coeffs_dim, scale_dim, size=4):
self.poses = torch.zeros(size, NUM_JOINTS, 3)
self.identity_coeffs = torch.zeros(size, id_coeffs_dim)
self.scale_params = torch.zeros(size, scale_dim)
self.transl = torch.zeros(size, 3)
self._layer = None # set by worker_init_fn
def __len__(self):
return len(self.poses)
def __getitem__(self, idx):
pose = self.poses[idx].unsqueeze(0)
id_coeffs = self.identity_coeffs[idx].unsqueeze(0)
scale = self.scale_params[idx].unsqueeze(0)
transl = self.transl[idx].unsqueeze(0)
with torch.no_grad():
out = self._layer(pose, id_coeffs, scale, transl)
return {
"vertices": out["vertices"].squeeze(0),
"joints": out["joints"].squeeze(0),
"pose": pose.squeeze(0),
"id_coeffs": id_coeffs.squeeze(0),
"scale": scale.squeeze(0),
"transl": transl.squeeze(0),
}
def _soma_worker_init(worker_id):
"""worker_init_fn: initialize SOMALayer once per worker process."""
info = torch.utils.data.get_worker_info()
from soma import SOMALayer
info.dataset._layer = SOMALayer(
data_root=str(ASSETS_DIR),
device="cpu",
identity_model_type="mhr",
mode="warp",
)
class TestSomaLayerDataLoader(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not _assets_available():
raise unittest.SkipTest(
"Assets not found. Run `git lfs pull` to fetch SOMA_neutral.npz."
)
cls.data_root = str(ASSETS_DIR)
# Query dims once so datasets don't hardcode them.
layer = cls._make_layer_static(cls.data_root, "cpu")
im = layer.identity_model
cls.id_coeffs_dim = im.num_identity_coeffs
cls.scale_dim = im.num_scale_params
@staticmethod
def _make_layer_static(data_root, device):
from soma import SOMALayer
return SOMALayer(
data_root=data_root,
device=device,
identity_model_type="mhr",
mode="warp",
)
def _make_layer(self, device="cpu"):
return self._make_layer_static(self.data_root, device)
def _assert_output_shapes(self, vertices, joints, batch_size, num_verts):
self.assertEqual(vertices.dim(), 3)
self.assertEqual(vertices.shape[0], batch_size)
self.assertEqual(vertices.shape[2], 3)
self.assertEqual(vertices.shape[1], num_verts)
self.assertEqual(joints.shape, (batch_size, NUM_JOINTS, 3))
def test_no_workers_baseline(self):
"""Sanity check: single-process DataLoader, Warp ops work correctly."""
layer = self._make_layer("cpu")
num_verts = layer.bind_shape.shape[0]
dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4)
loader = DataLoader(dataset, batch_size=2, num_workers=0)
for poses, id_coeffs, scale_params, transl in loader:
with torch.no_grad():
out = layer(poses, id_coeffs, scale_params, transl)
self.assertIn("vertices", out)
self.assertIn("joints", out)
self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts)
def test_multi_worker_warp_in_main_process(self):
"""Safe pattern: workers only load tensors; Warp called only in the main process."""
layer = self._make_layer("cpu")
num_verts = layer.bind_shape.shape[0]
dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4)
loader = DataLoader(dataset, batch_size=2, num_workers=2)
for poses, id_coeffs, scale_params, transl in loader:
with torch.no_grad():
out = layer(poses, id_coeffs, scale_params, transl)
self.assertIn("vertices", out)
self.assertIn("joints", out)
self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts)
def test_multi_worker_init_at_construction(self):
"""Warp is initialized fresh inside each forked worker via lazy SOMALayer init."""
import multiprocessing
if multiprocessing.get_start_method() != "fork":
self.skipTest(
"test requires fork-based DataLoader workers; "
f"current start method is {multiprocessing.get_start_method()!r}"
)
dataset = SomaDataset(self.data_root, size=4)
loader = DataLoader(dataset, batch_size=2, num_workers=2)
if not torch.cuda.is_available():
self.skipTest("CUDA not available")
device = "cuda:0"
if torch.cuda.device_count() > 1:
device = "cuda:1"
soma_layer = self._make_layer(device)
with tempfile.TemporaryFile() as _tmp:
_saved = os.dup(2)
os.dup2(_tmp.fileno(), 2)
try:
for data in loader:
# move to device
for key, value in data.items():
data[key] = value.to(device)
vertices = data["vertices"]
joints = data["joints"]
batch_size = vertices.shape[0]
with torch.no_grad():
out = soma_layer(
data["pose"], data["id_coeffs"], data["scale"], data["transl"]
)
pred_joints = out["joints"]
diff_joints = (joints - pred_joints).abs().max()
self.assertLess(diff_joints, 1e-3)
self.assertEqual(vertices.dim(), 3)
self.assertEqual(vertices.shape[2], 3)
self.assertEqual(joints.shape, (batch_size, NUM_JOINTS, 3))
finally:
os.dup2(_saved, 2)
os.close(_saved)
_tmp.seek(0)
_stderr = _tmp.read().decode("utf-8", errors="replace")
self.assertNotIn(
"Warp CUDA error 3",
_stderr,
"CUDA error 3 appeared in worker stderr — fork hook may not be working",
)
def test_multi_worker_lazy_init_in_worker(self):
"""Warp is initialized fresh inside each forked worker via lazy SOMALayer init."""
dataset = _LazySomaDataset(self.data_root, self.id_coeffs_dim, self.scale_dim, size=4)
# num_verts is unknown without a layer; just check dims
loader = DataLoader(dataset, batch_size=2, num_workers=2)
for vertices, joints in loader:
batch_size = vertices.shape[0]
self.assertEqual(vertices.dim(), 3)
self.assertEqual(vertices.shape[2], 3)
self.assertEqual(joints.shape[0], batch_size)
self.assertEqual(joints.shape[1], NUM_JOINTS)
self.assertEqual(joints.shape[2], 3)
def test_multi_worker_worker_init_fn(self):
"""Recommended pattern: SOMALayer initialized once per worker via worker_init_fn."""
dataset = _WorkerInitDataset(self.id_coeffs_dim, self.scale_dim, size=4)
loader = DataLoader(
dataset,
batch_size=2,
num_workers=2,
worker_init_fn=_soma_worker_init,
)
if not torch.cuda.is_available():
self.skipTest("CUDA not available")
device = "cuda:0"
if torch.cuda.device_count() > 1:
device = "cuda:1"
soma_layer = self._make_layer(device)
with tempfile.TemporaryFile() as _tmp:
_saved = os.dup(2)
os.dup2(_tmp.fileno(), 2)
try:
for data in loader:
# move to device
for key, value in data.items():
data[key] = value.to(device)
vertices = data["vertices"]
joints = data["joints"]
batch_size = vertices.shape[0]
with torch.no_grad():
out = soma_layer(
data["pose"], data["id_coeffs"], data["scale"], data["transl"]
)
pred_joints = out["joints"]
diff_joints = (joints - pred_joints).abs().max()
self.assertLess(diff_joints, 1e-3)
self.assertEqual(vertices.dim(), 3)
self.assertEqual(vertices.shape[2], 3)
self.assertEqual(joints.shape, (batch_size, NUM_JOINTS, 3))
finally:
os.dup2(_saved, 2)
os.close(_saved)
_tmp.seek(0)
_stderr = _tmp.read().decode("utf-8", errors="replace")
self.assertNotIn(
"Warp CUDA error 3",
_stderr,
"CUDA error 3 appeared in worker stderr — fork hook may not be working",
)
def test_spawn_context(self):
"""spawn multiprocessing context: fresh processes, no fork state inheritance."""
layer = self._make_layer("cpu")
num_verts = layer.bind_shape.shape[0]
dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4)
loader = DataLoader(
dataset,
batch_size=2,
num_workers=2,
multiprocessing_context="spawn",
)
for poses, id_coeffs, scale_params, transl in loader:
with torch.no_grad():
out = layer(poses, id_coeffs, scale_params, transl)
self.assertIn("vertices", out)
self.assertIn("joints", out)
self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts)
def test_cuda_spawn_context(self):
"""CUDA-safe pattern: spawn context avoids CUDA fork issues. Skipped if no GPU."""
if not torch.cuda.is_available():
self.skipTest("CUDA not available")
layer = self._make_layer("cuda")
num_verts = layer.bind_shape.shape[0]
dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4)
loader = DataLoader(
dataset,
batch_size=2,
num_workers=2,
multiprocessing_context="spawn",
)
for poses, id_coeffs, scale_params, transl in loader:
poses = poses.cuda()
id_coeffs = id_coeffs.cuda()
scale_params = scale_params.cuda()
transl = transl.cuda()
with torch.no_grad():
out = layer(poses, id_coeffs, scale_params, transl)
self.assertIn("vertices", out)
self.assertIn("joints", out)
self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts)
if __name__ == "__main__":
unittest.main()