File size: 16,210 Bytes
09d8e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optax: composable gradient processing and optimization, in JAX."""

# pylint: disable=wrong-import-position
# pylint: disable=g-importing-member

from optax import contrib
from optax import losses
from optax import monte_carlo
from optax import projections
from optax import schedules
from optax import second_order
from optax import transforms
from optax import tree_utils
from optax._src.alias import adabelief
from optax._src.alias import adadelta
from optax._src.alias import adafactor
from optax._src.alias import adagrad
from optax._src.alias import adam
from optax._src.alias import adamax
from optax._src.alias import adamaxw
from optax._src.alias import adamw
from optax._src.alias import amsgrad
from optax._src.alias import fromage
from optax._src.alias import lamb
from optax._src.alias import lars
from optax._src.alias import lion
from optax._src.alias import MaskOrFn
from optax._src.alias import nadam
from optax._src.alias import nadamw
from optax._src.alias import noisy_sgd
from optax._src.alias import novograd
from optax._src.alias import optimistic_gradient_descent
from optax._src.alias import polyak_sgd
from optax._src.alias import radam
from optax._src.alias import rmsprop
from optax._src.alias import rprop
from optax._src.alias import sgd
from optax._src.alias import sm3
from optax._src.alias import yogi
from optax._src.base import EmptyState
from optax._src.base import GradientTransformation
from optax._src.base import GradientTransformationExtraArgs
from optax._src.base import identity
from optax._src.base import OptState
from optax._src.base import Params
from optax._src.base import ScalarOrSchedule
from optax._src.base import Schedule
from optax._src.base import set_to_zero
from optax._src.base import stateless
from optax._src.base import stateless_with_tree_map
from optax._src.base import TransformInitFn
from optax._src.base import TransformUpdateExtraArgsFn
from optax._src.base import TransformUpdateFn
from optax._src.base import Updates
from optax._src.base import with_extra_args_support
from optax._src.clipping import adaptive_grad_clip
from optax._src.clipping import AdaptiveGradClipState
from optax._src.clipping import clip
from optax._src.clipping import clip_by_block_rms
from optax._src.clipping import clip_by_global_norm
from optax._src.clipping import ClipByGlobalNormState
from optax._src.clipping import ClipState
from optax._src.clipping import per_example_global_norm_clip
from optax._src.clipping import per_example_layer_norm_clip
from optax._src.combine import chain
from optax._src.combine import multi_transform
from optax._src.combine import MultiTransformState
from optax._src.combine import named_chain
from optax._src.constrain import keep_params_nonnegative
from optax._src.constrain import NonNegativeParamsState
from optax._src.constrain import zero_nans
from optax._src.constrain import ZeroNansState
from optax._src.factorized import FactoredState
from optax._src.factorized import scale_by_factored_rms
from optax._src.linear_algebra import global_norm
from optax._src.linear_algebra import matrix_inverse_pth_root
from optax._src.linear_algebra import power_iteration
from optax._src.linesearch import scale_by_backtracking_linesearch
from optax._src.linesearch import ScaleByBacktrackingLinesearchState
from optax._src.lookahead import lookahead
from optax._src.lookahead import LookaheadParams
from optax._src.lookahead import LookaheadState
from optax._src.numerics import safe_int32_increment
from optax._src.numerics import safe_norm
from optax._src.numerics import safe_root_mean_squares
from optax._src.transform import add_decayed_weights
from optax._src.transform import add_noise
from optax._src.transform import AddDecayedWeightsState
from optax._src.transform import AddNoiseState
from optax._src.transform import apply_every
from optax._src.transform import ApplyEvery
from optax._src.transform import centralize
from optax._src.transform import ema
from optax._src.transform import EmaState
from optax._src.transform import scale
from optax._src.transform import scale_by_adadelta
from optax._src.transform import scale_by_adam
from optax._src.transform import scale_by_adamax
from optax._src.transform import scale_by_amsgrad
from optax._src.transform import scale_by_belief
from optax._src.transform import scale_by_distance_over_gradients
from optax._src.transform import scale_by_learning_rate
from optax._src.transform import scale_by_lion
from optax._src.transform import scale_by_novograd
from optax._src.transform import scale_by_optimistic_gradient
from optax._src.transform import scale_by_param_block_norm
from optax._src.transform import scale_by_param_block_rms
from optax._src.transform import scale_by_polyak
from optax._src.transform import scale_by_radam
from optax._src.transform import scale_by_rms
from optax._src.transform import scale_by_rprop
from optax._src.transform import scale_by_rss
from optax._src.transform import scale_by_schedule
from optax._src.transform import scale_by_sm3
from optax._src.transform import scale_by_stddev
from optax._src.transform import scale_by_trust_ratio
from optax._src.transform import scale_by_yogi
from optax._src.transform import ScaleByAdaDeltaState
from optax._src.transform import ScaleByAdamState
from optax._src.transform import ScaleByAmsgradState
from optax._src.transform import ScaleByBeliefState
from optax._src.transform import ScaleByLionState
from optax._src.transform import ScaleByNovogradState
from optax._src.transform import ScaleByRmsState
from optax._src.transform import ScaleByRpropState
from optax._src.transform import ScaleByRssState
from optax._src.transform import ScaleByRStdDevState
from optax._src.transform import ScaleByScheduleState
from optax._src.transform import ScaleBySM3State
from optax._src.transform import ScaleByTrustRatioState
from optax._src.transform import ScaleState
from optax._src.transform import trace
from optax._src.transform import TraceState
from optax._src.update import apply_updates
from optax._src.update import incremental_update
from optax._src.update import periodic_update
from optax._src.utils import multi_normal
from optax._src.utils import scale_gradient
from optax._src.utils import value_and_grad_from_state
from optax._src.wrappers import apply_if_finite
from optax._src.wrappers import ApplyIfFiniteState
from optax._src.wrappers import conditionally_mask
from optax._src.wrappers import conditionally_transform
from optax._src.wrappers import ConditionallyMaskState
from optax._src.wrappers import ConditionallyTransformState
from optax._src.wrappers import flatten
from optax._src.wrappers import masked
from optax._src.wrappers import MaskedNode
from optax._src.wrappers import MaskedState
from optax._src.wrappers import maybe_update
from optax._src.wrappers import MaybeUpdateState
from optax._src.wrappers import MultiSteps
from optax._src.wrappers import MultiStepsState
from optax._src.wrappers import ShouldSkipUpdateFunction
from optax._src.wrappers import skip_large_updates
from optax._src.wrappers import skip_not_finite


# TODO(mtthss): remove tree_utils aliases after updates.
tree_map_params = tree_utils.tree_map_params
bias_correction = tree_utils.tree_bias_correction
update_infinity_moment = tree_utils.tree_update_infinity_moment
update_moment = tree_utils.tree_update_moment
update_moment_per_elem_norm = tree_utils.tree_update_moment_per_elem_norm

# TODO(mtthss): remove schedules alises from flat namespaces after user updates.
constant_schedule = schedules.constant_schedule
cosine_decay_schedule = schedules.cosine_decay_schedule
cosine_onecycle_schedule = schedules.cosine_onecycle_schedule
exponential_decay = schedules.exponential_decay
inject_hyperparams = schedules.inject_hyperparams
InjectHyperparamsState = schedules.InjectHyperparamsState
join_schedules = schedules.join_schedules
linear_onecycle_schedule = schedules.linear_onecycle_schedule
linear_schedule = schedules.linear_schedule
piecewise_constant_schedule = schedules.piecewise_constant_schedule
piecewise_interpolate_schedule = schedules.piecewise_interpolate_schedule
polynomial_schedule = schedules.polynomial_schedule
sgdr_schedule = schedules.sgdr_schedule
warmup_cosine_decay_schedule = schedules.warmup_cosine_decay_schedule
warmup_exponential_decay_schedule = schedules.warmup_exponential_decay_schedule
inject_stateful_hyperparams = schedules.inject_stateful_hyperparams
InjectStatefulHyperparamsState = schedules.InjectStatefulHyperparamsState
WrappedSchedule = schedules.WrappedSchedule

# TODO(mtthss): remove loss aliases from flat namespace once users have updated.
convex_kl_divergence = losses.convex_kl_divergence
cosine_distance = losses.cosine_distance
cosine_similarity = losses.cosine_similarity
ctc_loss = losses.ctc_loss
ctc_loss_with_forward_probs = losses.ctc_loss_with_forward_probs
hinge_loss = losses.hinge_loss
huber_loss = losses.huber_loss
kl_divergence = losses.kl_divergence
l2_loss = losses.l2_loss
log_cosh = losses.log_cosh
ntxent = losses.ntxent
sigmoid_binary_cross_entropy = losses.sigmoid_binary_cross_entropy
smooth_labels = losses.smooth_labels
safe_softmax_cross_entropy = losses.safe_softmax_cross_entropy
softmax_cross_entropy = losses.softmax_cross_entropy
softmax_cross_entropy_with_integer_labels = (
    losses.softmax_cross_entropy_with_integer_labels
)
squared_error = losses.squared_error
sigmoid_focal_loss = losses.sigmoid_focal_loss

# pylint: disable=g-import-not-at-top
# TODO(mtthss): remove contrib aliases from flat namespace once users updated.
# Deprecated modules
from optax.contrib import differentially_private_aggregate as _deprecated_differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState as _deprecated_DifferentiallyPrivateAggregateState
from optax.contrib import dpsgd as _deprecated_dpsgd

_deprecations = {
    # Added Apr 2024
    "differentially_private_aggregate": (
        (
            "optax.differentially_private_aggregate is deprecated: use"
            " optax.contrib.differentially_private_aggregate (optax v0.1.8 or"
            " newer)."
        ),
        _deprecated_differentially_private_aggregate,
    ),
    "DifferentiallyPrivateAggregateState": (
        (
            "optax.DifferentiallyPrivateAggregateState is deprecated: use"
            " optax.contrib.DifferentiallyPrivateAggregateState (optax v0.1.8"
            " or newer)."
        ),
        _deprecated_DifferentiallyPrivateAggregateState,
    ),
    "dpsgd": (
        (
            "optax.dpsgd is deprecated: use optax.contrib.dpsgd (optax v0.1.8"
            " or newer)."
        ),
        _deprecated_dpsgd,
    ),
}
# pylint: disable=g-bad-import-order
import typing as _typing

if _typing.TYPE_CHECKING:
  # pylint: disable=reimported
  from optax.contrib import differentially_private_aggregate
  from optax.contrib import DifferentiallyPrivateAggregateState
  from optax.contrib import dpsgd
  # pylint: enable=reimported

else:
  from optax._src.deprecations import deprecation_getattr as _deprecation_getattr

  __getattr__ = _deprecation_getattr(__name__, _deprecations)
  del _deprecation_getattr
del _typing
# pylint: enable=g-bad-import-order
# pylint: enable=g-import-not-at-top
# pylint: enable=g-importing-member


__version__ = "0.2.3.dev"

__all__ = (
    "adabelief",
    "adadelta",
    "adafactor",
    "adagrad",
    "adam",
    "adamax",
    "adamaxw",
    "adamw",
    "adaptive_grad_clip",
    "AdaptiveGradClipState",
    "add_decayed_weights",
    "add_noise",
    "AddDecayedWeightsState",
    "AddNoiseState",
    "amsgrad",
    "apply_every",
    "apply_if_finite",
    "apply_updates",
    "ApplyEvery",
    "ApplyIfFiniteState",
    "centralize",
    "chain",
    "clip_by_block_rms",
    "clip_by_global_norm",
    "clip",
    "ClipByGlobalNormState",
    "ClipState",
    "conditionally_mask",
    "ConditionallyMaskState",
    "conditionally_transform",
    "ConditionallyTransformState",
    "constant_schedule",
    "ctc_loss",
    "ctc_loss_with_forward_probs",
    "convex_kl_divergence",
    "cosine_decay_schedule",
    "cosine_distance",
    "cosine_onecycle_schedule",
    "cosine_similarity",
    "differentially_private_aggregate",
    "DifferentiallyPrivateAggregateState",
    "dpsgd",
    "ema",
    "EmaState",
    "EmptyState",
    "exponential_decay",
    "FactoredState",
    "flatten",
    "fromage",
    "global_norm",
    "GradientTransformation",
    "GradientTransformationExtraArgs",
    "hinge_loss",
    "huber_loss",
    "identity",
    "incremental_update",
    "inject_hyperparams",
    "InjectHyperparamsState",
    "join_schedules",
    "keep_params_nonnegative",
    "kl_divergence",
    "l2_loss",
    "lamb",
    "lars",
    "lion",
    "linear_onecycle_schedule",
    "linear_schedule",
    "log_cosh",
    "lookahead",
    "LookaheadParams",
    "LookaheadState",
    "masked",
    "MaskOrFn",
    "MaskedState",
    "matrix_inverse_pth_root",
    "maybe_update",
    "MaybeUpdateState",
    "multi_normal",
    "multi_transform",
    "MultiSteps",
    "MultiStepsState",
    "MultiTransformState",
    "nadam",
    "nadamw",
    "noisy_sgd",
    "novograd",
    "NonNegativeParamsState",
    "ntxent",
    "OptState",
    "Params",
    "periodic_update",
    "per_example_global_norm_clip",
    "per_example_layer_norm_clip",
    "piecewise_constant_schedule",
    "piecewise_interpolate_schedule",
    "polynomial_schedule",
    "power_iteration",
    "polyak_sgd",
    "radam",
    "rmsprop",
    "rprop",
    "safe_int32_increment",
    "safe_norm",
    "safe_root_mean_squares",
    "ScalarOrSchedule",
    "scale_by_adadelta",
    "scale_by_adam",
    "scale_by_adamax",
    "scale_by_amsgrad",
    "scale_by_backtracking_linesearch",
    "scale_by_belief",
    "scale_by_lion",
    "scale_by_factored_rms",
    "scale_by_novograd",
    "scale_by_param_block_norm",
    "scale_by_param_block_rms",
    "scale_by_polyak",
    "scale_by_radam",
    "scale_by_rms",
    "scale_by_rprop",
    "scale_by_rss",
    "scale_by_schedule",
    "scale_by_sm3",
    "scale_by_stddev",
    "scale_by_trust_ratio",
    "scale_by_yogi",
    "scale_gradient",
    "scale",
    "ScaleByAdaDeltaState",
    "ScaleByAdamState",
    "ScaleByAmsgradState",
    "ScaleByBacktrackingLinesearchState",
    "ScaleByBeliefState",
    "ScaleByLionState",
    "ScaleByNovogradState",
    "ScaleByRmsState",
    "ScaleByRpropState",
    "ScaleByRssState",
    "ScaleByRStdDevState",
    "ScaleByScheduleState",
    "ScaleBySM3State",
    "ScaleByTrustRatioState",
    "ScaleState",
    "Schedule",
    "set_to_zero",
    "sgd",
    "sgdr_schedule",
    "ShouldSkipUpdateFunction",
    "sigmoid_binary_cross_entropy",
    "skip_large_updates",
    "skip_not_finite",
    "sm3",
    "smooth_labels",
    "softmax_cross_entropy",
    "softmax_cross_entropy_with_integer_labels",
    "stateless",
    "stateless_with_tree_map",
    "trace",
    "TraceState",
    "TransformInitFn",
    "TransformUpdateFn",
    "TransformUpdateExtraArgsFn",
    "Updates",
    "value_and_grad_from_state",
    "warmup_cosine_decay_schedule",
    "warmup_exponential_decay_schedule",
    "yogi",
    "zero_nans",
    "ZeroNansState",
)

#  _________________________________________
# / Please don't use symbols in `_src` they \
# \ are not part of the Optax public API.   /
#  -----------------------------------------
#         \   ^__^
#          \  (oo)\_______
#             (__)\       )\/\
#                 ||----w |
#                 ||     ||
#