| import torch |
| from timm.loss import SoftTargetCrossEntropy |
| from timm.models.layers import drop |
|
|
|
|
| from models.network.hymamba import Encoder |
|
|
|
|
|
|
|
|
| |
| def _ex_repr(self): |
| return ', '.join( |
| f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) |
| for k, v in vars(self).items() |
| if not k.startswith('_') and k != 'training' |
| and not isinstance(v, (torch.nn.Module, torch.Tensor)) |
| ) |
| for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath): |
| if hasattr(clz, 'extra_repr'): |
| clz.extra_repr = _ex_repr |
| else: |
| clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' |
|
|
|
|
| pretrain_default_model_kwargs = { |
| 'mambamim': dict(sparse=True, drop_path_rate=0.1), |
| } |
| for kw in pretrain_default_model_kwargs.values(): |
| kw['pretrained'] = False |
| kw['num_classes'] = 0 |
| kw['global_pool'] = '' |
|
|
|
|
|
|
| def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): |
| from models.encoder import SparseEncoder |
| kwargs = pretrain_default_model_kwargs[name] |
| if drop_path_rate != 0: |
| kwargs['drop_path_rate'] = drop_path_rate |
| print(f'[build_sparse_encoder] model kwargs={kwargs}') |
| encoder = Encoder( |
| in_channel=1, |
| channels=(32, 64, 128, 192, 384), |
| depths=(1, 2, 2, 2, 1), |
| kernels=(3, 3, 3, 3, 3), |
| exp_r=(2, 2, 4, 4, 4), |
| img_size=96, |
| depth=4, |
| sparse=True) |
| return SparseEncoder(encoder=encoder, input_size=input_size, sbn=sbn, verbose=verbose) |