| |
| |
| |
|
|
| import copy |
| import math |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from src.efficientvit.models.utils import is_parallel |
|
|
| __all__ = ["EMA"] |
|
|
|
|
| def update_ema( |
| ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float |
| ) -> None: |
| for k, v in ema.state_dict().items(): |
| if v.dtype.is_floating_point: |
| v -= (1.0 - decay) * (v - new_state_dict[k].detach()) |
|
|
|
|
| class EMA: |
| def __init__(self, model: nn.Module, decay: float, warmup_steps=2000): |
| self.shadows = copy.deepcopy( |
| model.module if is_parallel(model) else model |
| ).eval() |
| self.decay = decay |
| self.warmup_steps = warmup_steps |
|
|
| for p in self.shadows.parameters(): |
| p.requires_grad = False |
|
|
| def step(self, model: nn.Module, global_step: int) -> None: |
| with torch.no_grad(): |
| msd = (model.module if is_parallel(model) else model).state_dict() |
| update_ema( |
| self.shadows, |
| msd, |
| self.decay * (1 - math.exp(-global_step / self.warmup_steps)), |
| ) |
|
|
| def state_dict(self) -> dict[float, dict[str, torch.Tensor]]: |
| return {self.decay: self.shadows.state_dict()} |
|
|
| def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None: |
| for decay in state_dict: |
| if decay == self.decay: |
| self.shadows.load_state_dict(state_dict[decay]) |
|
|