File size: 6,052 Bytes
09d8e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optax.transforms._clipping."""

from absl.testing import absltest
import chex
import jax
import jax.numpy as jnp

from optax._src import linear_algebra
from optax.transforms import _clipping


STEPS = 50
LR = 1e-2


class ClippingTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
    self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

  def test_clip(self):
    updates = self.per_step_updates
    # For a sufficiently high delta the update should not be changed.
    clipper = _clipping.clip(1e6)
    clipped_updates, _ = clipper.update(updates, None)
    chex.assert_trees_all_close(clipped_updates, clipped_updates)
    # Clipping at delta=1 should make all updates exactly 1.
    clipper = _clipping.clip(1.)
    clipped_updates, _ = clipper.update(updates, None)
    chex.assert_trees_all_close(
        clipped_updates, jax.tree_util.tree_map(jnp.ones_like, updates))

  def test_clip_by_block_rms(self):
    rmf_fn = lambda t: jnp.sqrt(jnp.mean(t**2))
    updates = self.per_step_updates
    for i in range(1, STEPS + 1):
      clipper = _clipping.clip_by_block_rms(1. / i)
      # Check that the clipper actually works and block rms is <= threshold
      updates, _ = clipper.update(updates, None)
      self.assertAlmostEqual(rmf_fn(updates[0]), 1. / i)
      self.assertAlmostEqual(rmf_fn(updates[1]), 1. / i)
      # Check that continuously clipping won't cause numerical issues.
      updates_step, _ = clipper.update(self.per_step_updates, None)
      chex.assert_trees_all_close(updates, updates_step)

  def test_clip_by_global_norm(self):
    updates = self.per_step_updates
    for i in range(1, STEPS + 1):
      clipper = _clipping.clip_by_global_norm(1. / i)
      # Check that the clipper actually works and global norm is <= max_norm
      updates, _ = clipper.update(updates, None)
      self.assertAlmostEqual(
          linear_algebra.global_norm(updates), 1. / i, places=6)
      # Check that continuously clipping won't cause numerical issues.
      updates_step, _ = clipper.update(self.per_step_updates, None)
      chex.assert_trees_all_close(updates, updates_step)

  def test_adaptive_grad_clip(self):
    updates = self.per_step_updates
    params = self.init_params
    for i in range(1, STEPS + 1):
      clip_r = 1. / i
      clipper = _clipping.adaptive_grad_clip(clip_r)

      # Check that the clipper actually works and upd_norm is < c * param_norm.
      updates, _ = clipper.update(updates, None, params)
      u_norm, p_norm = jax.tree_util.tree_map(
          _clipping.unitwise_norm, (updates, params))
      cmp = jax.tree_util.tree_map(
          lambda u, p, c=clip_r: u - c * p < 1e-6, u_norm, p_norm)
      for leaf in jax.tree_util.tree_leaves(cmp):
        self.assertTrue(leaf.all())

      # Check that continuously clipping won't cause numerical issues.
      updates_step, _ = clipper.update(self.per_step_updates, None, params)
      chex.assert_trees_all_close(updates, updates_step)

  def test_per_example_layer_norm_clip(self):
    # Test data for a model with two layers and a batch size of 4. The
    # 0th layer has one parameter (shape (1)), and the 1st layer has shape
    # (3, 3, 2).
    grads_flat = [
        jnp.array([[0.5], [1.5], [-2.0], [3.0]]),
        jnp.ones([4, 3, 3, 2], dtype=jnp.float32),
    ]

    with self.subTest(name='Uniform Variant'):
      sum_clipped_grads, num_clipped = _clipping.per_example_layer_norm_clip(
          grads_flat, global_l2_norm_clip=jnp.sqrt(2), uniform=True
      )

      # For the uniform variant, with global_l2_norm_clip=sqrt(2), the per-layer
      # clip norm is 1.0. Thus the per-example per-layer clipped grads are
      # [[0.5], [1.0], [-1.0], [1.0]] and [1 / sqrt(18) ... ]. The sum of
      # these over the 4 input gradients are [1.5] and [4 / sqrt(18) ...].
      self.assertAlmostEqual(sum_clipped_grads[0], 1.5)
      for element in sum_clipped_grads[1].flatten():
        self.assertAlmostEqual(element, 4 / jnp.sqrt(18), places=4)

      # The three values in grads_flat[0] with magnitude > 1.0 are clipped, as
      # are all four values in grads_flat[1].
      self.assertEqual(num_clipped[0], 3)
      self.assertEqual(num_clipped[1], 4)

    with self.subTest(name='Scaled Variant'):
      sum_clipped_grads, num_clipped = _clipping.per_example_layer_norm_clip(
          grads_flat, global_l2_norm_clip=jnp.sqrt(19), uniform=False
      )

      # For the scaled variant, with global_l2_norm_clip=sqrt(19), the per-layer
      # clip norm for the 0th layer is 1.0, and the per-layer clip norm for
      # the 1st layer is sqrt(18). Thus the per-example per-layer clipped grads
      # are [[0.5], [1.0], [-1.0], [1.0]] and [[1.0)] ... ]. The sum of
      # these over the 4 input gradients are [1.5] and [4.0 ...].
      self.assertAlmostEqual(sum_clipped_grads[0], 1.5)
      for element in sum_clipped_grads[1].flatten():
        self.assertAlmostEqual(element, 4.0)

      # The three values in grads_flat[0] with magnitude > 1.0 are clipped. The
      # grad norms for grads_flat[1] are all equal to the per-layer clip norm,
      # so none of these grads are clipped.
      self.assertEqual(num_clipped[0], 3)
      self.assertEqual(num_clipped[1], 0)


if __name__ == '__main__':
  absltest.main()