File size: 10,242 Bytes
d439dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
# Migrate parameter scheduler from MMCV to MMEngine

MMCV 1.x version uses [LrUpdaterHook](https://mmcv.readthedocs.io/en/v1.6.0/api.html#mmcv.runner.LrUpdaterHook) and [MomentumUpdaterHook](https://mmcv.readthedocs.io/en/v1.6.0/api.html#mmcv.runner.MomentumUpdaterHook) to adjust the learning rate and momentum.
However, the design of LrUpdaterHook has been difficult to meet more abundant customization requirements due to the development of the training strategies. Hence, MMEngine proposes parameter schedulers (ParamScheduler).

The interface of the parameter scheduler is consistent with PyTroch's learning rate scheduler (LRScheduler). In addition, the parameter scheduler provides stronger functions. For details, please refer to [Parameter Scheduler User Guide](../tutorials/param_scheduler.md).

## Learning rate scheduler (LrUpdater) migration

MMEngine uses LRScheduler instead of LrUpdaterHook. The field in the config file is changed from the original `lr_config` to `param_scheduler`.
The learning rate config in MMCV corresponds to the parameter scheduler config in MMEngine as follows:

### Learning rate warm-up migration

The learning rate warm-up can be achieved through the combination of schedulers by specifying the effective range `begin` and `end`. There are 3 learning rate warm-up methods in MMCV, namely `'constant'`, `'linear'`, `'exp'`. The corresponding config in MMEngine should be modified as follows:

#### Constant warm-up

<table class="docutils">
  <thead>
  <tr>
      <th>MMCV-1.x</th>
      <th>MMEngine</th>
  <tbody>
  <tr>
  <td>

```python
lr_config = dict(
    warmup='constant',
    warmup_ratio=0.1,
    warmup_iters=500,
    warmup_by_epoch=False
)
```

</td>
  <td>

```python
param_scheduler = [
    dict(type='ConstantLR',
         factor=0.1,
         begin=0,
         end=500,
         by_epoch=False),
    dict(...) # the main learning rate scheduler
]
```

</td>
  </tr>
  </thead>
  </table>

#### Linear warm-up

<table class="docutils">
  <thead>
  <tr>
      <th>MMCV-1.x</th>
      <th>MMEngine</th>
  <tbody>
  <tr>
  <td>

```python
lr_config = dict(
    warmup='linear',
    warmup_ratio=0.1,
    warmup_iters=500,
    warmup_by_epoch=False
)
```

</td>
  <td>

```python
param_scheduler = [
    dict(type='LinearLR',
         start_factor=0.1,
         begin=0,
         end=500,
         by_epoch=False),
    dict(...) # the main learning rate scheduler
]
```

</td>
  </tr>
  </thead>
  </table>

#### Exponential warm-up

<table class="docutils">
  <thead>
  <tr>
      <th>MMCV-1.x</th>
      <th>MMEngine</th>
  <tbody>
  <tr>
  <td>

```python
lr_config = dict(
    warmup='exp',
    warmup_ratio=0.1,
    warmup_iters=500,
    warmup_by_epoch=False
)
```

</td>
  <td>

```python
param_scheduler = [
    dict(type='ExponentialLR',
         gamma=0.1,
         begin=0,
         end=500,
         by_epoch=False),
    dict(...) # the main learning rate scheduler
]
```

</td>
  </tr>
  </thead>
  </table>

### Fixed learning rate (FixedLrUpdaterHook) migration

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(policy='fixed')
```

</td>
<td>

```python
param_scheduler = [
    dict(type='ConstantLR', factor=1)
]
```

</td>
</tr>
</thead>
</table>

### Step learning rate (StepLrUpdaterHook) migration

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(
    policy='step',
    step=[8, 11],
    gamma=0.1,
    by_epoch=True
)
```

</td>
<td>

```python
param_scheduler = [
    dict(type='MultiStepLR',
         milestone=[8, 11],
         gamma=0.1,
         by_epoch=True)
]
```

</td>
</tr>
</thead>
</table>

### Poly learning rate (PolyLrUpdaterHook) migration

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(
    policy='poly',
    power=0.7,
    min_lr=0.001,
    by_epoch=True
)
```

</td>
<td>

```python
param_scheduler = [
    dict(type='PolyLR',
         power=0.7,
         eta_min=0.001,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]
```

</td>
</tr>
</thead>
</table>

### Exponential learning rate (ExpLrUpdaterHook) migration

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(
    policy='exp',
    power=0.5,
    by_epoch=True
)
```

</td>
<td>

```python
param_scheduler = [
    dict(type='ExponentialLR',
         gamma=0.5,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]
```

</td>
</tr>
</thead>
</table>

### Cosine annealing learning rate (CosineAnnealingLrUpdaterHook) migration

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(
    policy='CosineAnnealing',
    min_lr=0.5,
    by_epoch=True
)
```

</td>
<td>

```python
param_scheduler = [
    dict(type='CosineAnnealingLR',
         eta_min=0.5,
         T_max=num_epochs,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]
```

</td>
</tr>
</thead>
</table>

### FlatCosineAnnealingLrUpdaterHook migration

The learning rate strategy combined by multiple phases like FlatCosineAnnealing originally needs to be achieved by rewriting a Hook. But in MMEngine, it can be achieved with combining two parameter scheduler configs:

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(
    policy='FlatCosineAnnealing',
    start_percent=0.5,
    min_lr=0.005,
    by_epoch=True
)
```

</td>
<td>

```python
param_scheduler = [
    dict(type='ConstantLR', factor=1, begin=0, end=num_epochs * 0.75)
    dict(type='CosineAnnealingLR',
         eta_min=0.005,
         begin=num_epochs * 0.75,
         end=num_epochs,
         T_max=num_epochs * 0.25,
         by_epoch=True)
]
```

</td>
</tr>
</thead>
</table>

### CosineRestartLrUpdaterHook migration

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(policy='CosineRestart',
                 periods=[5, 10, 15],
                 restart_weights=[1, 0.7, 0.3],
                 min_lr=0.001,
                 by_epoch=True)
```

</td>
<td>

```python
param_scheduler = [
    dict(type='CosineRestartLR',
         periods=[5, 10, 15],
         restart_weights=[1, 0.7, 0.3],
         eta_min=0.001,
         by_epoch=True)
]
```

</td>
</tr>
</thead>
</table>

### OneCycleLrUpdaterHook migration

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(policy='OneCycle',
                 max_lr=0.02,
                 total_steps=90000,
                 pct_start=0.3,
                 anneal_strategy='cos',
                 div_factor=25,
                 final_div_factor=1e4,
                 three_phase=True,
                 by_epoch=False)
```

</td>
<td>

```python
param_scheduler = [
    dict(type='OneCycleLR',
         eta_max=0.02,
         total_steps=90000,
         pct_start=0.3,
         anneal_strategy='cos',
         div_factor=25,
         final_div_factor=1e4,
         three_phase=True,
         by_epoch=False)
]
```

</td>
</tr>
</thead>
</table>

Notice:  `by_epoch` defaults to `False` in MMCV. It now defaults to `True` in MMEngine.

### LinearAnnealingLrUpdaterHook migration

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(
    policy='LinearAnnealing',
    min_lr_ratio=0.01,
    by_epoch=True
)
```

</td>
<td>

```python
param_scheduler = [
    dict(type='LinearLR',
         start_factor=1,
         end_factor=0.01,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]
```

</td>
</tr>
</thead>
</table>

## MomentumUpdater migration

MMCV uses `momentum_config` field and MomentumUpdateHook to adjust momentum. The momentum in MMEngine is also controlled by the parameter scheduler. Users can simply change the `LR` of the learning rate scheduler to `Momentum` to use the same strategy to adjust the momentum. The momentum scheduler shares the same `param_scheduler` field in the config with the learning rate scheduler:

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(...)
momentum_config = dict(
    policy='CosineAnnealing',
    min_momentum=0.1,
    by_epoch=True
)
```

</td>
<td>

```python
param_scheduler = [
    # config of learning rate schedulers
    dict(...),
    # config of momentum schedulers
    dict(type='CosineAnnealingMomentum',
         eta_min=0.1,
         T_max=num_epochs,
         begin=0,
         end=num_epochs,
         by_epoch=True)
]
```

</td>
</tr>
</thead>
</table>

## Migrate parameter update frequency related config

If you want to update the parameter rate based on iteration while using the epoch-based training loop and setting the effective range (`begin`, `end`) or period (`T_max`) and other variables according to epoch in MMCV, you need to set `by_epoch` to False.

However, in MMEngine, the `by_epoch` in the config still needs to be set to True. Instead, you need to add `convert_to_iter_based=True` in the config to build a parameter scheduler which updates by iteration, see [Parameter Scheduler Tutorial](../tutorials/param_scheduler.md) for more details.

Take the migration of CosineAnnealing as an example:

<table class="docutils">
<thead>
<tr>
    <th>MMCV-1.x</th>
    <th>MMEngine</th>
<tbody>
<tr>
<td>

```python
lr_config = dict(
    policy='CosineAnnealing',
    min_lr=0.5,
    by_epoch=False
)
```

</td>
<td>

```python
param_scheduler = [
    dict(
        type='CosineAnnealingLR',
        eta_min=0.5,
        T_max=num_epochs,
        by_epoch=True,  # Notice, by_epoch need to be set to True
        convert_to_iter_based=True  # convert to an iter-based scheduler
    )
]
```

</td>
</tr>
</thead>
</table>

You may also want to read [parameter scheduler tutorial](../tutorials/param_scheduler.md) or [parameter scheduler API documentations](mmengine.optim.scheduler).