| |
| |
| |
| |
|
|
| import sys |
| from collections import OrderedDict |
|
|
| import numpy as np |
| import torch |
|
|
| layer_modules = (torch.nn.MultiheadAttention,) |
|
|
|
|
| def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor, |
| batch_size=-1, |
| *args, **kwargs): |
| """ |
| give example input data as least one way like below: |
| ① input_data ---> model.forward(input_data) |
| ② input_data_args ---> model.forward(*input_data_args) |
| ③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape]) |
| """ |
|
|
| hooks = [] |
| summary = OrderedDict() |
|
|
| def register_hook(module): |
| def hook(module, inputs, outputs): |
|
|
| class_name = str(module.__class__).split(".")[-1].split("'")[0] |
| module_idx = len(summary) |
|
|
| key = "%s-%i" % (class_name, module_idx + 1) |
|
|
| info = OrderedDict() |
| info["id"] = id(module) |
| if isinstance(outputs, (list, tuple)): |
| try: |
| info["out"] = [batch_size] + list(outputs[0].size())[1:] |
| except AttributeError: |
| |
| info["out"] = [batch_size] + list(outputs[0].data.size())[1:] |
| else: |
| info["out"] = [batch_size] + list(outputs.size())[1:] |
|
|
| info["params_nt"], info["params"] = 0, 0 |
| for name, param in module.named_parameters(): |
| info["params"] += param.nelement() * param.requires_grad |
| info["params_nt"] += param.nelement() * (not param.requires_grad) |
|
|
| summary[key] = info |
|
|
| |
| if isinstance(module, layer_modules) or not module._modules: |
| hooks.append(module.register_forward_hook(hook)) |
|
|
| model.apply(register_hook) |
|
|
| |
| if isinstance(input_shape, tuple): |
| input_shape = [input_shape] |
|
|
| if input_data is not None: |
| x = [input_data] |
| elif input_shape is not None: |
| |
| x = [torch.rand(2, *size).type(input_dtype) for size in input_shape] |
| elif input_data_args is not None: |
| x = input_data_args |
| else: |
| x = [] |
| try: |
| with torch.no_grad(): |
| model(*x) if not (kwargs or args) else model(*x, *args, **kwargs) |
| except Exception: |
| |
| print("Failed to run summary...") |
| raise |
| finally: |
| for hook in hooks: |
| hook.remove() |
| summary_logs = [] |
| summary_logs.append("--------------------------------------------------------------------------") |
| line_new = "{:<30} {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #") |
| summary_logs.append(line_new) |
| summary_logs.append("==========================================================================") |
| total_params = 0 |
| total_output = 0 |
| trainable_params = 0 |
| for layer in summary: |
| |
| line_new = "{:<30} {:>20} {:>20}".format( |
| layer, |
| str(summary[layer]["out"]), |
| "{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"]) |
| ) |
| total_params += (summary[layer]["params"] + summary[layer]["params_nt"]) |
| total_output += np.prod(summary[layer]["out"]) |
| trainable_params += summary[layer]["params"] |
| summary_logs.append(line_new) |
|
|
| |
| if input_data is not None: |
| total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.)) |
| elif input_shape is not None: |
| total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.)) |
| else: |
| total_input_size = 0.0 |
| total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) |
| total_params_size = abs(total_params * 4. / (1024 ** 2.)) |
| total_size = total_params_size + total_output_size + total_input_size |
|
|
| summary_logs.append("==========================================================================") |
| summary_logs.append("Total params: {0:,}".format(total_params)) |
| summary_logs.append("Trainable params: {0:,}".format(trainable_params)) |
| summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params)) |
| summary_logs.append("--------------------------------------------------------------------------") |
| summary_logs.append("Input size (MB): %0.6f" % total_input_size) |
| summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size) |
| summary_logs.append("Params size (MB): %0.6f" % total_params_size) |
| summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size) |
| summary_logs.append("--------------------------------------------------------------------------") |
|
|
| summary_info = "\n".join(summary_logs) |
|
|
| print(summary_info) |
| return summary_info |
|
|