| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """A lookahead optimization wrapper.""" |
|
|
| from typing import NamedTuple, Union |
|
|
| from absl import logging |
| import jax |
| import jax.numpy as jnp |
| from optax._src import base |
|
|
|
|
| class LookaheadState(NamedTuple): |
| """State of the `GradientTransformation` returned by `lookahead`. |
| |
| Attributes: |
| fast_state: Optimizer state of the fast optimizer. |
| steps_since_sync: Number of fast optimizer steps taken since slow and fast |
| parameters were synchronized. |
| """ |
|
|
| fast_state: base.OptState |
| steps_since_sync: jnp.ndarray |
|
|
|
|
| class LookaheadParams(NamedTuple): |
| """Holds a pair of slow and fast parameters for the lookahead optimizer. |
| |
| Gradients should always be calculated with the fast parameters. The slow |
| parameters should be used for testing and inference as they generalize better. |
| See the reference for a detailed discussion. |
| |
| References: |
| [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) |
| |
| Attributes: |
| fast: Fast parameters. |
| slow: Slow parameters. |
| """ |
|
|
| fast: base.Params |
| slow: base.Params |
|
|
| @classmethod |
| def init_synced(cls, params: base.Params) -> 'LookaheadParams': |
| """Initialize a pair of synchronized lookahead parameters.""" |
| return cls(slow=params, fast=params) |
|
|
|
|
| def lookahead( |
| fast_optimizer: base.GradientTransformation, |
| sync_period: int, |
| slow_step_size: float, |
| reset_state: bool = False, |
| ) -> base.GradientTransformation: |
| """Lookahead optimizer. |
| |
| Performs steps with a fast optimizer and periodically updates a set of slow |
| parameters. Optionally resets the fast optimizer state after synchronization |
| by calling the init function of the fast optimizer. |
| |
| Updates returned by the lookahead optimizer should not be modified before they |
| are applied, otherwise fast and slow parameters are not synchronized |
| correctly. |
| |
| References: |
| [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) |
| |
| Args: |
| fast_optimizer: The optimizer to use in the inner loop of lookahead. |
| sync_period: Number of fast optimizer steps to take before synchronizing |
| parameters. Must be >= 1. |
| slow_step_size: Step size of the slow parameter updates. |
| reset_state: Whether to reset the optimizer state of the fast opimizer after |
| each synchronization. |
| |
| Returns: |
| A `GradientTransformation` with init and update functions. The updates |
| passed to the update function should be calculated using the fast lookahead |
| parameters only. |
| """ |
| if sync_period < 1: |
| raise ValueError('Synchronization period must be >= 1.') |
|
|
| def init_fn(params: base.Params) -> LookaheadState: |
| fast_params = getattr(params, 'fast', None) |
| if fast_params is None: |
| |
| |
| logging.warning( |
| '`params` has no attribute `fast`. Continuing by assuming that ' |
| 'only fast parameters were passed to lookahead init.' |
| ) |
| fast_params = params |
|
|
| return LookaheadState( |
| fast_state=fast_optimizer.init(fast_params), |
| steps_since_sync=jnp.zeros(shape=(), dtype=jnp.int32), |
| ) |
|
|
| def update_fn( |
| updates: base.Updates, state: LookaheadState, params: LookaheadParams |
| ) -> tuple[LookaheadParams, LookaheadState]: |
| updates, fast_state = fast_optimizer.update( |
| updates, state.fast_state, params.fast |
| ) |
|
|
| sync_next = state.steps_since_sync == (sync_period - 1) |
| updates = _lookahead_update(updates, sync_next, params, slow_step_size) |
| if reset_state: |
| |
| |
| initial_state = fast_optimizer.init(params.fast) |
| fast_state = jax.tree_util.tree_map( |
| lambda current, init: (1 - sync_next) * current + sync_next * init, |
| fast_state, |
| initial_state, |
| ) |
|
|
| steps_since_sync = (state.steps_since_sync + 1) % sync_period |
| return updates, LookaheadState(fast_state, steps_since_sync) |
|
|
| return base.GradientTransformation(init_fn, update_fn) |
|
|
|
|
| def _lookahead_update( |
| updates: base.Updates, |
| sync_next: Union[bool, jax.Array], |
| params: LookaheadParams, |
| slow_step_size: float, |
| ) -> LookaheadParams: |
| """Returns the updates corresponding to one lookahead step. |
| |
| References: |
| [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) |
| |
| Args: |
| updates: Updates returned by the fast optimizer. |
| sync_next: Wether fast and slow parameters should be synchronized after the |
| fast optimizer step. |
| params: Current fast and slow parameters as `LookaheadParams` object. |
| slow_step_size: Step size of the slow optimizer. |
| |
| Returns: |
| The updates for the lookahead parameters. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| last_difference = jax.tree_util.tree_map( |
| lambda f, u, s: f + u - s, params.fast, updates, params.slow |
| ) |
| slow_updates = jax.tree_util.tree_map( |
| lambda diff: slow_step_size * sync_next * diff, last_difference |
| ) |
| fast_updates = jax.tree_util.tree_map( |
| lambda up, diff: up - sync_next * (1 - slow_step_size) * diff, |
| updates, |
| last_difference, |
| ) |
|
|
| return LookaheadParams(fast=fast_updates, slow=slow_updates) |
|
|