| """Running mean/std normalization for RL observation vectors.""" |
| from __future__ import annotations |
|
|
| import numpy as np |
|
|
|
|
| class RunningNormalizer: |
| """Welford online algorithm for running mean/variance, used to normalize observations.""" |
|
|
| def __init__(self, shape: tuple[int, ...], clip: float = 10.0, epsilon: float = 1e-8) -> None: |
| self.mean = np.zeros(shape, dtype=np.float64) |
| self.var = np.ones(shape, dtype=np.float64) |
| self.count = 0 |
| self.clip = clip |
| self.epsilon = epsilon |
|
|
| def update(self, x: np.ndarray) -> None: |
| """Update running statistics with a single observation or batch.""" |
| if x.ndim == 1: |
| x = x.reshape(1, -1) |
| batch_mean = x.mean(axis=0) |
| batch_var = x.var(axis=0) |
| batch_count = x.shape[0] |
| self._update_from_moments(batch_mean, batch_var, batch_count) |
|
|
| def _update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None: |
| delta = batch_mean - self.mean |
| total_count = self.count + batch_count |
| new_mean = self.mean + delta * batch_count / max(total_count, 1) |
| m_a = self.var * self.count |
| m_b = batch_var * batch_count |
| m2 = m_a + m_b + np.square(delta) * self.count * batch_count / max(total_count, 1) |
| self.mean = new_mean |
| self.var = m2 / max(total_count, 1) |
| self.count = total_count |
|
|
| def normalize(self, x: np.ndarray) -> np.ndarray: |
| """Normalize an observation using running statistics.""" |
| return np.clip( |
| (x - self.mean) / np.sqrt(self.var + self.epsilon), |
| -self.clip, |
| self.clip, |
| ).astype(np.float32) |
|
|
| def state_dict(self) -> dict: |
| return {"mean": self.mean.copy(), "var": self.var.copy(), "count": self.count} |
|
|
| def load_state_dict(self, state: dict) -> None: |
| self.mean = state["mean"].copy() |
| self.var = state["var"].copy() |
| self.count = state["count"] |
|
|