File size: 7,917 Bytes
6d5047c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Smooth root trajectory: ADMM-based smoother with margin constraints and get_smooth_root_pos helper."""
import math
import numpy as np
import torch
from scipy import sparse
from scipy.sparse.linalg import splu
from kimodo.tools import ensure_batched
class TrajectorySmoother:
"""Modify trajectories to hit target values while respecting soft constraints.
This smoother keeps the trajectory close to the original positions while minimizing
accelerations. Targets are enforced at specified frames via soft constraints.
"""
def __init__(
self,
margins,
pos_weight=0.0,
loop=False,
admm_iters=100,
alpha_overrelax=1.0,
circle_project=False,
):
"""Initialize the TrajectorySmoother.
Args:
margins: Array of margin values for each frame.
margins[i] < 0: unconstrained
margins[i] == 0: pinned on this frame
margins[i] > 0: can deviate within the margin
pos_weight: Weight for position preservation
loop: Whether the trajectory should loop
admm_iters: Number of ADMM iterations
"""
self.pos_weight = pos_weight
self.admm_iters = admm_iters
self.alpha_overrelax = alpha_overrelax
self.circle_project = circle_project
N = len(margins)
# Store margin information as numpy arrays
self.margin_vals = margins
# Build acceleration matrix A
a_data = []
a_rows = []
a_cols = []
for i in range(1, N - 1):
scale = 1.0
a_data.extend([-scale, 2.0 * scale, -scale])
a_rows.extend([i, i, i])
a_cols.extend([i - 1, i, i + 1])
if loop:
# Add periodic accelerations
scale = 1.0
a_data.extend([-scale, 2.0 * scale, -scale])
a_rows.extend([0, 0, 0])
a_cols.extend([N - 1, 0, 1])
scale = 1.0
a_data.extend([-scale, 2.0 * scale, -scale])
a_rows.extend([N - 1, N - 1, N - 1])
a_cols.extend([N - 2, N - 1, 0])
A = sparse.csr_matrix((a_data, (a_rows, a_cols)), shape=(N, N))
# Build identity matrix
identity_matrix = sparse.eye(N)
# Build system matrix M
M = pos_weight * identity_matrix + A.T @ A
# Calculate ADMM step size
diag_max = max(abs(M.diagonal()))
self.admm_stepsize = 0.25 * np.sqrt(diag_max)
M = M + self.admm_stepsize * identity_matrix
self.system_lu = splu(M.tocsc())
def smooth(self, targets, x0):
"""Interpolate between reference positions while satisfying constraints.
Args:
observations: Target positions for constrained frames (numpy array)
ref_positions: Reference positions defining original shape
(numpy array)
Returns:
Interpolated positions (numpy array)
"""
x_target = targets.copy()
x = x0.copy()
z = np.zeros_like(x)
u = np.zeros_like(x)
for _ in range(self.admm_iters):
self.z_update(z, x, x_target, u)
self.u_update(u, x, z)
self.x_update(x, z, u, x_target)
return x
def x_update(self, x, z, u, x_t):
"""Update x in the ADMM iteration."""
# x = (wp * I + A^T A + p I)^-1 (wp * x_orig + p (z - u))
r = self.pos_weight * x_t + self.admm_stepsize * (z - u)
x[:] = self.system_lu.solve(r)
def z_update(self, z, x, z_t, u):
"""Update z in the ADMM iteration using vectorized operations."""
# Compute the difference from target for all margin locations at once
z[:] = x + u - z_t
# Check if we need to project back to margin
z_diff_norms = np.linalg.norm(z, axis=1)
mask = z_diff_norms > self.margin_vals
if np.any(mask):
scale_factors = self.margin_vals[mask] / z_diff_norms[mask]
z[mask] *= scale_factors[:, np.newaxis]
# Add back the target
z[:] += z_t
if self.circle_project:
z[:] = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1.0e-6)
def u_update(self, u, x, z):
"""Update u in the ADMM iteration using vectorized operations."""
u[:] += self.alpha_overrelax * (x - z)
def smooth_signal(x, margins, pos_weight=0, alpha_overrelax=1.8, admm_iters=500, circle_project=False):
"""Multigrid trajectory smoothing with margin constraints.
Args:
x: Input trajectory ``[T, D]`` as a NumPy array.
margins: Allowed radius around each target frame ``[T]``.
pos_weight: Weight for staying close to the original signal.
alpha_overrelax: ADMM over-relaxation coefficient.
admm_iters: ADMM iterations per multigrid level.
circle_project: If ``True``, project each vector to the unit sphere.
Returns:
Smoothed trajectory of shape ``[T, D]``.
"""
x_smoothed = x.copy()
x_smoothed[:] = x.mean(axis=0, keepdims=True)
# smooth the signal, multigrid style by starting out coarse,
# doubling the resolution and repeating until we're at the full
# resolution, using the previous result as the initial guess.
levels = int(math.floor(math.log2(len(x))))
levels = max(levels - 4, 1)
stepsize = 2**levels
while True:
# smooth signals at this level:
num_steps = len(x_smoothed[::stepsize])
smoother = TrajectorySmoother(
margins=margins[::stepsize],
pos_weight=pos_weight,
alpha_overrelax=alpha_overrelax,
admm_iters=admm_iters,
circle_project=circle_project,
)
x_smoothed[::stepsize] = smoother.smooth(x[::stepsize], x_smoothed[::stepsize])
# interpolate to next level:
next_stepsize = stepsize // 2
num_interleaved = len(x_smoothed[next_stepsize::stepsize])
if num_interleaved == num_steps:
# linearly extrapolate the last value if we have to:
x_smoothed[next_stepsize::stepsize][-1] = (
x_smoothed[::stepsize][-1] + (x_smoothed[::stepsize][-1] - x_smoothed[::stepsize][-2]) / 2
)
num_interleaved = num_interleaved - 1
# linearly interpolate the remaining values:
x_smoothed[next_stepsize::stepsize][:num_interleaved] = (
x_smoothed[::stepsize][:-1] + x_smoothed[::stepsize][1:]
) / 2
if stepsize == 1:
break
stepsize //= 2
return x_smoothed
@ensure_batched(hip_translations=3)
def get_smooth_root_pos(hip_translations):
"""Smooth root trajectory in the ground plane while preserving height.
Args:
hip_translations: Root translations ``[B, T, 3]``.
Returns:
Smoothed root translations ``[B, T, 3]`` where ``x/z`` are smoothed and
``y`` remains unchanged.
"""
root_translations_xz = hip_translations[..., [0, 2]]
root_translations_y = hip_translations[..., [1]]
batch_size, nframes = root_translations_xz.shape[:2]
margins = np.full(root_translations_xz.shape[1], 0.06)
root_translations_smoothed_xz = []
for batch in range(batch_size):
root_translations_smoothed_xz.append(
smooth_signal(root_translations_xz[batch].detach().cpu().numpy(), margins)[None]
)
root_translations_smoothed_xz = torch.tensor(np.concatenate(root_translations_smoothed_xz))
root_translations = torch.cat(
[
root_translations_smoothed_xz.to(root_translations_y.device),
root_translations_y,
],
dim=-1,
)[..., [0, 2, 1]]
return root_translations
|