| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Tests for optax.transforms._clipping.""" |
|
|
| from absl.testing import absltest |
| import chex |
| import jax |
| import jax.numpy as jnp |
|
|
| from optax._src import linear_algebra |
| from optax.transforms import _clipping |
|
|
|
|
| STEPS = 50 |
| LR = 1e-2 |
|
|
|
|
| class ClippingTest(absltest.TestCase): |
|
|
| def setUp(self): |
| super().setUp() |
| self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) |
| self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) |
|
|
| def test_clip(self): |
| updates = self.per_step_updates |
| |
| clipper = _clipping.clip(1e6) |
| clipped_updates, _ = clipper.update(updates, None) |
| chex.assert_trees_all_close(clipped_updates, clipped_updates) |
| |
| clipper = _clipping.clip(1.) |
| clipped_updates, _ = clipper.update(updates, None) |
| chex.assert_trees_all_close( |
| clipped_updates, jax.tree_util.tree_map(jnp.ones_like, updates)) |
|
|
| def test_clip_by_block_rms(self): |
| rmf_fn = lambda t: jnp.sqrt(jnp.mean(t**2)) |
| updates = self.per_step_updates |
| for i in range(1, STEPS + 1): |
| clipper = _clipping.clip_by_block_rms(1. / i) |
| |
| updates, _ = clipper.update(updates, None) |
| self.assertAlmostEqual(rmf_fn(updates[0]), 1. / i) |
| self.assertAlmostEqual(rmf_fn(updates[1]), 1. / i) |
| |
| updates_step, _ = clipper.update(self.per_step_updates, None) |
| chex.assert_trees_all_close(updates, updates_step) |
|
|
| def test_clip_by_global_norm(self): |
| updates = self.per_step_updates |
| for i in range(1, STEPS + 1): |
| clipper = _clipping.clip_by_global_norm(1. / i) |
| |
| updates, _ = clipper.update(updates, None) |
| self.assertAlmostEqual( |
| linear_algebra.global_norm(updates), 1. / i, places=6) |
| |
| updates_step, _ = clipper.update(self.per_step_updates, None) |
| chex.assert_trees_all_close(updates, updates_step) |
|
|
| def test_adaptive_grad_clip(self): |
| updates = self.per_step_updates |
| params = self.init_params |
| for i in range(1, STEPS + 1): |
| clip_r = 1. / i |
| clipper = _clipping.adaptive_grad_clip(clip_r) |
|
|
| |
| updates, _ = clipper.update(updates, None, params) |
| u_norm, p_norm = jax.tree_util.tree_map( |
| _clipping.unitwise_norm, (updates, params)) |
| cmp = jax.tree_util.tree_map( |
| lambda u, p, c=clip_r: u - c * p < 1e-6, u_norm, p_norm) |
| for leaf in jax.tree_util.tree_leaves(cmp): |
| self.assertTrue(leaf.all()) |
|
|
| |
| updates_step, _ = clipper.update(self.per_step_updates, None, params) |
| chex.assert_trees_all_close(updates, updates_step) |
|
|
| def test_per_example_layer_norm_clip(self): |
| |
| |
| |
| grads_flat = [ |
| jnp.array([[0.5], [1.5], [-2.0], [3.0]]), |
| jnp.ones([4, 3, 3, 2], dtype=jnp.float32), |
| ] |
|
|
| with self.subTest(name='Uniform Variant'): |
| sum_clipped_grads, num_clipped = _clipping.per_example_layer_norm_clip( |
| grads_flat, global_l2_norm_clip=jnp.sqrt(2), uniform=True |
| ) |
|
|
| |
| |
| |
| |
| self.assertAlmostEqual(sum_clipped_grads[0], 1.5) |
| for element in sum_clipped_grads[1].flatten(): |
| self.assertAlmostEqual(element, 4 / jnp.sqrt(18), places=4) |
|
|
| |
| |
| self.assertEqual(num_clipped[0], 3) |
| self.assertEqual(num_clipped[1], 4) |
|
|
| with self.subTest(name='Scaled Variant'): |
| sum_clipped_grads, num_clipped = _clipping.per_example_layer_norm_clip( |
| grads_flat, global_l2_norm_clip=jnp.sqrt(19), uniform=False |
| ) |
|
|
| |
| |
| |
| |
| |
| self.assertAlmostEqual(sum_clipped_grads[0], 1.5) |
| for element in sum_clipped_grads[1].flatten(): |
| self.assertAlmostEqual(element, 4.0) |
|
|
| |
| |
| |
| self.assertEqual(num_clipped[0], 3) |
| self.assertEqual(num_clipped[1], 0) |
|
|
|
|
| if __name__ == '__main__': |
| absltest.main() |
|
|