File size: 10,092 Bytes
09d8e80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient clipping transformations.
Note that complex numbers are also supported, see
https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
"""
import chex
import jax
from jax import tree_util as jtu
import jax.numpy as jnp
from optax import tree_utils as otu
from optax._src import base
from optax._src import linear_algebra
from optax._src import numerics
def clip(max_delta: chex.Numeric) -> base.GradientTransformation:
"""Clips updates element-wise, to be in ``[-max_delta, +max_delta]``.
Args:
max_delta: The maximum absolute value for each element in the update.
Returns:
A `GradientTransformation` object.
"""
def update_fn(updates, state, params=None):
del params
return otu.tree_clip(updates, -max_delta, max_delta), state
return base.GradientTransformation(base.init_empty_state, update_fn)
def clip_by_block_rms(threshold: float) -> base.GradientTransformation:
"""Clips updates to a max rms for the gradient of each param vector or matrix.
A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix
(e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
Args:
threshold: The maximum rms for the gradient of each param vector or matrix.
Returns:
A `GradientTransformation` object.
"""
def update_fn(updates, state, params=None):
del params
def _clip_fn(u):
clip_denom = jnp.maximum(
1.0,
jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold)
return u / clip_denom
updates = jtu.tree_map(_clip_fn, updates)
return updates, state
return base.GradientTransformation(base.init_empty_state, update_fn)
def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:
"""Clips updates using their global norm.
References:
[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)
Args:
max_norm: The maximum global norm for an update.
Returns:
A `GradientTransformation` object.
"""
def update_fn(updates, state, params=None):
del params
g_norm = linear_algebra.global_norm(updates)
# TODO(b/163995078): revert back to the following (faster) implementation
# once analysed how it affects backprop through update (e.g. meta-gradients)
# g_norm = jnp.maximum(max_norm, g_norm)
# updates = jtu.tree_map(lambda t: (t / g_norm) * max_norm, updates)
trigger = jnp.squeeze(g_norm < max_norm)
chex.assert_shape(trigger, ()) # A scalar.
def clip_fn(t):
return jax.lax.select(trigger, t, (t / g_norm.astype(t.dtype)) * max_norm)
updates = jtu.tree_map(clip_fn, updates)
return updates, state
return base.GradientTransformation(base.init_empty_state, update_fn)
def per_example_global_norm_clip(
grads: list[chex.Array], l2_norm_clip: float
) -> tuple[list[chex.Array], jax.Array]:
"""Applies gradient clipping per-example using their global norm.
References:
[Abadi et al, 2016](https://arxiv.org/abs/1607.00133)
Args:
grads: flattened update; the function expects these to have a batch
dimension on the 0th axis.
l2_norm_clip: maximum L2 norm of the per-example gradients.
Returns:
A tuple containing sum of the clipped per-example grads, and the number of
per-example grads that were clipped.
"""
bsize = grads[0].shape[0]
if any(g.ndim == 0 or bsize != g.shape[0] for g in grads):
raise ValueError(
'Unlike other transforms, `per_example_global_norm_clip` expects'
' `grads` to have a batch dimension in the 0th axis.')
global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads)
divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
num_clipped = jnp.greater(divisors, 1.0).sum()
clipped_sum = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads]
return clipped_sum, num_clipped
def per_example_layer_norm_clip(
grads: list[chex.Array],
global_l2_norm_clip: float,
uniform: bool = True,
eps: float = 1e-8,
) -> tuple[list[chex.Array], list[chex.Array]]:
"""Applies gradient clipping per-example using per-layer norms.
References:
[McMahan et al, 2012](https://arxiv.org/abs/1710.06963)]
Args:
grads: flattened update; i.e. a list of gradients in which each item is
the gradient for one layer; the function expects these to have a batch
dimension on the 0th axis.
global_l2_norm_clip: overall L2 clip norm to use.
uniform: If `True`, per-layer clip norm is global_l2_norm_clip/sqrt(L),
where L is the number of layers. Otherwise, per-layer clip norm is
global_l2_norm_clip * sqrt(f), where f is the fraction of total model
parameters that are in this layer.
eps: Small positive value to add to norms to avoid possible division by
zero.
Let C = `global_l2_norm_clip value`. Then per-layer clipping is done as
follows:
(1) If `uniform` is `True`, each of the K layers has an individual clip
norm of C / sqrt(K).
(2) If `uniform` is `False`, each of the K layers has an individual clip
norm of C * sqrt(D_i / D) where D_i is the number of parameters in
layer i, and D is the total number of parameters in the model.
Returns:
A tuple containing sum of the clipped per-example grads and the number of
per-example grads that were clipped for each layer.
"""
bsize = grads[0].shape[0]
if any(g.ndim == 0 or bsize != g.shape[0] for g in grads):
raise ValueError(
'Unlike other transforms, `per_example_layer_norm_clip` expects'
' `grads` to have a batch dimension in the 0th axis; got shapes:'
f' {(g.shape for g in grads)}.'
)
num_layers = len(grads)
# Compute per-layer clip norms, based on whether we are using uniform
# variant or not.
if uniform:
# Create list of length `num_layers` of per-layer clip norm.
layer_clip_norms = (
global_l2_norm_clip * (1.0 / num_layers) ** 0.5,
) * num_layers
else:
total_params = sum(g[0].size for g in grads)
layer_clip_norms = tuple(
global_l2_norm_clip * (g[0].size / float(total_params)) ** 0.5
for g in grads
)
# Compute per-layer grad norms.
def map_layer_norm(grads_list):
return [jnp.linalg.norm(g, ord=None, axis=None) for g in grads_list]
layer_grad_norms_per_example = jax.vmap(map_layer_norm)(grads)
# Perform clipping.
divisors = (
tuple(
jnp.maximum(
layer_grad_norm / (layer_clip_norm + eps), 1.0
)
for layer_grad_norm, layer_clip_norm in zip(
layer_grad_norms_per_example, layer_clip_norms
)
)
)
num_clipped = [jnp.greater(divisor, 1.0).sum() for divisor in divisors]
clipped_sum = [
(g / jnp.expand_dims(d, axis=[i for i in range(1, g.ndim)])).sum(0)
for g, d in zip(grads, divisors)
]
return clipped_sum, num_clipped
def unitwise_norm(x: chex.Array) -> chex.Array:
"""Computes norms of each output unit separately."""
if jnp.squeeze(x).ndim <= 1: # Scalars and vectors
squared_norm = jnp.sum(numerics.abs_sq(x), keepdims=True)
# Note that this assumes parameters with a shape of length 3 are multihead
# linear parameters--if you wish to apply AGC to 1D convs, you may need
# to modify this line.
elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear
squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True)
elif x.ndim == 4: # Conv kernels of shape HWIO
squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2), keepdims=True)
else:
raise ValueError(
f'Expected parameter with shape in {1, 2, 3, 4}, got {x.shape}.')
chex.assert_is_broadcastable(squared_norm.shape, x.shape)
return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape)
def unitwise_clip(g_norm: chex.Array,
max_norm: chex.Array,
grad: chex.Array,
div_eps: float = 1e-6) -> chex.Array:
"""Applies gradient clipping unit-wise."""
# This little max(., div_eps) is distinct from the normal eps and just
# prevents division by zero. It technically should be impossible to engage.
clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps))
chex.assert_equal_shape((g_norm, max_norm, grad, clipped_grad))
return jnp.where(g_norm < max_norm, grad, clipped_grad)
def adaptive_grad_clip(clipping: float,
eps: float = 1e-3) -> base.GradientTransformation:
"""Clips updates to be at most ``clipping * parameter_norm``, unit-wise.
References:
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
Recognition Without Normalization. (https://arxiv.org/abs/2102.06171)
Args:
clipping: The maximum allowed ratio of update norm to parameter norm.
eps: An epsilon term to prevent clipping of zero-initialized params.
Returns:
A `GradientTransformation` object.
"""
def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
g_norm, p_norm = jtu.tree_map(unitwise_norm, (updates, params))
# Maximum allowable norm.
max_norm = jtu.tree_map(
lambda x: clipping * jnp.maximum(x, eps), p_norm)
# If grad norm > clipping * param_norm, rescale.
updates = jtu.tree_map(unitwise_clip, g_norm, max_norm, updates)
return updates, state
return base.GradientTransformation(base.init_empty_state, update_fn)
|