somatosmpl / tools /batch_soma2smpl.py
zirobtc's picture
Upload folder using huggingface_hub
bd95c9c verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Batch BONES-SEED SOMA BVH -> SMPL-X conversion.
This keeps the expensive SOMA/SMPL-X objects resident and loops over BVHs,
instead of spawning one Python process per motion.
"""
from __future__ import annotations
import argparse
import gc
import json
import os
import resource
import sys
import time
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
repo_root = Path(__file__).resolve().parents[1]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
from tools.soma2smpl import (
BVHMotion,
SMPLXInversion,
_create_bones_soma,
_fit_smplx_betas_to_bones_soma,
_make_soma_to_smplx_transfer,
_parse_betas_arg,
_parse_bvh,
_save_smplx_npz,
_smplx_forward_from_result,
)
def _iter_bvhs(dataset_root: Path, limit: int | None) -> list[Path]:
root = dataset_root / "bvh"
files = sorted(root.rglob("*.bvh") if root.exists() else dataset_root.rglob("*.bvh"))
if limit is not None:
files = files[:limit]
return files
def _output_path(dataset_root: Path, output_root: Path, bvh: Path) -> Path:
bvh_root = dataset_root / "bvh"
try:
rel = bvh.relative_to(bvh_root)
except ValueError:
rel = bvh.relative_to(dataset_root)
return output_root / rel.with_suffix(".npz")
def _amass_payload(
result: dict,
fps: float,
betas: torch.Tensor,
inv: SMPLXInversion,
):
params = _smplx_forward_params(result, betas, inv)
num_frames = params["root_orient"].shape[0]
zeros_99 = torch.zeros(num_frames, 99, device=betas.device, dtype=params["root_orient"].dtype)
poses = torch.cat([params["root_orient"], params["pose_body"], zeros_99], dim=-1)
return {
"mocap_framerate": np.array(float(fps), dtype=np.float32),
"gender": np.array("neutral"),
"betas": betas[0].detach().cpu().numpy().astype(np.float32),
"trans": params["trans"].detach().cpu().numpy().astype(np.float32),
"poses": poses.detach().cpu().numpy().astype(np.float32),
}
def _save_amass_npz(
path: Path,
result: dict,
fps: float,
betas: torch.Tensor,
inv: SMPLXInversion,
compressed: bool,
):
path.parent.mkdir(parents=True, exist_ok=True)
payload = _amass_payload(result, fps, betas, inv)
tmp = path.with_suffix(path.suffix + ".tmp")
with tmp.open("wb") as f:
if compressed:
np.savez_compressed(f, **payload)
else:
np.savez(f, **payload)
os.replace(tmp, path)
def _save_legacy_full_npz(
path: Path,
result: dict,
fps: float,
betas: torch.Tensor,
inv: SMPLXInversion,
source_bvh: Path,
mean_error: float,
max_error: float,
):
path.parent.mkdir(parents=True, exist_ok=True)
params = _smplx_forward_params(result, betas, inv)
num_frames = params["root_orient"].shape[0]
tmp = path.with_suffix(path.suffix + ".tmp")
with tmp.open("wb") as f:
np.savez_compressed(
f,
trans=params["trans"].detach().cpu().numpy().astype(np.float32),
root_orient=params["root_orient"].detach().cpu().numpy().astype(np.float32),
pose_body=params["pose_body"].detach().cpu().numpy().astype(np.float32),
pose_hand=params["pose_hand"].detach().cpu().numpy().astype(np.float32),
pose_jaw=params["pose_jaw"].detach().cpu().numpy().astype(np.float32),
pose_eye=params["pose_eye"].detach().cpu().numpy().astype(np.float32),
betas=betas[0].detach().cpu().numpy().astype(np.float32),
num_betas=np.array(10, dtype=np.int32),
gender=np.array("neutral"),
surface_model_type=np.array("smplx"),
mocap_frame_rate=np.array(float(fps), dtype=np.float32),
mocap_time_length=np.array(num_frames / float(fps), dtype=np.float32),
source_bvh=np.array(str(source_bvh)),
fit_error_mean=np.array(mean_error, dtype=np.float32),
fit_error_max=np.array(max_error, dtype=np.float32),
)
os.replace(tmp, path)
def _smplx_forward_params(result: dict, betas: torch.Tensor, inv: SMPLXInversion):
from tools.soma2smpl import _smplx_pose_params_from_result
return _smplx_pose_params_from_result(inv, result, betas)
def _concat_results(parts: list[dict]) -> dict:
out = {
"rotations": torch.cat([p["rotations"] for p in parts], dim=0),
"root_translation": torch.cat([p["root_translation"] for p in parts], dim=0),
"per_vertex_error": torch.cat([p["per_vertex_error"] for p in parts], dim=0),
}
if "vertices" in parts[0]:
out["vertices"] = torch.cat([p["vertices"] for p in parts], dim=0)
return out
def _slice_result(result: dict, start: int, end: int) -> dict:
out = {
"rotations": result["rotations"][start:end],
"root_translation": result["root_translation"][start:end],
"per_vertex_error": result["per_vertex_error"][start:end],
}
if "vertices" in result:
out["vertices"] = result["vertices"][start:end]
return out
def _append_manifest(path: Path | None, row: dict):
if path is None:
return
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as f:
f.write(json.dumps(row, sort_keys=True) + "\n")
def _forward_soma_bvh_resident(soma, motion: BVHMotion, batch_size: int):
verts = []
for start in range(0, motion.local_rot_mats.shape[0], batch_size):
end = min(start + batch_size, motion.local_rot_mats.shape[0])
with torch.no_grad():
out = soma.pose(
motion.local_rot_mats[start:end],
transl=motion.root_trans[start:end],
pose2rot=False,
absolute_pose=True,
)
verts.append(out["vertices"])
return torch.cat(verts, dim=0)
def _compact_result(result: dict) -> dict:
# AMASS export only needs rotations/root translation. Keeping fitted vertices
# for every frame inflates VRAM pressure with no output benefit.
return {
"rotations": result["rotations"],
"root_translation": result["root_translation"],
"per_vertex_error": result["per_vertex_error"],
}
def _memory_snapshot(device: torch.device) -> str:
rss_gb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0 / 1024.0
if device.type != "cuda":
return f"rss_max={rss_gb:.2f}GB"
allocated = torch.cuda.memory_allocated(device) / 1024**3
reserved = torch.cuda.memory_reserved(device) / 1024**3
max_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
return (
f"rss_max={rss_gb:.2f}GB "
f"cuda_alloc={allocated:.2f}GB "
f"cuda_reserved={reserved:.2f}GB "
f"cuda_max_alloc={max_allocated:.2f}GB"
)
def _save_one(
out: Path,
result: dict,
fps: float,
betas: torch.Tensor,
inv: SMPLXInversion,
source_bvh: Path,
mean_error: float,
max_error: float,
args,
):
if args.output_format == "amass":
_save_amass_npz(out, result, fps, betas, inv, compressed=args.compressed)
elif args.compressed:
_save_legacy_full_npz(out, result, fps, betas, inv, source_bvh, mean_error, max_error)
else:
_save_smplx_npz(out, result, fps, betas, inv)
def _convert_one(
bvh: Path,
out: Path,
dataset_root: Path,
soma,
soma_to_smplx,
inv: SMPLXInversion,
betas: torch.Tensor,
args,
):
motion = _parse_bvh(bvh, inv.device, args.subsample, args.max_frames)
num_frames = int(motion.local_rot_mats.shape[0])
batch_size = args.batch_size or num_frames
results = []
total_fit_time = 0.0
for start in range(0, num_frames, batch_size):
end = min(start + batch_size, num_frames)
chunk_motion = BVHMotion(
path=motion.path,
local_rot_mats=motion.local_rot_mats[start:end],
root_trans=motion.root_trans[start:end],
fps=motion.fps,
joint_offsets=motion.joint_offsets,
parents=motion.parents,
)
soma_vertices = _forward_soma_bvh_resident(soma, chunk_motion, end - start)
with torch.no_grad():
target_smplx = soma_to_smplx(soma_vertices)
if inv.device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
result = inv.fit(
target_smplx,
body_iters=args.body_iters,
finger_iters=args.finger_iters,
full_iters=args.full_iters,
)
if inv.device.type == "cuda":
torch.cuda.synchronize()
total_fit_time += time.perf_counter() - t0
results.append(result)
result = _concat_results(results)
smplx_out = _smplx_forward_from_result(inv, result, betas)
with torch.no_grad():
# Recompute target in chunks only for official-forward error. This avoids
# storing transferred vertices for very long motions.
err_parts = []
for start in range(0, num_frames, batch_size):
end = min(start + batch_size, num_frames)
chunk_motion = BVHMotion(
path=motion.path,
local_rot_mats=motion.local_rot_mats[start:end],
root_trans=motion.root_trans[start:end],
fps=motion.fps,
joint_offsets=motion.joint_offsets,
parents=motion.parents,
)
soma_vertices = _forward_soma_bvh_resident(soma, chunk_motion, end - start)
target_smplx = soma_to_smplx(soma_vertices)
err_parts.append(torch.norm(smplx_out.vertices[start:end] - target_smplx, dim=-1).detach().cpu())
smplx_err = torch.cat(err_parts, dim=0)
mean_error = float(smplx_err.mean().item())
max_error = float(smplx_err.max().item())
if args.body_only:
_save_body_only_npz(out, result, motion.fps, betas, inv, bvh, mean_error, max_error)
elif args.compressed:
_save_full_compressed_npz(out, result, motion.fps, betas, inv, bvh, mean_error, max_error)
else:
_save_smplx_npz(out, result, motion.fps, betas, inv)
return {
"source_bvh": str(bvh),
"output_npz": str(out),
"frames": num_frames,
"fps": float(motion.fps),
"fit_seconds": total_fit_time,
"fit_fps": num_frames / max(total_fit_time, 1e-9),
"mean_error": mean_error,
"max_error": max_error,
}
def _make_groups(entries: list[tuple[int, Path, Path]], args) -> list[list[tuple[int, Path, Path]]]:
groups = []
group = []
for entry in entries:
group.append(entry)
if len(group) >= args.files_per_batch:
groups.append(group)
group = []
if group:
groups.append(group)
return groups
def _convert_group(
entries: list[tuple[int, Path, Path]],
dataset_root: Path,
soma,
soma_to_smplx,
inv: SMPLXInversion,
betas: torch.Tensor,
args,
):
profile = {
"parse": 0.0,
"soma_forward": 0.0,
"transfer": 0.0,
"fit": 0.0,
"smplx_forward_error": 0.0,
"save": 0.0,
}
t0 = time.perf_counter()
motions = [_parse_bvh(bvh, inv.device, args.subsample, args.max_frames) for _, bvh, _ in entries]
profile["parse"] += time.perf_counter() - t0
frame_counts = [int(m.local_rot_mats.shape[0]) for m in motions]
if len({round(float(m.fps), 6) for m in motions}) != 1:
raise ValueError("Grouped BVHs have mismatched FPS after subsampling.")
batch_motion = BVHMotion(
path=entries[0][1],
local_rot_mats=torch.cat([m.local_rot_mats for m in motions], dim=0),
root_trans=torch.cat([m.root_trans for m in motions], dim=0),
fps=motions[0].fps,
joint_offsets=motions[0].joint_offsets,
parents=motions[0].parents,
)
total_frames = int(batch_motion.local_rot_mats.shape[0])
batch_size = args.batch_size or total_frames
results = []
error_parts = []
total_fit_time = 0.0
for start in range(0, total_frames, batch_size):
end = min(start + batch_size, total_frames)
chunk_motion = BVHMotion(
path=batch_motion.path,
local_rot_mats=batch_motion.local_rot_mats[start:end],
root_trans=batch_motion.root_trans[start:end],
fps=batch_motion.fps,
joint_offsets=batch_motion.joint_offsets,
parents=batch_motion.parents,
)
if inv.device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
soma_vertices = _forward_soma_bvh_resident(soma, chunk_motion, end - start)
if inv.device.type == "cuda":
torch.cuda.synchronize()
profile["soma_forward"] += time.perf_counter() - t0
if inv.device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.no_grad():
target_smplx = soma_to_smplx(soma_vertices)
if inv.device.type == "cuda":
torch.cuda.synchronize()
profile["transfer"] += time.perf_counter() - t0
if inv.device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
result = inv.fit(
target_smplx,
body_iters=args.body_iters,
finger_iters=args.finger_iters,
full_iters=args.full_iters,
)
if inv.device.type == "cuda":
torch.cuda.synchronize()
fit_dt = time.perf_counter() - t0
total_fit_time += fit_dt
profile["fit"] += fit_dt
if args.skip_official_error:
error_parts.append(result["per_vertex_error"].detach().cpu())
results.append(_compact_result(result))
else:
if inv.device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
smplx_out = _smplx_forward_from_result(inv, result, betas)
with torch.no_grad():
error_parts.append(torch.norm(smplx_out.vertices - target_smplx, dim=-1).detach().cpu())
if inv.device.type == "cuda":
torch.cuda.synchronize()
profile["smplx_forward_error"] += time.perf_counter() - t0
results.append(_compact_result(result))
del result, soma_vertices, target_smplx
if "smplx_out" in locals():
del smplx_out
batch_result = _concat_results(results)
batch_errors = torch.cat(error_parts, dim=0)
rows = []
cursor = 0
for (idx, bvh, out), motion, frames in zip(entries, motions, frame_counts):
end = cursor + frames
result = _slice_result(batch_result, cursor, end)
err = batch_errors[cursor:end]
mean_error = float(err.mean().item())
max_error = float(err.max().item())
t0 = time.perf_counter()
_save_one(out, result, motion.fps, betas, inv, bvh, mean_error, max_error, args)
profile["save"] += time.perf_counter() - t0
rows.append(
{
"index": idx,
"source_bvh": str(bvh),
"output_npz": str(out),
"frames": frames,
"fps": float(motion.fps),
"fit_seconds": total_fit_time * (frames / max(total_frames, 1)),
"fit_fps": total_frames / max(total_fit_time, 1e-9),
"mean_error": mean_error,
"max_error": max_error,
"group_files": len(entries),
"group_frames": total_frames,
"profile": profile,
}
)
cursor = end
return rows
def main():
parser = argparse.ArgumentParser(description="Batch convert BONES-SEED SOMA BVHs to SMPL-X NPZs.")
parser.add_argument("--dataset-root", default="/home/ziro/workspace/experimental/bones-seed/soma_uniform")
parser.add_argument("--output-root", required=True)
parser.add_argument("--manifest", default=None)
parser.add_argument("--subsample", type=int, default=4)
parser.add_argument("--body-iters", type=int, default=2)
parser.add_argument("--finger-iters", type=int, default=0)
parser.add_argument("--full-iters", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--files-per-batch", type=int, default=8)
parser.add_argument("--max-frames", type=int, default=None)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--skip-existing", action="store_true")
parser.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bar.")
parser.add_argument("--profile", action="store_true", help="Print per-batch timing breakdown.")
parser.add_argument("--profile-memory", action="store_true", help="Print RSS/CUDA memory after each batch.")
parser.add_argument(
"--empty-cache-every",
type=int,
default=0,
help="Run gc.collect() and torch.cuda.empty_cache() every N grouped batches. 0 disables.",
)
parser.add_argument(
"--skip-official-error",
action="store_true",
help="Skip official SMPL-X forward error metric and use inverse-LBS fit error in the manifest.",
)
parser.add_argument("--output-format", choices=["amass", "legacy"], default="amass")
parser.add_argument("--body-only", action="store_true", help="Deprecated: AMASS output is body-only SMPL-X pose by default.")
parser.add_argument("--compressed", action="store_true", default=True, help="Use compressed NPZ output.")
parser.add_argument("--uncompressed", dest="compressed", action="store_false", help="Use np.savez instead of np.savez_compressed.")
parser.add_argument("--betas", default="cached")
parser.add_argument("--beta-fit-iters", type=int, default=20)
parser.add_argument("--beta-fit-lr", type=float, default=1.0)
parser.add_argument("--beta-fit-l2", type=float, default=0.0003)
parser.add_argument("--device", default="cuda:0")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
dataset_root = Path(args.dataset_root)
output_root = Path(args.output_root)
manifest = Path(args.manifest) if args.manifest else None
files = _iter_bvhs(dataset_root, args.limit)
print(f"BVHs: {len(files)}")
print(f"Dataset: {dataset_root}")
print(f"Output: {output_root}")
print(f"Device: {device}")
print(f"Subsample: {args.subsample}")
print(f"Files per batch: {args.files_per_batch}")
print(f"Frames per kernel batch: {args.batch_size}")
soma = _create_bones_soma(dataset_root, device)
soma_to_smplx = _make_soma_to_smplx_transfer(device)
inv = SMPLXInversion(device)
fixed_betas = _parse_betas_arg(None if args.betas == "fit" else args.betas, device)
if fixed_betas is not None:
betas = fixed_betas
print(f"Using fixed SMPL-X betas: {betas.detach().cpu().numpy()[0].round(4).tolist()}")
elif args.beta_fit_iters > 0:
betas = _fit_smplx_betas_to_bones_soma(
dataset_root,
device,
steps=args.beta_fit_iters,
lr=args.beta_fit_lr,
l2=args.beta_fit_l2,
)
else:
betas = torch.zeros(1, 10, device=device)
print("SMPL-X beta fitting disabled; using neutral betas.")
inv.prepare_identity(betas)
ok = 0
skipped = 0
failed = 0
start_all = time.perf_counter()
pending = []
for idx, bvh in enumerate(files, start=1):
out = _output_path(dataset_root, output_root, bvh)
if args.skip_existing and out.exists():
skipped += 1
continue
pending.append((idx, bvh, out))
groups = _make_groups(pending, args)
progress = tqdm(total=len(files), initial=skipped, unit="file", dynamic_ncols=True, disable=args.no_progress)
progress.set_postfix(ok=ok, skip=skipped, fail=failed)
for group_idx, group in enumerate(groups, start=1):
try:
rows = _convert_group(group, dataset_root, soma, soma_to_smplx, inv, betas, args)
ok += len(rows)
for row in rows:
row["total"] = len(files)
_append_manifest(manifest, row)
progress.update(len(rows))
progress.set_postfix(ok=ok, skip=skipped, fail=failed)
first = Path(rows[0]["source_bvh"]).name
last = Path(rows[-1]["source_bvh"]).name
mean_err = sum(row["mean_error"] for row in rows) / len(rows)
max_err = max(row["max_error"] for row in rows)
frames = sum(row["frames"] for row in rows)
progress.write(
f"[batch {rows[0]['index']}-{rows[-1]['index']}/{len(files)}] ok "
f"files={len(rows)} frames={frames} fit_fps={rows[0]['fit_fps']:.0f} "
f"err={mean_err:.5f}/{max_err:.5f} {first} ... {last}"
)
if args.profile:
prof = rows[0]["profile"]
total_measured = sum(prof.values())
progress.write(
"[profile] "
f"parse={prof['parse']:.3f}s "
f"soma={prof['soma_forward']:.3f}s "
f"transfer={prof['transfer']:.3f}s "
f"fit={prof['fit']:.3f}s "
f"smplx_err={prof['smplx_forward_error']:.3f}s "
f"save={prof['save']:.3f}s "
f"total_measured={total_measured:.3f}s"
)
if args.profile_memory:
progress.write(f"[memory] {_memory_snapshot(device)}")
except Exception as exc:
failed += len(group)
for idx, bvh, out in group:
row = {
"index": idx,
"total": len(files),
"source_bvh": str(bvh),
"output_npz": str(out),
"error": repr(exc),
}
_append_manifest(manifest, row)
progress.update(len(group))
progress.set_postfix(ok=ok, skip=skipped, fail=failed)
progress.write(f"[batch {group[0][0]}-{group[-1][0]}/{len(files)}] failed: {exc}")
finally:
if args.empty_cache_every > 0 and group_idx % args.empty_cache_every == 0:
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
dt = time.perf_counter() - start_all
print(f"Done ok={ok} skipped={skipped} failed={failed} seconds={dt:.1f}")
if __name__ == "__main__":
main()