import os.path from get_args import get_args_pretrain import mae_model # import mae_ori_model import numpy as np import datetime import time import json import math import sys from typing import Iterable from pathlib import Path from accelerate import Accelerator import torch import torch.backends.cudnn as cudnn import torch.nn as nn from torch.utils.tensorboard import SummaryWriter import torchvision.transforms as transforms import torchvision.datasets as datasets import timm.optim.optim_factory as optim_factory from SARdatasets import SARImageFolder, build_coed_SARImageFolder, Multi_task_SARImageFolder import util.misc as misc import util.lr_sched as lr_sched from util.pos_embed import interpolate_pos_embed from util.misc import NativeScalerWithGradNormCount as NativeScaler def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, log_writer=None, args=None, accelerator=None): model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 20 accum_iter = args.accum_iter optimizer.zero_grad() if log_writer is not None: print('log_dir: {}'.format(log_writer.log_dir)) for data_iter_step, (samples, target) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): samples = samples.to(device, non_blocking=True) target = target.to(device, non_blocking=True) with torch.cuda.amp.autocast(): loss, channel_loss, _, _ = model(samples, target) #, mask_ratio=args.mask_ratio) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) accelerator.backward(loss) if (data_iter_step + 1) % accum_iter == 0: optimizer.zero_grad() torch.cuda.synchronize() metric_logger.update(loss=loss_value) lr = optimizer.param_groups[0]["lr"] metric_logger.update(lr=lr) loss_value_reduce = misc.all_reduce_mean(loss_value) if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: """ We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. """ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) log_writer.add_scalar('lr', lr, epoch_1000x) # log_writer.add_scalar('Channel Loss Mean', channel_loss, epoch_1000x) # print(f"Channel Loss Mean: {channel_loss}") # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} def main(args): misc.init_distributed_mode(args) torch.multiprocessing.set_start_method('spawn', force=True) print ('work_dir:{}'.format(os.path.realpath(__file__))) accelerator = Accelerator() device = torch.device(args.device) device = accelerator.device # fix the seed for reproducibility seed = args.seed + misc.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True # simple augmentation transform_train = transforms.Compose([ transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0)), # 3 is bicubicinterpolation=3 transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) dataset_train = Multi_task_SARImageFolder(root=args.data_path, transform=transform_train) print(dataset_train) if True: num_tasks = misc.get_world_size() global_rank = misc.get_rank() sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) print("Sampler_train = %s" % str(sampler_train)) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) if global_rank == 0 and args.log_dir is not None: os.makedirs(args.log_dir, exist_ok=True) log_writer = SummaryWriter(log_dir=args.log_dir) else: log_writer = None data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, shuffle=False ) model = mae_model.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) # load pretrain checkpoint of Imagenet checkpoint = torch.load(args.finetune, map_location='cpu') print("Load pre-trained checkpoint from: %s" % args.finetune) checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['head.weight', 'head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding interpolate_pos_embed(model, checkpoint_model) # load pre-trained model msg = model.load_state_dict(checkpoint_model, strict=False) print(msg) model.to(device) model_without_ddp = model print("Model = %s" % str(model_without_ddp)) eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() if args.lr is None: # only base_lr is specified args.lr = args.blr * eff_batch_size / 80 # 256 print("base lr: %.2e" % (args.lr * 80 / eff_batch_size)) print("actual lr: %.2e" % args.lr) print("accumulate grad iterations: %d" % args.accum_iter) print("effective batch size: %d" % eff_batch_size) if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module # following timm: set wd as 0 for bias and norm layers param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) # add_weight_decay optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) print(optimizer) loss_scaler = NativeScaler() model, optimizer, data_loader_train = accelerator.prepare(model, optimizer, data_loader_train) print(f"Start training for {args.epochs} epochs") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): train_stats = train_one_epoch( model, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args, accelerator=accelerator ) if args.output_dir and (epoch % 50 == 0 or epoch + 1 == args.epochs): misc.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch) log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, } if args.output_dir and misc.is_main_process(): if log_writer is not None: log_writer.flush() with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == '__main__': args = get_args_pretrain() args = args.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args)