ViTGaze / tools /train.py
yhsong's picture
initial commit
f9561b9 verified
import os.path as osp
import logging
import warnings
import sys
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine import (
default_argument_parser,
default_setup,
default_writers,
hooks,
launch,
)
from detectron2.engine.defaults import create_ddp_model
from detectron2.utils import comm
sys.path.append(osp.dirname(osp.dirname(__file__)))
warnings.filterwarnings("ignore")
logger = logging.getLogger("detectron2")
from engine import CycleTrainer
def do_train(args, cfg):
"""
Args:
cfg: an object with the following attributes:
model: instantiate to a module
dataloader.{train,test}: instantiate to dataloaders
dataloader.evaluator: instantiate to evaluator for test set
optimizer: instantaite to an optimizer
lr_multiplier: instantiate to a fvcore scheduler
train: other misc config defined in `configs/common/train.py`, including:
output_dir (str)
init_checkpoint (str)
amp.enabled (bool)
max_iter (int)
eval_period, log_period (int)
device (str)
checkpointer (dict)
ddp (dict)
"""
model = instantiate(cfg.model)
logger = logging.getLogger("detectron2")
logger.info("Model:\n{}".format(model))
model.to(cfg.train.device)
cfg.optimizer.params.model = model
optim = instantiate(cfg.optimizer)
train_loader = instantiate(cfg.dataloader.train)
model = create_ddp_model(model, **cfg.train.ddp)
trainer = CycleTrainer(model, train_loader, optim)
checkpointer = DetectionCheckpointer(
model,
cfg.train.output_dir,
trainer=trainer,
)
trainer.register_hooks(
[
hooks.IterationTimer(),
hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
if comm.is_main_process()
else None,
# hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
hooks.PeriodicWriter(
default_writers(cfg.train.output_dir, cfg.train.max_iter),
period=cfg.train.log_period,
)
if comm.is_main_process()
else None,
]
)
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
if args.resume and checkpointer.has_checkpoint():
start_iter = trainer.iter + 1
else:
start_iter = 0
trainer.train(start_iter, cfg.train.max_iter)
def main(args):
cfg = LazyConfig.load(args.config_file)
cfg = LazyConfig.apply_overrides(cfg, args.opts)
default_setup(cfg, args)
do_train(args, cfg)
if __name__ == "__main__":
args = default_argument_parser().parse_args()
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)