File size: 4,540 Bytes
6d5047c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Feature normalization statistics (mean/std) for motion representations."""

import logging
import os
from typing import Optional

import numpy as np
import torch

log = logging.getLogger(__name__)


class Stats(torch.nn.Module):
    """Utility module for feature normalization statistics.

    Normalization follows:
    ``(data - mean) / sqrt(std**2 + eps)``
    """

    def __init__(
        self,
        folder: Optional[str] = None,
        load: bool = True,
        eps=1e-05,
    ):
        super().__init__()
        self.folder = folder
        self.eps = eps
        if folder is not None and load:
            self.load()

    def sliced(self, indices):
        """Return a new ``Stats`` object containing selected feature indices."""
        new_stats = Stats(folder=self.folder, load=False, eps=self.eps)
        new_stats.register_from_tensors(
            self.mean[..., indices].clone(),
            self.std[..., indices].clone(),
        )
        return new_stats

    def load(self):
        """Load ``mean.npy`` and ``std.npy`` from ``self.folder``."""
        mean_path = os.path.join(self.folder, "mean.npy")
        std_path = os.path.join(self.folder, "std.npy")
        if not os.path.exists(mean_path) or not os.path.exists(std_path):
            raise FileNotFoundError(
                f"Missing stats files in '{self.folder}'. Expected:\n"
                f"  - {mean_path}\n"
                f"  - {std_path}\n\n"
                "Make sure the checkpoint/stats have been downloaded and are mounted into the container.\n"
                "If you're using Docker Compose, run it from the repo root so `./:/workspace` mounts the correct directory."
            )

        mean = torch.from_numpy(np.load(mean_path))
        std = torch.from_numpy(np.load(std_path))
        self.register_from_tensors(mean, std)

    def register_from_tensors(self, mean: torch.Tensor, std: torch.Tensor):
        """Register mean/std tensors as non-persistent buffers."""
        self.register_buffer("mean", mean, persistent=False)
        self.register_buffer("std", std, persistent=False)

    def normalize(self, data: torch.Tensor) -> torch.Tensor:
        """Normalize data using the stored statistics."""
        mean = self.mean.to(device=data.device, dtype=data.dtype)
        std = self.std.to(device=data.device, dtype=data.dtype)
        # adjust std with eps
        return (data - mean) / torch.sqrt(std**2 + self.eps)

    def unnormalize(self, data: torch.Tensor) -> torch.Tensor:
        """Undo normalization using the stored statistics."""
        mean = self.mean.to(device=data.device, dtype=data.dtype)
        std = self.std.to(device=data.device, dtype=data.dtype)
        # adjust std with eps
        return data * torch.sqrt(std**2 + self.eps) + mean

    def is_loaded(self):
        """Return whether statistics are currently available."""
        return hasattr(self, "mean")

    def get_dim(self):
        """Return feature dimensionality."""
        return self.mean.shape[0]

    def save(
        self,
        folder: Optional[str] = None,
        mean: Optional[torch.Tensor] = None,
        std: Optional[torch.Tensor] = None,
    ):
        """Save statistics to ``folder`` as ``mean.npy`` and ``std.npy``."""
        if folder is None:
            folder = self.folder
            if folder is None:
                raise ValueError("No folder to save stats")

        if mean is None and std is None:
            try:
                mean = self.mean.cpu().numpy()
                std = self.std.cpu().numpy()
            except AttributeError:
                raise ValueError("Stats were not loaded")

        # don't override stats folder
        os.makedirs(folder, exist_ok=False)

        np.save(os.path.join(folder, "mean.npy"), mean)
        np.save(os.path.join(folder, "std.npy"), std)

    def __eq__(self, other):
        return (self.mean.cpu() == other.mean.cpu()).all() and (self.std.cpu() == other.std.cpu()).all()

    # should define a hash value for pytorch, as we defined __eq__
    def __hash__(self):
        # Convert mean and std to bytes for a consistent hash value
        mean_hash = hash(self.mean.detach().cpu().numpy().tobytes())
        std_hash = hash(self.std.detach().cpu().numpy().tobytes())
        return hash((mean_hash, std_hash))

    def __repr__(self):
        return f'Stats(folder="{self.folder}")'