| import os.path |
|
|
| from get_args import get_args_pretrain |
| import mae_model |
| |
| import numpy as np |
| import datetime |
| import time |
| import json |
| import math |
| import sys |
| from typing import Iterable |
| from pathlib import Path |
| from accelerate import Accelerator |
| import torch |
| import torch.backends.cudnn as cudnn |
| import torch.nn as nn |
| from torch.utils.tensorboard import SummaryWriter |
| import torchvision.transforms as transforms |
| import torchvision.datasets as datasets |
| import timm.optim.optim_factory as optim_factory |
| from SARdatasets import SARImageFolder, build_coed_SARImageFolder, Multi_task_SARImageFolder |
| import util.misc as misc |
| import util.lr_sched as lr_sched |
| from util.pos_embed import interpolate_pos_embed |
| from util.misc import NativeScalerWithGradNormCount as NativeScaler |
|
|
|
|
| def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, |
| device: torch.device, epoch: int, loss_scaler, |
| log_writer=None, |
| args=None, |
| accelerator=None): |
| model.train(True) |
| metric_logger = misc.MetricLogger(delimiter=" ") |
| metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
| header = 'Epoch: [{}]'.format(epoch) |
| print_freq = 20 |
|
|
| accum_iter = args.accum_iter |
|
|
| optimizer.zero_grad() |
|
|
| if log_writer is not None: |
| print('log_dir: {}'.format(log_writer.log_dir)) |
|
|
| for data_iter_step, (samples, target) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): |
|
|
| samples = samples.to(device, non_blocking=True) |
| target = target.to(device, non_blocking=True) |
|
|
| with torch.cuda.amp.autocast(): |
| loss, channel_loss, _, _ = model(samples, target) |
|
|
| loss_value = loss.item() |
| |
| if not math.isfinite(loss_value): |
| print("Loss is {}, stopping training".format(loss_value)) |
| sys.exit(1) |
|
|
| accelerator.backward(loss) |
|
|
| if (data_iter_step + 1) % accum_iter == 0: |
| optimizer.zero_grad() |
|
|
| torch.cuda.synchronize() |
|
|
| metric_logger.update(loss=loss_value) |
|
|
| lr = optimizer.param_groups[0]["lr"] |
| metric_logger.update(lr=lr) |
|
|
| loss_value_reduce = misc.all_reduce_mean(loss_value) |
| if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: |
| """ We use epoch_1000x as the x-axis in tensorboard. |
| This calibrates different curves when batch size changes. |
| """ |
| epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) |
| log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) |
| log_writer.add_scalar('lr', lr, epoch_1000x) |
| |
| |
|
|
| |
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
| def main(args): |
| misc.init_distributed_mode(args) |
| torch.multiprocessing.set_start_method('spawn', force=True) |
| print ('work_dir:{}'.format(os.path.realpath(__file__))) |
| accelerator = Accelerator() |
| device = torch.device(args.device) |
| device = accelerator.device |
| |
| seed = args.seed + misc.get_rank() |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| cudnn.benchmark = True |
| |
| transform_train = transforms.Compose([ |
| transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0)), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| ]) |
|
|
| dataset_train = Multi_task_SARImageFolder(root=args.data_path, transform=transform_train) |
|
|
| print(dataset_train) |
|
|
| if True: |
| num_tasks = misc.get_world_size() |
| global_rank = misc.get_rank() |
| sampler_train = torch.utils.data.DistributedSampler( |
| dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) |
| print("Sampler_train = %s" % str(sampler_train)) |
| else: |
| sampler_train = torch.utils.data.RandomSampler(dataset_train) |
|
|
| if global_rank == 0 and args.log_dir is not None: |
| os.makedirs(args.log_dir, exist_ok=True) |
| log_writer = SummaryWriter(log_dir=args.log_dir) |
| else: |
| log_writer = None |
|
|
| data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, batch_size=args.batch_size, |
| num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, shuffle=False |
| ) |
|
|
| model = mae_model.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) |
|
|
| |
| checkpoint = torch.load(args.finetune, map_location='cpu') |
|
|
| print("Load pre-trained checkpoint from: %s" % args.finetune) |
| checkpoint_model = checkpoint['model'] |
| state_dict = model.state_dict() |
| for k in ['head.weight', 'head.bias']: |
| if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: |
| print(f"Removing key {k} from pretrained checkpoint") |
| del checkpoint_model[k] |
|
|
| |
| interpolate_pos_embed(model, checkpoint_model) |
| |
| msg = model.load_state_dict(checkpoint_model, strict=False) |
| print(msg) |
|
|
| model.to(device) |
| model_without_ddp = model |
| print("Model = %s" % str(model_without_ddp)) |
|
|
| eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() |
|
|
| if args.lr is None: |
| args.lr = args.blr * eff_batch_size / 80 |
|
|
| print("base lr: %.2e" % (args.lr * 80 / eff_batch_size)) |
| print("actual lr: %.2e" % args.lr) |
|
|
| print("accumulate grad iterations: %d" % args.accum_iter) |
| print("effective batch size: %d" % eff_batch_size) |
|
|
| if args.distributed: |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) |
| model_without_ddp = model.module |
|
|
| |
| param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) |
| optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) |
| print(optimizer) |
| loss_scaler = NativeScaler() |
|
|
| model, optimizer, data_loader_train = accelerator.prepare(model, optimizer, data_loader_train) |
| |
| print(f"Start training for {args.epochs} epochs") |
| start_time = time.time() |
| for epoch in range(args.start_epoch, args.epochs): |
| train_stats = train_one_epoch( |
| model, data_loader_train, |
| optimizer, device, epoch, loss_scaler, |
| log_writer=log_writer, |
| args=args, |
| accelerator=accelerator |
| ) |
| if args.output_dir and (epoch % 50 == 0 or epoch + 1 == args.epochs): |
| misc.save_model( |
| args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, |
| loss_scaler=loss_scaler, epoch=epoch) |
|
|
| log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, |
| 'epoch': epoch, } |
|
|
| if args.output_dir and misc.is_main_process(): |
| if log_writer is not None: |
| log_writer.flush() |
| with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: |
| f.write(json.dumps(log_stats) + "\n") |
|
|
|
|
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| print('Training time {}'.format(total_time_str)) |
|
|
|
|
| if __name__ == '__main__': |
| args = get_args_pretrain() |
| args = args.parse_args() |
| if args.output_dir: |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
| main(args) |
|
|