| import os.path as osp |
| import logging |
| import warnings |
| import sys |
|
|
| from detectron2.checkpoint import DetectionCheckpointer |
| from detectron2.config import LazyConfig, instantiate |
| from detectron2.engine import ( |
| default_argument_parser, |
| default_setup, |
| default_writers, |
| hooks, |
| launch, |
| ) |
| from detectron2.engine.defaults import create_ddp_model |
| from detectron2.utils import comm |
|
|
|
|
| sys.path.append(osp.dirname(osp.dirname(__file__))) |
| warnings.filterwarnings("ignore") |
| logger = logging.getLogger("detectron2") |
|
|
|
|
| from engine import CycleTrainer |
|
|
|
|
| def do_train(args, cfg): |
| """ |
| Args: |
| cfg: an object with the following attributes: |
| model: instantiate to a module |
| dataloader.{train,test}: instantiate to dataloaders |
| dataloader.evaluator: instantiate to evaluator for test set |
| optimizer: instantaite to an optimizer |
| lr_multiplier: instantiate to a fvcore scheduler |
| train: other misc config defined in `configs/common/train.py`, including: |
| output_dir (str) |
| init_checkpoint (str) |
| amp.enabled (bool) |
| max_iter (int) |
| eval_period, log_period (int) |
| device (str) |
| checkpointer (dict) |
| ddp (dict) |
| """ |
| model = instantiate(cfg.model) |
| logger = logging.getLogger("detectron2") |
| logger.info("Model:\n{}".format(model)) |
| model.to(cfg.train.device) |
|
|
| cfg.optimizer.params.model = model |
| optim = instantiate(cfg.optimizer) |
|
|
| train_loader = instantiate(cfg.dataloader.train) |
|
|
| model = create_ddp_model(model, **cfg.train.ddp) |
| trainer = CycleTrainer(model, train_loader, optim) |
| checkpointer = DetectionCheckpointer( |
| model, |
| cfg.train.output_dir, |
| trainer=trainer, |
| ) |
| trainer.register_hooks( |
| [ |
| hooks.IterationTimer(), |
| hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), |
| hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) |
| if comm.is_main_process() |
| else None, |
| |
| hooks.PeriodicWriter( |
| default_writers(cfg.train.output_dir, cfg.train.max_iter), |
| period=cfg.train.log_period, |
| ) |
| if comm.is_main_process() |
| else None, |
| ] |
| ) |
|
|
| checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) |
| if args.resume and checkpointer.has_checkpoint(): |
| start_iter = trainer.iter + 1 |
| else: |
| start_iter = 0 |
| trainer.train(start_iter, cfg.train.max_iter) |
|
|
|
|
| def main(args): |
| cfg = LazyConfig.load(args.config_file) |
| cfg = LazyConfig.apply_overrides(cfg, args.opts) |
| default_setup(cfg, args) |
| do_train(args, cfg) |
|
|
|
|
| if __name__ == "__main__": |
| args = default_argument_parser().parse_args() |
| launch( |
| main, |
| args.num_gpus, |
| num_machines=args.num_machines, |
| machine_rank=args.machine_rank, |
| dist_url=args.dist_url, |
| args=(args,), |
| ) |
|
|