| import os |
| import torch |
| import argparse |
| import torch.nn as nn |
| from tqdm.auto import tqdm |
| from torch.utils.data import DataLoader |
| import torch.nn as nn |
| import torch |
| from torch import nn |
| from torch.optim import * |
| from torch.optim.lr_scheduler import * |
| from torch.utils.data import DataLoader |
| from torchprofile import profile_macs |
| from torchvision.datasets import * |
| from torchvision.transforms import * |
| from proard.classification.data_providers.imagenet import ImagenetDataProvider |
| from proard.classification.run_manager import DistributedClassificationRunConfig, DistributedRunManager |
| from proard.model_zoo import DYN_net |
| from proard.nas.accuracy_predictor import AccuracyDataset,AccuracyPredictor,ResNetArchEncoder,RobustnessPredictor,MobileNetArchEncoder,AccuracyRobustnessDataset,Accuracy_Robustness_Predictor |
| parser = argparse.ArgumentParser() |
|
|
|
|
| def RMSELoss(yhat,y): |
| return torch.sqrt(torch.mean((yhat-y)**2)) |
| def train( |
| model: nn.Module, |
| dataloader: DataLoader, |
| criterion: nn.Module, |
| optimizer: Optimizer, |
| callbacks = None, |
| epochs = 10, |
| save_path = None |
| ) -> None: |
| model.cuda() |
| model.train() |
| for epoch in range(epochs): |
| print(epoch) |
| for inputs, targets_acc, targets_rob in tqdm(dataloader, desc='train', leave=False): |
| inputs = inputs.float().cuda() |
| targets_acc = targets_acc.cuda() |
| targets_rob = targets_rob.cuda() |
|
|
| |
| optimizer.zero_grad() |
|
|
| |
| outputs = model(inputs) |
| loss = criterion(outputs[:,0], targets_acc) + criterion(outputs[:,1], targets_rob) |
|
|
| |
| loss.backward() |
|
|
| |
| optimizer.step() |
| |
|
|
| if callbacks is not None: |
| for callback in callbacks: |
| callback() |
| torch.save(model.state_dict(), save_path) |
| return model |
|
|
| @torch.inference_mode() |
| def evaluate( |
| model: nn.Module, |
| dataloader: DataLoader, |
| ) -> float: |
| model.eval() |
|
|
| for inputs, targets_acc, targets_rob in tqdm(dataloader, desc="eval", leave=False): |
| |
| inputs = inputs.cuda() |
|
|
| targets_acc = targets_acc.cuda() |
| targets_rob = targets_rob.cuda() |
|
|
|
|
| |
| outputs = model(inputs) |
|
|
| |
| print(RMSELoss(outputs[:,0],targets_acc),RMSELoss(outputs[:,1],targets_rob)) |
| return RMSELoss(outputs[:,0],targets_acc) + RMSELoss(outputs[:,1],targets_rob) |
|
|
|
|
| def get_model_flops(model, inputs): |
| num_macs = profile_macs(model, inputs) |
| return num_macs |
|
|
|
|
| def get_model_size(model: nn.Module, data_width=32): |
| """ |
| calculate the model size in bits |
| :param data_width: #bits per element |
| """ |
| num_elements = 0 |
| for param in model.parameters(): |
| num_elements += param.numel() |
| return num_elements * data_width |
|
|
|
|
|
|
|
|
|
|
| parser.add_argument( |
| "-p", "--path", help="The path of cifar10", type=str, default="/dataset/cifar10" |
| ) |
| parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all") |
| parser.add_argument( |
| "-b", |
| "--batch_size", |
| help="The batch on every device for validation", |
| type=int, |
| default=32, |
| ) |
| parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20) |
| parser.add_argument( |
| "-n", |
| "--net", |
| metavar="DYNNET", |
| default="ResNet50", |
| choices=[ |
| "ResNet50", |
| "MBV3", |
| "ProxylessNASNet", |
| ], |
| help="Dyanmic networks", |
| ) |
| parser.add_argument( |
| "--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"] |
| ) |
| parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"]) |
| parser.add_argument( |
| "--robust_mode", type=bool, default=True |
| ) |
| args = parser.parse_args() |
| if args.net == "ResNet50": |
| arch = ResNetArchEncoder(image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[0,1,2],expand_list=[0.2,0.25,0.35],width_mult_list=[0.65,0.8,1.0]) |
| else: |
| arch = MobileNetArchEncoder (image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[2,3,4],expand_list=[3,4,6],ks_list=[3,5,7]) |
| print(arch) |
| acc_data = AccuracyRobustnessDataset("./acc_rob_data_{}_{}_{}".format(args.dataset,args.net,args.train_criterion)) |
| train_loader, valid_loader, base_acc ,base_rob = acc_data.build_acc_data_loader(arch) |
| acc_pred_network = Accuracy_Robustness_Predictor(arch_encoder=arch,base_acc_val=None) |
| |
| |
| |
| acc_pred_network.load_state_dict(torch.load("./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion))) |
| print(evaluate(acc_pred_network,valid_loader)) |
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|