| import argparse |
| import torch |
| import os |
| import sys |
| from thop import profile, clever_format |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| filepath = os.path.split(__file__)[0] |
| repopath = os.path.split(filepath)[0] |
| sys.path.append(repopath) |
|
|
| from lib import * |
| from lib.optim import * |
| from data.dataloader import * |
| from utils.misc import * |
|
|
| def _args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, default='configs/InSPyReNet_SwinB.yaml') |
| parser.add_argument('--input_size', type=int, nargs='+', default=[384, 384]) |
| parser.add_argument('--verbose', action='store_true', default=False) |
| return parser.parse_args() |
|
|
| def benchmark(opt, args): |
| model = Simplify(eval(opt.Model.name)(**opt.Model)) |
| model = model.cuda() |
| |
| input = torch.rand(1, 3, *args.input_size) |
| input = input.cuda() |
| |
| macs, params = profile(model, inputs=(input, ), verbose=False) |
| macs, params = clever_format([macs, params], "%.3f") |
| |
| with torch.no_grad(): |
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
|
|
| start.record() |
| for i in range(10): |
| out = model(input) |
| end.record() |
| |
| |
| torch.cuda.synchronize() |
|
|
| print('Model:', opt.Model.name) |
| print('MACs:', macs, 'Params:', params) |
| print('Throughput:', start.elapsed_time(end) / 10, 'msec') |
|
|
| if __name__ == '__main__': |
| args = _args() |
| opt = load_config(args.config) |
| benchmark(opt, args) |
|
|