"""Training callbacks.""" from __future__ import annotations from typing import Callable def every_n_steps(n: int, fn: Callable[[int], None]) -> Callable[[int], None]: def _callback(step: int) -> None: if step % n == 0: fn(step) return _callback