| import re |
| import torch.nn as nn |
|
|
| class BaseObject(nn.Module): |
|
|
| def __init__(self, name=None): |
| super().__init__() |
| self._name = name |
|
|
| @property |
| def __name__(self): |
| if self._name is None: |
| name = self.__class__.__name__ |
| s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) |
| return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() |
| else: |
| return self._name |
|
|
|
|
| class Metric(BaseObject): |
| pass |
|
|
|
|
| class Loss(BaseObject): |
|
|
| def __add__(self, other): |
| if isinstance(other, Loss): |
| return SumOfLosses(self, other) |
| else: |
| raise ValueError('Loss should be inherited from `Loss` class') |
|
|
| def __radd__(self, other): |
| return self.__add__(other) |
|
|
| def __mul__(self, value): |
| if isinstance(value, (int, float)): |
| return MultipliedLoss(self, value) |
| else: |
| raise ValueError('Loss should be inherited from `BaseLoss` class') |
|
|
| def __rmul__(self, other): |
| return self.__mul__(other) |
|
|
|
|
| class SumOfLosses(Loss): |
|
|
| def __init__(self, l1, l2): |
| name = '{} + {}'.format(l1.__name__, l2.__name__) |
| super().__init__(name=name) |
| self.l1 = l1 |
| self.l2 = l2 |
|
|
| def __call__(self, *inputs): |
| return self.l1.forward(*inputs) + self.l2.forward(*inputs) |
|
|
|
|
| class MultipliedLoss(Loss): |
|
|
| def __init__(self, loss, multiplier): |
|
|
| |
| if len(loss.__name__.split('+')) > 1: |
| name = '{} * ({})'.format(multiplier, loss.__name__) |
| else: |
| name = '{} * {}'.format(multiplier, loss.__name__) |
| super().__init__(name=name) |
| self.loss = loss |
| self.multiplier = multiplier |
|
|
| def __call__(self, *inputs): |
| return self.multiplier * self.loss.forward(*inputs) |
|
|