| import torch |
| from timm.loss import SoftTargetCrossEntropy |
|
|
| from timm.models.layers import DropPath |
|
|
| from .infinity import Infinity, CInfinity, FInfinity, FAInfinity, AInfinity,BInfinity, sample_with_top_k_top_p_also_inplace_modifying_logits_ |
|
|
| def _ex_repr(self): |
| return ', '.join( |
| f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) |
| for k, v in vars(self).items() |
| if not k.startswith('_') and k != 'training' |
| and not isinstance(v, (torch.nn.Module, torch.Tensor)) |
| ) |
| for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy): |
| if hasattr(clz, 'extra_repr'): |
| clz.extra_repr = _ex_repr |
| else: |
| clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' |
|
|
| DropPath.__repr__ = lambda self: f'{type(self).__name__}(...)' |
|
|
| alias_dict = {} |
| for d in range(6, 40+2, 2): |
| alias_dict[f'd{d}'] = f'infinity_d{d}' |
| alias_dict_inv = {v: k for k, v in alias_dict.items()} |
|
|