File size: 6,397 Bytes
09d8e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for optax.projections."""

from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
import jax.numpy as jnp
import numpy as np
from optax import projections as proj
import optax.tree_utils as otu


def projection_simplex_jacobian(projection):
  """Theoretical expression for the Jacobian of projection_simplex."""
  support = (projection > 0).astype(jnp.int32)
  cardinality = jnp.count_nonzero(support)
  return jnp.diag(support) - jnp.outer(support, support) / cardinality


class ProjectionsTest(parameterized.TestCase):

  def test_projection_non_negative(self):
    with self.subTest('with an array'):
      x = jnp.array([-1.0, 2.0, 3.0])
      expected = jnp.array([0, 2.0, 3.0])
      np.testing.assert_array_equal(proj.projection_non_negative(x), expected)

    with self.subTest('with a tuple'):
      np.testing.assert_array_equal(
          proj.projection_non_negative((x, x)), (expected, expected)
      )

    with self.subTest('with nested pytree'):
      tree_x = (-1.0, {'k1': 1.0, 'k2': (1.0, 1.0)}, 1.0)
      tree_expected = (0.0, {'k1': 1.0, 'k2': (1.0, 1.0)}, 1.0)
      chex.assert_trees_all_equal(
          proj.projection_non_negative(tree_x), tree_expected
      )

  def test_projection_box(self):
    with self.subTest('lower and upper are scalars'):
      lower, upper = 0.0, 2.0
      x = jnp.array([-1.0, 2.0, 3.0])
      expected = jnp.array([0, 2.0, 2.0])
      np.testing.assert_array_equal(
          proj.projection_box(x, lower, upper), expected
      )

    with self.subTest('lower and upper values are arrays'):
      lower_arr = jnp.ones(len(x)) * lower
      upper_arr = jnp.ones(len(x)) * upper
      np.testing.assert_array_equal(
          proj.projection_box(x, lower_arr, upper_arr), expected
      )

    with self.subTest('lower and upper are tuples of arrays'):
      lower_tuple = (lower, lower)
      upper_tuple = (upper, upper)
      chex.assert_trees_all_equal(
          proj.projection_box((x, x), lower_tuple, upper_tuple),
          (expected, expected),
      )

    with self.subTest('lower and upper are pytrees'):
      tree = (-1.0, {'k1': 2.0, 'k2': (2.0, 3.0)}, 3.0)
      expected = (0.0, {'k1': 2.0, 'k2': (2.0, 2.0)}, 2.0)
      lower_tree = (0.0, {'k1': 0.0, 'k2': (0.0, 0.0)}, 0.0)
      upper_tree = (2.0, {'k1': 2.0, 'k2': (2.0, 2.0)}, 2.0)
      chex.assert_trees_all_equal(
          proj.projection_box(tree, lower_tree, upper_tree), expected
      )

  def test_projection_hypercube(self):
    x = jnp.array([-1.0, 2.0, 0.5])

    with self.subTest('with default scale'):
      expected = jnp.array([0, 1.0, 0.5])
      np.testing.assert_array_equal(proj.projection_hypercube(x), expected)

    with self.subTest('with scalar scale'):
      expected = jnp.array([0, 0.8, 0.5])
      np.testing.assert_array_equal(proj.projection_hypercube(x, 0.8), expected)

    with self.subTest('with array scales'):
      scales = jnp.ones(len(x)) * 0.8
      np.testing.assert_array_equal(
          proj.projection_hypercube(x, scales), expected
      )

  @parameterized.parameters(1.0, 0.8)
  def test_projection_simplex_array(self, scale):
    rng = np.random.RandomState(0)
    x = rng.randn(50).astype(np.float32)
    p = proj.projection_simplex(x, scale)

    np.testing.assert_almost_equal(jnp.sum(p), scale, decimal=4)
    self.assertTrue(jnp.all(0 <= p))
    self.assertTrue(jnp.all(p <= scale))

  @parameterized.parameters(1.0, 0.8)
  def test_projection_simplex_pytree(self, scale):
    pytree = {'w': jnp.array([2.5, 3.2]), 'b': 0.5}
    new_pytree = proj.projection_simplex(pytree, scale)
    np.testing.assert_almost_equal(otu.tree_sum(new_pytree), scale, decimal=4)

  @parameterized.parameters(1.0, 0.8)
  def test_projection_simplex_edge_case(self, scale):
    p = proj.projection_simplex(jnp.array([0.0, 0.0, -jnp.inf]), scale)
    np.testing.assert_array_almost_equal(
        p, jnp.array([scale / 2, scale / 2, 0.0])
    )

  def test_projection_simplex_jacobian(self):
    rng = np.random.RandomState(0)

    x = rng.rand(5).astype(np.float32)
    v = rng.randn(5).astype(np.float32)

    jac_rev = jax.jacrev(proj.projection_simplex)(x)
    jac_fwd = jax.jacfwd(proj.projection_simplex)(x)

    with self.subTest('Check against theoretical expression'):
      p = proj.projection_simplex(x)
      jac_true = projection_simplex_jacobian(p)

      np.testing.assert_array_almost_equal(jac_true, jac_fwd)
      np.testing.assert_array_almost_equal(jac_true, jac_rev)

    with self.subTest('Check against finite difference'):
      jvp = jax.jvp(proj.projection_simplex, (x,), (v,))[1]
      eps = 1e-4
      jvp_finite_diff = (proj.projection_simplex(x + eps * v) -
                         proj.projection_simplex(x - eps * v)) / (2 * eps)
      np.testing.assert_array_almost_equal(jvp, jvp_finite_diff, decimal=3)

    with self.subTest('Check vector-Jacobian product'):
      (vjp,) = jax.vjp(proj.projection_simplex, x)[1](v)
      np.testing.assert_array_almost_equal(vjp, jnp.dot(v, jac_true))

    with self.subTest('Check Jacobian-vector product'):
      jvp = jax.jvp(proj.projection_simplex, (x,), (v,))[1]
      np.testing.assert_array_almost_equal(jvp, jnp.dot(jac_true, v))

  @parameterized.parameters(1.0, 0.8)
  def test_projection_simplex_vmap(self, scale):
    rng = np.random.RandomState(0)
    x = rng.randn(3, 50).astype(np.float32)
    scales = jnp.full(len(x), scale)

    p = jax.vmap(proj.projection_simplex)(x, scales)
    np.testing.assert_array_almost_equal(jnp.sum(p, axis=1), scales)
    np.testing.assert_array_equal(True, 0 <= p)
    np.testing.assert_array_equal(True, p <= scale)


if __name__ == '__main__':
  absltest.main()