File size: 21,087 Bytes
fc0f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optax Schedules.

Schedules may be used to anneal the value of a hyper-parameter over time; for
instance, they may be used to anneal the learning rate used to update an agent's
parameters or the exploration factor used to select actions.
"""

from typing import Union, Optional, Iterable

from absl import logging
import chex
import jax.numpy as jnp
import numpy as np

from optax._src import base
from optax.schedules import _join


def constant_schedule(
    value: Union[float, int]
) -> base.Schedule:
  """Constructs a constant schedule.

  Args:
    value: value to be held constant throughout.

  Returns:
    schedule
      A function that maps step counts to values.
  """
  return lambda count: value


def polynomial_schedule(
    init_value: chex.Scalar,
    end_value: chex.Scalar,
    power: chex.Scalar,
    transition_steps: int,
    transition_begin: int = 0
) -> base.Schedule:
  """Constructs a schedule with polynomial transition from init to end value.

  Args:
    init_value: initial value for the scalar to be annealed.
    end_value: end value of the scalar to be annealed.
    power: the power of the polynomial used to transition from init to end.
    transition_steps: number of steps over which annealing takes place.
      The scalar starts changing at ``transition_begin`` steps and completes
      the transition by ``transition_begin + transition_steps`` steps.
      If ``transition_steps <= 0``, then the entire annealing process is
      disabled and the value is held fixed at ``init_value``.
    transition_begin: must be positive. After how many steps to start annealing
      (before this many steps the scalar value is held fixed at ``init_value``).

  Returns:
    schedule
      A function that maps step counts to values.
  """
  if transition_steps <= 0:
    logging.info(
        'A polynomial schedule was set with a non-positive `transition_steps` '
        'value; this results in a constant schedule with value `init_value`.')
    return lambda count: init_value

  if transition_begin < 0:
    logging.info(
        'A polynomial schedule was set with a negative `transition_begin` '
        'value; this will result in `transition_begin` falling back to `0`.')
    transition_begin = 0

  def schedule(count):
    count = jnp.clip(count - transition_begin, 0, transition_steps)
    frac = 1 - count / transition_steps
    return (init_value - end_value) * (frac**power) + end_value
  return schedule


def linear_schedule(
    init_value: chex.Scalar,
    end_value: chex.Scalar,
    transition_steps: int,
    transition_begin: int = 0
) -> base.Schedule:
  r"""Schedule with linear transition from ``init_value`` to ``end_value``.

  More precisely, the learning rate at iteration :math:`t` is given by:

  .. math::
    \begin{cases}
      I, & \text{if } t < B \\
      I + \frac{t - B}{T} (E - I), & \text{if } B \leq t < B + T \\
      E, & \text{if } t \geq B + T
    \end{cases}

  where :math:`I` is the initial value, :math:`E` is the end value,
  :math:`B` is the transition begin, and :math:`T` is the transition steps.

  This schedule is equivalent to :func:`optax.polynomial_schedule` with
  ``power=1``.

  Examples:
    >>> schedule_fn = optax.linear_schedule(
    ...    init_value=1.0, end_value=0.01, transition_steps=100)
    >>> schedule_fn(0)  # learning rate on the first iteration
    Array(1., dtype=float32, weak_type=True)
    >>> schedule_fn(100)  # learning rate on the last iteration
    Array(0.01, dtype=float32, weak_type=True)

  Args:
    init_value: initial value for the scalar to be annealed.
    end_value: end value of the scalar to be annealed.
    transition_steps: number of steps over which annealing takes place. The
      scalar starts changing at ``transition_begin`` steps and completes the
      transition by ``transition_begin + transition_steps`` steps. If
      ``transition_steps <= 0``, then the entire annealing process is disabled
      and the value is held fixed at ``init_value``.
    transition_begin: must be positive. After how many steps to start annealing
      (before this many steps the scalar value is held fixed at ``init_value``).

  Returns:
    schedule
      A function that maps step counts to values.
  """
  return polynomial_schedule(
      init_value=init_value, end_value=end_value, power=1,
      transition_steps=transition_steps, transition_begin=transition_begin)


def piecewise_constant_schedule(
    init_value: float,
    boundaries_and_scales: Optional[dict[int, float]] = None
) -> base.Schedule:
  """Returns a function which implements a piecewise constant schedule.

  Args:
    init_value: An initial value ``init_v``.
    boundaries_and_scales: A map from boundaries ``b_i`` to non-negative scaling
      factors ``f_i``. For any step count `s`, the schedule returns ``init_v``
      scaled by the product of all factors ``f_i`` such that ``b_i < s``.

  Returns:
    schedule
      A function that maps step counts to values.
  """
  if boundaries_and_scales is not None:
    all_positive = all(scale >= 0. for scale in boundaries_and_scales.values())
    if not all_positive:
      raise ValueError(
          '`piecewise_constant_schedule` expects non-negative scale factors')

  def schedule(count):
    v = init_value
    if boundaries_and_scales is not None:
      for threshold, scale in sorted(boundaries_and_scales.items()):
        indicator = jnp.maximum(0., jnp.sign(threshold - count))
        v = v * indicator + (1 - indicator) * scale * v
    return v

  return schedule


def exponential_decay(
    init_value: float,
    transition_steps: int,
    decay_rate: float,
    transition_begin: int = 0,
    staircase: bool = False,
    end_value: Optional[float] = None
) -> base.Schedule:
  """Constructs a schedule with either continuous or discrete exponential decay.

  This function applies an exponential decay function to a provided initial
  value. When ``count >= transition_begin`` the function returns the decayed
  value as:

  .. code-block::

    rate_factor = ((count - transition_begin) / transition_steps)
    decayed_value = init_value * (decay_rate ** rate_factor)

  If the argument ``staircase`` is ``True`` then ``count / transition_steps`` is
  an integer division and the decayed value follows a staircase function.

  Args:
    init_value: the initial learning rate.
    transition_steps: must be positive. See the decay computation above.
    decay_rate: must not be zero. The decay rate.
    transition_begin: must be positive. After how many steps to start annealing
      (before this many steps the scalar value is held fixed at `init_value`).
    staircase: if ``True``, decay the values at discrete intervals.
    end_value: the value at which the exponential decay stops. When
      ``decay_rate < 1``, ``end_value`` is treated as a lower bound, otherwise
      as an upper bound. Has no effect when ``decay_rate = 0``.

  Returns:
    schedule
      A function that maps step counts to values.
  """

  if transition_steps <= 0:
    logging.info(
        'An exponential schedule was set with a non-positive `transition_steps`'
        ' value; this will result in a constant schedule with value '
        '`init_value`.')
    return lambda count: init_value

  if decay_rate == 0:
    logging.info(
        'An exponential schedule was set with a zero `decay_rate` value; '
        'this will result in a constant schedule with value `init_value`.')
    return lambda count: init_value

  if transition_begin < 0:
    logging.info(
        'An exponential schedule was set with a negative `transition_begin` '
        'value; this will result in `transition_begin` falling back to `0`.')
    transition_begin = 0

  if end_value is not None:
    clip_fn = jnp.maximum if decay_rate < 1.0 else jnp.minimum

  def schedule(count):
    decreased_count = count - transition_begin
    p = decreased_count / transition_steps
    if staircase:
      p = jnp.floor(p)
    decayed_value = jnp.where(
        decreased_count <= 0, init_value, init_value * jnp.power(decay_rate, p))
    if end_value is not None:
      decayed_value = clip_fn(decayed_value, end_value)
    return decayed_value

  return schedule


def cosine_decay_schedule(
    init_value: float,
    decay_steps: int,
    alpha: float = 0.0,
    exponent: float = 1.0,
) -> base.Schedule:
  r"""Returns a function which implements cosine learning rate decay.

  This schedule smoothly decreases the learning rate over a specified number of
  steps (``decay_steps``). The decay follows a cosine function, with an optional
  exponent to modify the decay curve. A minimum value (``alpha``) ensures the
  learning rate does not drop entirely to zero.

  More precisely, the learning rate at iteration :math:`t` is given by:

  .. math::

     \frac{I (1 - \alpha)}{2}(1+\cos(\pi\,\frac{t}{T})^p) + \alpha\,,

  where :math:`T` is the number of decay steps (``decay_steps``), :math:`p` is
  the ``exponent`` and :math:`I` is the initial value (``init_value``).

  References:
    Loshchilov et al., `SGDR: Stochastic Gradient Descent with Warm Restarts
    <https://arxiv.org/abs/1608.03983>`_, 2017

  Args:
    init_value: An initial value for the learning rate.
    decay_steps: Positive integer - the number of steps for which to apply
      the decay for.
    alpha: The minimum value of the multiplier used to adjust the
      learning rate. Defaults to 0.0.
    exponent:  The default decay is ``0.5 * (1 + cos(pi * t/T))``, where 
      ``t`` is the current timestep and ``T`` is the ``decay_steps``. The
      exponent modifies this to be ``(0.5 * (1 + cos(pi * t/T))) ** exponent``.
      Defaults to 1.0.

  Returns:
    schedule
      A function that maps step counts to values.
  """
  if not decay_steps > 0:
    raise ValueError(
        'The cosine_decay_schedule requires positive decay_steps, got'
        f' {decay_steps=}.'
    )

  def schedule(count):
    count = jnp.minimum(count, decay_steps)
    cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * count / decay_steps))
    decayed = (1 - alpha) * cosine_decay ** exponent + alpha
    return init_value * decayed

  return schedule


def _linear_interpolate(start: float, end: float, pct: float):
  return (end-start) * pct + start


def _cosine_interpolate(start: float, end: float, pct: float):
  return end + (start-end) / 2.0 * (jnp.cos(jnp.pi * pct) + 1)


def piecewise_interpolate_schedule(
    interpolate_type: str,
    init_value: float,
    boundaries_and_scales: Optional[dict[int, float]] = None
) -> base.Schedule:
  """Returns a function which implements a piecewise interpolated schedule.

  Args:
    interpolate_type: 'linear' or 'cosine', specifying the interpolation
      strategy.
    init_value: An initial value ``init_v``.
    boundaries_and_scales: A map from boundaries ``b_i`` to non-negative scaling
      factors ``f_i``. At boundary step ``b_i``, the schedule returns ``init_v``
      scaled by the product of all factors ``f_j`` such that ``b_j <= b_i``.
      The values in between each boundary will be interpolated as per ``type``.

  Returns:
    schedule
      A function that maps step counts to values.
  """
  if interpolate_type == 'linear':
    interpolate_fn = _linear_interpolate
  elif interpolate_type == 'cosine':
    interpolate_fn = _cosine_interpolate
  else:
    raise ValueError('`interpolate_type` must be either \'cos\' or \'linear\'')

  if boundaries_and_scales:
    boundaries, scales = zip(*sorted(boundaries_and_scales.items()))
    if not all(scale >= 0. for scale in scales):
      raise ValueError(
          '`piecewise_interpolate_schedule` expects non-negative scale factors')
  else:
    boundaries, scales = (), ()

  bounds = np.stack((0,) + boundaries)
  values = np.cumprod(np.stack((init_value,) + scales))
  interval_sizes = bounds[1:] - bounds[:-1]

  def schedule(count):
    indicator = (bounds[:-1] <= count) & (count < bounds[1:])
    pct = (count - bounds[:-1]) / interval_sizes
    interp_vals = interpolate_fn(values[:-1], values[1:], pct)
    return indicator.dot(interp_vals) + (bounds[-1] <= count) * values[-1]

  return schedule


def linear_onecycle_schedule(
    transition_steps: int,
    peak_value: float,
    pct_start: float = 0.3,
    pct_final: float = 0.85,
    div_factor: float = 25.0,
    final_div_factor: float = 1e4
) -> base.Schedule:
  """Returns a function which implements the onecycle learning rate schedule.

  This function uses a linear annealing strategy.

  References:
    Smith et al, `Super-Convergence: Very Fast Training of Neural Networks Using
    Large Learning Rates <https://arxiv.org/abs/1708.07120>`_, 2017


  Args:
    transition_steps: Number of steps over which annealing takes place.
    peak_value: Maximum value attained by schedule at pct_start percent
      of the cycle (in number of steps).
    pct_start: The percentage of the cycle (in number of steps) spent
      increasing the learning rate.
    pct_final: The percentage of the cycle (in number of steps) spent
      increasing to ``peak_value`` then decreasing back to ``init_value``.
    div_factor: Determines the initial value via ``init_value =
      peak_value / div_factor``.
    final_div_factor: Determines the final value via ``final_value =
      init_value / final_div_factor``.

  Returns:
    schedule
      A function that maps step counts to values
  """
  if transition_steps <= 0:
    raise ValueError(
        'A linear onecycle schedule was set with a non-positive '
        '`transition_steps`')

  return piecewise_interpolate_schedule(
      'linear',
      peak_value / div_factor,
      {int(pct_start * transition_steps): div_factor,
       int(pct_final * transition_steps): 1. / div_factor,
       transition_steps: 1. / final_div_factor})


def cosine_onecycle_schedule(
    transition_steps: int,
    peak_value: float,
    pct_start: float = 0.3,
    div_factor: float = 25.0,
    final_div_factor: float = 1e4
) -> base.Schedule:
  """Returns a function which implements the onecycle learning rate schedule.

  This learning rate increases the learning rate and then decreases it in a
  cosine-like manner. The number of steps over which the learning rate increases
  is determined by the ``pct_start`` argument. The maximum value of the learning
  rate is determined by the ``peak_value`` argument, the initial value of the 
  learning rate is determined through the formula ``init_value = peak_value /
  div_factor``, and the final value is determined by the ``final_div_factor``
  argument.

  References:
    Smith et al, `Super-Convergence: Very Fast Training of Neural Networks Using
    Large Learning Rates <https://arxiv.org/abs/1708.07120>`_, 2017

  Args:
    transition_steps: Number of steps over which annealing takes place.
    peak_value: Maximum value attained by schedule at pct_start percent of the
      cycle (in number of steps).
    pct_start: The percentage of the cycle (in number of steps) spent increasing
      the learning rate.
    div_factor: Determines the initial value via ``init_value = peak_value /
      div_factor``.
    final_div_factor: Determines the final value via ``final_value = init_value
      / final_div_factor``.

  Returns:
    schedule
      A function that maps step counts to values
  """
  if transition_steps <= 0:
    raise ValueError(
        'A linear onecycle schedule was set with a non-positive '
        '`transition_steps`')

  return piecewise_interpolate_schedule(
      'cosine',
      peak_value / div_factor,
      {int(pct_start * transition_steps): div_factor,
       int(transition_steps): 1. / (div_factor * final_div_factor)})


def warmup_constant_schedule(
    init_value: float,
    peak_value: float,
    warmup_steps: int,
) -> base.Schedule:
  r"""Linear warmup followed by constant schedule i.e no decay.

  Args:
    init_value: Initial value for the scalar to be annealed.
    peak_value: Peak value for scalar to be annealed at end of warmup.
    warmup_steps: Positive integer, the length of the linear warmup.

  Returns:
    schedule
      A function that maps step counts to values
  """
  return linear_schedule(
      init_value=init_value,
      end_value=peak_value,
      transition_steps=warmup_steps,
  )


def warmup_cosine_decay_schedule(
    init_value: float,
    peak_value: float,
    warmup_steps: int,
    decay_steps: int,
    end_value: float = 0.0,
    exponent: float = 1.0,
) -> base.Schedule:
  r"""Linear warmup followed by cosine decay.

  Args:
    init_value: Initial value for the scalar to be annealed.
    peak_value: Peak value for scalar to be annealed at end of warmup.
    warmup_steps: Positive integer, the length of the linear warmup.
    decay_steps: Positive integer, the total length of the schedule. Note that
      this includes the warmup time, so the number of steps during which cosine
      annealing is applied is ``decay_steps - warmup_steps``.
    end_value: End value of the scalar to be annealed.
    exponent: Float. The default decay is ``0.5 * (1 + cos(pi t/T))``,
      where ``t`` is the current timestep and ``T`` is ``decay_steps``.
      The exponent modifies this to be ``(0.5 * (1 + cos(pi * t/T)))
      ** exponent``.
      Defaults to 1.0.

  Returns:
    schedule
      A function that maps step counts to values
  """
  alpha = 0. if peak_value == 0. else end_value / peak_value
  schedules = [
      linear_schedule(
          init_value=init_value,
          end_value=peak_value,
          transition_steps=warmup_steps,
      ),
      cosine_decay_schedule(
          init_value=peak_value,
          decay_steps=decay_steps - warmup_steps,
          alpha=alpha,
          exponent=exponent,
      ),
  ]
  return _join.join_schedules(schedules, [warmup_steps])


def warmup_exponential_decay_schedule(
    init_value: float,
    peak_value: float,
    warmup_steps: int,
    transition_steps: int,
    decay_rate: float,
    transition_begin: int = 0,
    staircase: bool = False,
    end_value: Optional[float] = None
) -> base.Schedule:
  """Linear warmup followed by exponential decay.

  Args:
    init_value: Initial value for the scalar to be annealed.
    peak_value: Peak value for scalar to be annealed at end of warmup.
    warmup_steps: Positive integer, the length of the linear warmup.
    transition_steps: must be positive. See :func:`optax.exponential_decay`
      for more details.
    decay_rate: must not be zero. The decay rate.
    transition_begin: must be positive. After how many steps to start annealing
      (before this many steps the scalar value is held fixed at ``peak_value``).
    staircase: if ``True``, decay the values at discrete intervals.
    end_value: the value at which the exponential decay stops. When
      ``decay_rate < 1``, ``end_value`` is treated as a lower bound, otherwise
      as an upper bound. Has no effect when ``decay_rate = 0``.

  Returns:
    schedule
      A function that maps step counts to values
  """
  schedules = [
      linear_schedule(
          init_value=init_value,
          end_value=peak_value,
          transition_steps=warmup_steps),
      exponential_decay(
          init_value=peak_value,
          transition_steps=transition_steps,
          decay_rate=decay_rate,
          transition_begin=transition_begin,
          staircase=staircase,
          end_value=end_value)]
  return _join.join_schedules(schedules, [warmup_steps])


def sgdr_schedule(cosine_kwargs: Iterable[dict[str, chex.Numeric]]
                  ) -> base.Schedule:
  """SGD with warm restarts.

  This learning rate schedule applies multiple joined cosine decay cycles.

  References:
    Loshchilov et al., `SGDR: Stochastic Gradient Descent with Warm Restarts
    <https://arxiv.org/abs/1608.03983>`_, 2017

  Args:
    cosine_kwargs: An Iterable of dicts, where each element specifies the
      arguments to pass to each cosine decay cycle. The ``decay_steps`` kwarg
      will specify how long each cycle lasts for, and therefore when to
      transition to the next cycle.

  Returns:
    schedule
      A function that maps step counts to values
  """
  boundaries = []
  schedules = []
  step = 0
  for kwargs in cosine_kwargs:
    schedules += [warmup_cosine_decay_schedule(**kwargs)]
    boundaries += [step + kwargs['decay_steps']]
    step += kwargs['decay_steps']
  return _join.join_schedules(schedules, boundaries[:-1])