| |
| |
|
|
| """ |
| Tests verifying SOMALayer works correctly with PyTorch DataLoader under multiprocessing. |
| |
| The specific concern is that Warp-based operations (LBS, rotation fitting) and Warp's own |
| initialization could fail in forked worker processes. Three patterns are exercised: |
| 1. Warp called only in main process (workers just load tensors) |
| 2. Lazy SOMALayer initialization inside worker __getitem__ |
| 3. Per-worker SOMALayer init via worker_init_fn |
| 4. spawn multiprocessing context (fresh processes, no fork state) |
| """ |
|
|
| import os |
| import tempfile |
| import unittest |
| from pathlib import Path |
|
|
| import torch |
| from torch.utils.data import DataLoader, Dataset |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| ASSETS_DIR = REPO_ROOT / "assets" |
|
|
| NUM_JOINTS = 77 |
|
|
|
|
| def _assets_available(): |
| return ASSETS_DIR.is_dir() and (ASSETS_DIR / "SOMA_neutral.npz").is_file() |
|
|
|
|
| class SomaPoseDataset(Dataset): |
| """Minimal dataset that returns pre-generated pose tensors. |
| |
| Workers only load CPU tensors — no Warp calls inside workers. |
| The SOMALayer forward pass is run in the main process collation step. |
| """ |
|
|
| def __init__(self, id_coeffs_dim, scale_dim, size=4): |
| self.poses = torch.zeros(size, NUM_JOINTS, 3) |
| self.identity_coeffs = torch.zeros(size, id_coeffs_dim) |
| self.scale_params = torch.zeros(size, scale_dim) |
| self.transl = torch.zeros(size, 3) |
|
|
| def __len__(self): |
| return len(self.poses) |
|
|
| def __getitem__(self, idx): |
| return ( |
| self.poses[idx], |
| self.identity_coeffs[idx], |
| self.scale_params[idx], |
| self.transl[idx], |
| ) |
|
|
|
|
| class SomaDataset(Dataset): |
| """Dataset that initializes SOMALayer once at construction.""" |
|
|
| def __init__(self, data_root, size=4): |
| from soma import SOMALayer |
|
|
| self.data_root = data_root |
| self._layer = SOMALayer( |
| data_root=self.data_root, |
| device="cpu", |
| identity_model_type="mhr", |
| mode="warp", |
| ) |
| im = self._layer.identity_model |
| self.poses = torch.zeros(size, NUM_JOINTS, 3) |
| self.identity_coeffs = torch.zeros(size, im.num_identity_coeffs) |
| self.scale_params = torch.zeros(size, im.num_scale_params) |
| self.transl = torch.zeros(size, 3) |
|
|
| def __len__(self): |
| return len(self.poses) |
|
|
| def __getitem__(self, idx): |
| pose = self.poses[idx].unsqueeze(0) |
| id_coeffs = self.identity_coeffs[idx].unsqueeze(0) |
| scale = self.scale_params[idx].unsqueeze(0) |
| transl = self.transl[idx].unsqueeze(0) |
| with torch.no_grad(): |
| out = self._layer(pose, id_coeffs, scale, transl) |
| return { |
| "vertices": out["vertices"].squeeze(0), |
| "joints": out["joints"].squeeze(0), |
| "pose": pose.squeeze(0), |
| "id_coeffs": id_coeffs.squeeze(0), |
| "scale": scale.squeeze(0), |
| "transl": transl.squeeze(0), |
| } |
|
|
|
|
| class _LazySomaDataset(Dataset): |
| """Dataset that initializes SOMALayer lazily inside __getitem__. |
| |
| This ensures Warp is initialized fresh in each worker process rather than |
| inheriting state from a fork of the main process. |
| """ |
|
|
| def __init__(self, data_root, id_coeffs_dim, scale_dim, size=4): |
| self.data_root = data_root |
| self.poses = torch.zeros(size, NUM_JOINTS, 3) |
| self.identity_coeffs = torch.zeros(size, id_coeffs_dim) |
| self.scale_params = torch.zeros(size, scale_dim) |
| self.transl = torch.zeros(size, 3) |
| self._layer = None |
|
|
| def __len__(self): |
| return len(self.poses) |
|
|
| def __getitem__(self, idx): |
| if self._layer is None: |
| from soma import SOMALayer |
|
|
| self._layer = SOMALayer( |
| data_root=self.data_root, |
| device="cpu", |
| identity_model_type="mhr", |
| mode="warp", |
| ) |
| pose = self.poses[idx].unsqueeze(0) |
| id_coeffs = self.identity_coeffs[idx].unsqueeze(0) |
| scale = self.scale_params[idx].unsqueeze(0) |
| transl = self.transl[idx].unsqueeze(0) |
| with torch.no_grad(): |
| out = self._layer(pose, id_coeffs, scale, transl) |
| return out["vertices"].squeeze(0), out["joints"].squeeze(0) |
|
|
|
|
| class _WorkerInitDataset(Dataset): |
| """Dataset where SOMALayer is injected by worker_init_fn.""" |
|
|
| def __init__(self, id_coeffs_dim, scale_dim, size=4): |
| self.poses = torch.zeros(size, NUM_JOINTS, 3) |
| self.identity_coeffs = torch.zeros(size, id_coeffs_dim) |
| self.scale_params = torch.zeros(size, scale_dim) |
| self.transl = torch.zeros(size, 3) |
| self._layer = None |
|
|
| def __len__(self): |
| return len(self.poses) |
|
|
| def __getitem__(self, idx): |
| pose = self.poses[idx].unsqueeze(0) |
| id_coeffs = self.identity_coeffs[idx].unsqueeze(0) |
| scale = self.scale_params[idx].unsqueeze(0) |
| transl = self.transl[idx].unsqueeze(0) |
| with torch.no_grad(): |
| out = self._layer(pose, id_coeffs, scale, transl) |
| return { |
| "vertices": out["vertices"].squeeze(0), |
| "joints": out["joints"].squeeze(0), |
| "pose": pose.squeeze(0), |
| "id_coeffs": id_coeffs.squeeze(0), |
| "scale": scale.squeeze(0), |
| "transl": transl.squeeze(0), |
| } |
|
|
|
|
| def _soma_worker_init(worker_id): |
| """worker_init_fn: initialize SOMALayer once per worker process.""" |
| info = torch.utils.data.get_worker_info() |
| from soma import SOMALayer |
|
|
| info.dataset._layer = SOMALayer( |
| data_root=str(ASSETS_DIR), |
| device="cpu", |
| identity_model_type="mhr", |
| mode="warp", |
| ) |
|
|
|
|
| class TestSomaLayerDataLoader(unittest.TestCase): |
| @classmethod |
| def setUpClass(cls): |
| if not _assets_available(): |
| raise unittest.SkipTest( |
| "Assets not found. Run `git lfs pull` to fetch SOMA_neutral.npz." |
| ) |
| cls.data_root = str(ASSETS_DIR) |
| |
| layer = cls._make_layer_static(cls.data_root, "cpu") |
| im = layer.identity_model |
| cls.id_coeffs_dim = im.num_identity_coeffs |
| cls.scale_dim = im.num_scale_params |
|
|
| @staticmethod |
| def _make_layer_static(data_root, device): |
| from soma import SOMALayer |
|
|
| return SOMALayer( |
| data_root=data_root, |
| device=device, |
| identity_model_type="mhr", |
| mode="warp", |
| ) |
|
|
| def _make_layer(self, device="cpu"): |
| return self._make_layer_static(self.data_root, device) |
|
|
| def _assert_output_shapes(self, vertices, joints, batch_size, num_verts): |
| self.assertEqual(vertices.dim(), 3) |
| self.assertEqual(vertices.shape[0], batch_size) |
| self.assertEqual(vertices.shape[2], 3) |
| self.assertEqual(vertices.shape[1], num_verts) |
| self.assertEqual(joints.shape, (batch_size, NUM_JOINTS, 3)) |
|
|
| def test_no_workers_baseline(self): |
| """Sanity check: single-process DataLoader, Warp ops work correctly.""" |
| layer = self._make_layer("cpu") |
| num_verts = layer.bind_shape.shape[0] |
| dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4) |
| loader = DataLoader(dataset, batch_size=2, num_workers=0) |
|
|
| for poses, id_coeffs, scale_params, transl in loader: |
| with torch.no_grad(): |
| out = layer(poses, id_coeffs, scale_params, transl) |
| self.assertIn("vertices", out) |
| self.assertIn("joints", out) |
| self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts) |
|
|
| def test_multi_worker_warp_in_main_process(self): |
| """Safe pattern: workers only load tensors; Warp called only in the main process.""" |
| layer = self._make_layer("cpu") |
| num_verts = layer.bind_shape.shape[0] |
| dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4) |
| loader = DataLoader(dataset, batch_size=2, num_workers=2) |
|
|
| for poses, id_coeffs, scale_params, transl in loader: |
| with torch.no_grad(): |
| out = layer(poses, id_coeffs, scale_params, transl) |
| self.assertIn("vertices", out) |
| self.assertIn("joints", out) |
| self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts) |
|
|
| def test_multi_worker_init_at_construction(self): |
| """Warp is initialized fresh inside each forked worker via lazy SOMALayer init.""" |
| import multiprocessing |
|
|
| if multiprocessing.get_start_method() != "fork": |
| self.skipTest( |
| "test requires fork-based DataLoader workers; " |
| f"current start method is {multiprocessing.get_start_method()!r}" |
| ) |
| dataset = SomaDataset(self.data_root, size=4) |
| loader = DataLoader(dataset, batch_size=2, num_workers=2) |
| if not torch.cuda.is_available(): |
| self.skipTest("CUDA not available") |
| device = "cuda:0" |
| if torch.cuda.device_count() > 1: |
| device = "cuda:1" |
| soma_layer = self._make_layer(device) |
|
|
| with tempfile.TemporaryFile() as _tmp: |
| _saved = os.dup(2) |
| os.dup2(_tmp.fileno(), 2) |
| try: |
| for data in loader: |
| |
| for key, value in data.items(): |
| data[key] = value.to(device) |
| vertices = data["vertices"] |
| joints = data["joints"] |
| batch_size = vertices.shape[0] |
| with torch.no_grad(): |
| out = soma_layer( |
| data["pose"], data["id_coeffs"], data["scale"], data["transl"] |
| ) |
| pred_joints = out["joints"] |
| diff_joints = (joints - pred_joints).abs().max() |
| self.assertLess(diff_joints, 1e-3) |
| self.assertEqual(vertices.dim(), 3) |
| self.assertEqual(vertices.shape[2], 3) |
| self.assertEqual(joints.shape, (batch_size, NUM_JOINTS, 3)) |
| finally: |
| os.dup2(_saved, 2) |
| os.close(_saved) |
| _tmp.seek(0) |
| _stderr = _tmp.read().decode("utf-8", errors="replace") |
|
|
| self.assertNotIn( |
| "Warp CUDA error 3", |
| _stderr, |
| "CUDA error 3 appeared in worker stderr — fork hook may not be working", |
| ) |
|
|
| def test_multi_worker_lazy_init_in_worker(self): |
| """Warp is initialized fresh inside each forked worker via lazy SOMALayer init.""" |
| dataset = _LazySomaDataset(self.data_root, self.id_coeffs_dim, self.scale_dim, size=4) |
| |
| loader = DataLoader(dataset, batch_size=2, num_workers=2) |
|
|
| for vertices, joints in loader: |
| batch_size = vertices.shape[0] |
| self.assertEqual(vertices.dim(), 3) |
| self.assertEqual(vertices.shape[2], 3) |
| self.assertEqual(joints.shape[0], batch_size) |
| self.assertEqual(joints.shape[1], NUM_JOINTS) |
| self.assertEqual(joints.shape[2], 3) |
|
|
| def test_multi_worker_worker_init_fn(self): |
| """Recommended pattern: SOMALayer initialized once per worker via worker_init_fn.""" |
| dataset = _WorkerInitDataset(self.id_coeffs_dim, self.scale_dim, size=4) |
| loader = DataLoader( |
| dataset, |
| batch_size=2, |
| num_workers=2, |
| worker_init_fn=_soma_worker_init, |
| ) |
| if not torch.cuda.is_available(): |
| self.skipTest("CUDA not available") |
|
|
| device = "cuda:0" |
| if torch.cuda.device_count() > 1: |
| device = "cuda:1" |
| soma_layer = self._make_layer(device) |
|
|
| with tempfile.TemporaryFile() as _tmp: |
| _saved = os.dup(2) |
| os.dup2(_tmp.fileno(), 2) |
| try: |
| for data in loader: |
| |
| for key, value in data.items(): |
| data[key] = value.to(device) |
| vertices = data["vertices"] |
| joints = data["joints"] |
| batch_size = vertices.shape[0] |
| with torch.no_grad(): |
| out = soma_layer( |
| data["pose"], data["id_coeffs"], data["scale"], data["transl"] |
| ) |
| pred_joints = out["joints"] |
| diff_joints = (joints - pred_joints).abs().max() |
| self.assertLess(diff_joints, 1e-3) |
| self.assertEqual(vertices.dim(), 3) |
| self.assertEqual(vertices.shape[2], 3) |
| self.assertEqual(joints.shape, (batch_size, NUM_JOINTS, 3)) |
| finally: |
| os.dup2(_saved, 2) |
| os.close(_saved) |
| _tmp.seek(0) |
| _stderr = _tmp.read().decode("utf-8", errors="replace") |
|
|
| self.assertNotIn( |
| "Warp CUDA error 3", |
| _stderr, |
| "CUDA error 3 appeared in worker stderr — fork hook may not be working", |
| ) |
|
|
| def test_spawn_context(self): |
| """spawn multiprocessing context: fresh processes, no fork state inheritance.""" |
| layer = self._make_layer("cpu") |
| num_verts = layer.bind_shape.shape[0] |
| dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4) |
| loader = DataLoader( |
| dataset, |
| batch_size=2, |
| num_workers=2, |
| multiprocessing_context="spawn", |
| ) |
|
|
| for poses, id_coeffs, scale_params, transl in loader: |
| with torch.no_grad(): |
| out = layer(poses, id_coeffs, scale_params, transl) |
| self.assertIn("vertices", out) |
| self.assertIn("joints", out) |
| self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts) |
|
|
| def test_cuda_spawn_context(self): |
| """CUDA-safe pattern: spawn context avoids CUDA fork issues. Skipped if no GPU.""" |
| if not torch.cuda.is_available(): |
| self.skipTest("CUDA not available") |
|
|
| layer = self._make_layer("cuda") |
| num_verts = layer.bind_shape.shape[0] |
| dataset = SomaPoseDataset(self.id_coeffs_dim, self.scale_dim, size=4) |
| loader = DataLoader( |
| dataset, |
| batch_size=2, |
| num_workers=2, |
| multiprocessing_context="spawn", |
| ) |
|
|
| for poses, id_coeffs, scale_params, transl in loader: |
| poses = poses.cuda() |
| id_coeffs = id_coeffs.cuda() |
| scale_params = scale_params.cuda() |
| transl = transl.cuda() |
| with torch.no_grad(): |
| out = layer(poses, id_coeffs, scale_params, transl) |
| self.assertIn("vertices", out) |
| self.assertIn("joints", out) |
| self._assert_output_shapes(out["vertices"], out["joints"], 2, num_verts) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|