File size: 1,993 Bytes
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Running mean/std normalization for RL observation vectors."""
from __future__ import annotations

import numpy as np


class RunningNormalizer:
    """Welford online algorithm for running mean/variance, used to normalize observations."""

    def __init__(self, shape: tuple[int, ...], clip: float = 10.0, epsilon: float = 1e-8) -> None:
        self.mean = np.zeros(shape, dtype=np.float64)
        self.var = np.ones(shape, dtype=np.float64)
        self.count = 0
        self.clip = clip
        self.epsilon = epsilon

    def update(self, x: np.ndarray) -> None:
        """Update running statistics with a single observation or batch."""
        if x.ndim == 1:
            x = x.reshape(1, -1)
        batch_mean = x.mean(axis=0)
        batch_var = x.var(axis=0)
        batch_count = x.shape[0]
        self._update_from_moments(batch_mean, batch_var, batch_count)

    def _update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None:
        delta = batch_mean - self.mean
        total_count = self.count + batch_count
        new_mean = self.mean + delta * batch_count / max(total_count, 1)
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        m2 = m_a + m_b + np.square(delta) * self.count * batch_count / max(total_count, 1)
        self.mean = new_mean
        self.var = m2 / max(total_count, 1)
        self.count = total_count

    def normalize(self, x: np.ndarray) -> np.ndarray:
        """Normalize an observation using running statistics."""
        return np.clip(
            (x - self.mean) / np.sqrt(self.var + self.epsilon),
            -self.clip,
            self.clip,
        ).astype(np.float32)

    def state_dict(self) -> dict:
        return {"mean": self.mean.copy(), "var": self.var.copy(), "count": self.count}

    def load_state_dict(self, state: dict) -> None:
        self.mean = state["mean"].copy()
        self.var = state["var"].copy()
        self.count = state["count"]