| """Training callbacks.""" | |
| from __future__ import annotations | |
| from typing import Callable | |
| def every_n_steps(n: int, fn: Callable[[int], None]) -> Callable[[int], None]: | |
| def _callback(step: int) -> None: | |
| if step % n == 0: | |
| fn(step) | |
| return _callback | |
| """Training callbacks.""" | |
| from __future__ import annotations | |
| from typing import Callable | |
| def every_n_steps(n: int, fn: Callable[[int], None]) -> Callable[[int], None]: | |
| def _callback(step: int) -> None: | |
| if step % n == 0: | |
| fn(step) | |
| return _callback | |