| import os.path as osp |
| import warnings |
| warnings.filterwarnings('ignore') |
| from typing import Optional |
| from pathlib import Path |
| from models.maplocnet import MapLocNet |
| import hydra |
| import pytorch_lightning as pl |
| import torch |
| from omegaconf import DictConfig, OmegaConf |
| from pytorch_lightning.utilities import rank_zero_only |
| from module import GenericModule |
| from logger import logger, pl_logger, EXPERIMENTS_PATH |
| from module import GenericModule |
| from dataset import UavMapDatasetModule |
| from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
| |
|
|
|
|
| class CleanProgressBar(pl.callbacks.TQDMProgressBar): |
| def get_metrics(self, trainer, model): |
| items = super().get_metrics(trainer, model) |
| items.pop("v_num", None) |
| items.pop("loss", None) |
| return items |
|
|
|
|
| class SeedingCallback(pl.callbacks.Callback): |
| def on_epoch_start_(self, trainer, module): |
| seed = module.cfg.experiment.seed |
| is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0 |
| if trainer.training and not is_overfit: |
| seed = seed + trainer.current_epoch |
|
|
| |
| pl_logger.disabled = True |
| try: |
| pl.seed_everything(seed, workers=True) |
| finally: |
| pl_logger.disabled = False |
|
|
| def on_train_epoch_start(self, *args, **kwargs): |
| self.on_epoch_start_(*args, **kwargs) |
|
|
| def on_validation_epoch_start(self, *args, **kwargs): |
| self.on_epoch_start_(*args, **kwargs) |
|
|
| def on_test_epoch_start(self, *args, **kwargs): |
| self.on_epoch_start_(*args, **kwargs) |
|
|
|
|
| class ConsoleLogger(pl.callbacks.Callback): |
| @rank_zero_only |
| def on_train_epoch_start(self, trainer, module): |
| logger.info( |
| "New training epoch %d for experiment '%s'.", |
| module.current_epoch, |
| module.cfg.experiment.name, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def find_last_checkpoint_path(experiment_dir): |
| cls = pl.callbacks.ModelCheckpoint |
| path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION) |
| if osp.exists(path): |
| return path |
| else: |
| return None |
|
|
|
|
| def prepare_experiment_dir(experiment_dir, cfg, rank): |
| config_path = osp.join(experiment_dir, "config.yaml") |
| last_checkpoint_path = find_last_checkpoint_path(experiment_dir) |
| if last_checkpoint_path is not None: |
| if rank == 0: |
| logger.info( |
| "Resuming the training from checkpoint %s", last_checkpoint_path |
| ) |
| if osp.exists(config_path): |
| with open(config_path, "r") as fp: |
| cfg_prev = OmegaConf.create(fp.read()) |
| compare_keys = ["experiment", "data", "model", "training"] |
| if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy( |
| cfg_prev, compare_keys |
| ): |
| raise ValueError( |
| "Attempting to resume training with a different config: " |
| f"{OmegaConf.masked_copy(cfg, compare_keys)} vs " |
| f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}" |
| ) |
| if rank == 0: |
| Path(experiment_dir).mkdir(exist_ok=True, parents=True) |
| with open(config_path, "w") as fp: |
| OmegaConf.save(cfg, fp) |
| return last_checkpoint_path |
|
|
|
|
| def train(cfg: DictConfig) -> None: |
| torch.set_float32_matmul_precision("medium") |
| OmegaConf.resolve(cfg) |
| rank = rank_zero_only.rank |
|
|
| if rank == 0: |
| logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg)) |
| if cfg.experiment.gpus in (None, 0): |
| logger.warning("Will train on CPU...") |
| cfg.experiment.gpus = 0 |
| elif not torch.cuda.is_available(): |
| raise ValueError("Requested GPU but no NVIDIA drivers found.") |
| pl.seed_everything(cfg.experiment.seed, workers=True) |
|
|
| init_checkpoint_path = cfg.training.get("finetune_from_checkpoint") |
| if init_checkpoint_path is not None: |
| logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path) |
| model = GenericModule.load_from_checkpoint( |
| init_checkpoint_path, strict=True, find_best=False, cfg=cfg |
| ) |
| else: |
| model = GenericModule(cfg) |
| if rank == 0: |
| logger.info("Network:\n%s", model.model) |
|
|
| experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name) |
| last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank) |
| checkpointing_epoch = pl.callbacks.ModelCheckpoint( |
| dirpath=experiment_dir, |
| filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}", |
| auto_insert_metric_name=False, |
| save_last=True, |
| every_n_epochs=1, |
| save_on_train_epoch_end=True, |
| verbose=True, |
| **cfg.training.checkpointing, |
| ) |
| checkpointing_step = pl.callbacks.ModelCheckpoint( |
| dirpath=experiment_dir, |
| filename="checkpoint-step-{step}-{loss/total/val:02f}", |
| auto_insert_metric_name=False, |
| save_last=True, |
| every_n_train_steps=1000, |
| verbose=True, |
| **cfg.training.checkpointing, |
| ) |
| checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing" |
|
|
| |
| early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5) |
|
|
| strategy = None |
| if cfg.experiment.gpus > 1: |
| strategy = pl.strategies.DDPStrategy(find_unused_parameters=False) |
| for split in ["train", "val"]: |
| cfg.data[split].batch_size = ( |
| cfg.data[split].batch_size // cfg.experiment.gpus |
| ) |
| cfg.data[split].num_workers = int( |
| (cfg.data[split].num_workers + cfg.experiment.gpus - 1) |
| / cfg.experiment.gpus |
| ) |
|
|
| |
|
|
| datamodule =UavMapDatasetModule(cfg.data) |
|
|
| tb_args = {"name": cfg.experiment.name, "version": ""} |
| tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args) |
|
|
| callbacks = [ |
| checkpointing_epoch, |
| checkpointing_step, |
| |
| pl.callbacks.LearningRateMonitor(), |
| SeedingCallback(), |
| CleanProgressBar(), |
| ConsoleLogger(), |
| ] |
| if cfg.experiment.gpus > 0: |
| callbacks.append(pl.callbacks.DeviceStatsMonitor()) |
|
|
| trainer = pl.Trainer( |
| default_root_dir=experiment_dir, |
| detect_anomaly=False, |
| |
| enable_model_summary=True, |
| sync_batchnorm=True, |
| enable_checkpointing=True, |
| logger=tb, |
| callbacks=callbacks, |
| strategy=strategy, |
| check_val_every_n_epoch=1, |
| accelerator="gpu", |
| num_nodes=1, |
| **cfg.training.trainer, |
| ) |
| trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path) |
|
|
|
|
| @hydra.main( |
| config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml" |
| ) |
| def main(cfg: DictConfig) -> None: |
| OmegaConf.save(config=cfg, f='maplocnet.yaml') |
| train(cfg) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|