| import pytest |
| import torch.nn as nn |
|
|
| from kornia.metrics import AverageMeter |
| from kornia.x import EarlyStopping, ModelCheckpoint |
| from kornia.x.utils import TrainerState |
|
|
|
|
| @pytest.fixture |
| def model(): |
| return nn.Conv2d(3, 10, kernel_size=1) |
|
|
|
|
| def test_callback_modelcheckpoint(tmp_path, model): |
| cb = ModelCheckpoint(tmp_path, 'test_monitor') |
| assert cb is not None |
|
|
| metric = {'test_monitor': AverageMeter()} |
| metric['test_monitor'].avg = 1.0 |
|
|
| cb(model, epoch=0, valid_metric=metric) |
| assert cb.best_metric == 1.0 |
| assert (tmp_path / "model_0.pt").is_file() |
|
|
|
|
| def test_callback_earlystopping(model): |
| cb = EarlyStopping('test_monitor', patience=2) |
| assert cb is not None |
| assert cb.counter == 0 |
|
|
| metric = {'test_monitor': AverageMeter()} |
| metric['test_monitor'].avg = 1 |
|
|
| state = cb(model, epoch=0, valid_metric=metric) |
| assert state == TrainerState.TRAINING |
| assert cb.best_score == -1 |
| assert cb.counter == 0 |
|
|
| metric['test_monitor'].avg = 2 |
| state = cb(model, epoch=0, valid_metric=metric) |
| assert state == TrainerState.TRAINING |
| assert cb.best_score == -1 |
| assert cb.counter == 1 |
|
|
| state = cb(model, epoch=0, valid_metric=metric) |
| assert state == TrainerState.TERMINATE |
|
|