| import timm |
|
|
| import torch.nn as nn |
|
|
| from pathlib import Path |
| from .utils import activations, forward_default, get_activation |
|
|
| from ..external.next_vit.classification.nextvit import * |
|
|
|
|
| def forward_next_vit(pretrained, x): |
| return forward_default(pretrained, x, "forward") |
|
|
|
|
| def _make_next_vit_backbone( |
| model, |
| hooks=[2, 6, 36, 39], |
| ): |
| pretrained = nn.Module() |
|
|
| pretrained.model = model |
| pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) |
| pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) |
| pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) |
| pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) |
|
|
| pretrained.activations = activations |
|
|
| return pretrained |
|
|
|
|
| def _make_pretrained_next_vit_large_6m(hooks=None): |
| model = timm.create_model("nextvit_large") |
|
|
| hooks = [2, 6, 36, 39] if hooks == None else hooks |
| return _make_next_vit_backbone( |
| model, |
| hooks=hooks, |
| ) |
|
|