Spaces:
Runtime error
Runtime error
File size: 4,540 Bytes
6d5047c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Feature normalization statistics (mean/std) for motion representations."""
import logging
import os
from typing import Optional
import numpy as np
import torch
log = logging.getLogger(__name__)
class Stats(torch.nn.Module):
"""Utility module for feature normalization statistics.
Normalization follows:
``(data - mean) / sqrt(std**2 + eps)``
"""
def __init__(
self,
folder: Optional[str] = None,
load: bool = True,
eps=1e-05,
):
super().__init__()
self.folder = folder
self.eps = eps
if folder is not None and load:
self.load()
def sliced(self, indices):
"""Return a new ``Stats`` object containing selected feature indices."""
new_stats = Stats(folder=self.folder, load=False, eps=self.eps)
new_stats.register_from_tensors(
self.mean[..., indices].clone(),
self.std[..., indices].clone(),
)
return new_stats
def load(self):
"""Load ``mean.npy`` and ``std.npy`` from ``self.folder``."""
mean_path = os.path.join(self.folder, "mean.npy")
std_path = os.path.join(self.folder, "std.npy")
if not os.path.exists(mean_path) or not os.path.exists(std_path):
raise FileNotFoundError(
f"Missing stats files in '{self.folder}'. Expected:\n"
f" - {mean_path}\n"
f" - {std_path}\n\n"
"Make sure the checkpoint/stats have been downloaded and are mounted into the container.\n"
"If you're using Docker Compose, run it from the repo root so `./:/workspace` mounts the correct directory."
)
mean = torch.from_numpy(np.load(mean_path))
std = torch.from_numpy(np.load(std_path))
self.register_from_tensors(mean, std)
def register_from_tensors(self, mean: torch.Tensor, std: torch.Tensor):
"""Register mean/std tensors as non-persistent buffers."""
self.register_buffer("mean", mean, persistent=False)
self.register_buffer("std", std, persistent=False)
def normalize(self, data: torch.Tensor) -> torch.Tensor:
"""Normalize data using the stored statistics."""
mean = self.mean.to(device=data.device, dtype=data.dtype)
std = self.std.to(device=data.device, dtype=data.dtype)
# adjust std with eps
return (data - mean) / torch.sqrt(std**2 + self.eps)
def unnormalize(self, data: torch.Tensor) -> torch.Tensor:
"""Undo normalization using the stored statistics."""
mean = self.mean.to(device=data.device, dtype=data.dtype)
std = self.std.to(device=data.device, dtype=data.dtype)
# adjust std with eps
return data * torch.sqrt(std**2 + self.eps) + mean
def is_loaded(self):
"""Return whether statistics are currently available."""
return hasattr(self, "mean")
def get_dim(self):
"""Return feature dimensionality."""
return self.mean.shape[0]
def save(
self,
folder: Optional[str] = None,
mean: Optional[torch.Tensor] = None,
std: Optional[torch.Tensor] = None,
):
"""Save statistics to ``folder`` as ``mean.npy`` and ``std.npy``."""
if folder is None:
folder = self.folder
if folder is None:
raise ValueError("No folder to save stats")
if mean is None and std is None:
try:
mean = self.mean.cpu().numpy()
std = self.std.cpu().numpy()
except AttributeError:
raise ValueError("Stats were not loaded")
# don't override stats folder
os.makedirs(folder, exist_ok=False)
np.save(os.path.join(folder, "mean.npy"), mean)
np.save(os.path.join(folder, "std.npy"), std)
def __eq__(self, other):
return (self.mean.cpu() == other.mean.cpu()).all() and (self.std.cpu() == other.std.cpu()).all()
# should define a hash value for pytorch, as we defined __eq__
def __hash__(self):
# Convert mean and std to bytes for a consistent hash value
mean_hash = hash(self.mean.detach().cpu().numpy().tobytes())
std_hash = hash(self.std.detach().cpu().numpy().tobytes())
return hash((mean_hash, std_hash))
def __repr__(self):
return f'Stats(folder="{self.folder}")'
|