| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
| from functools import partialmethod |
| from typing import Union, List |
|
|
|
|
| class Dropout(nn.Module): |
| """ |
| Implementation of dropout with the ability to share the dropout mask |
| along a particular dimension. |
| |
| If not in training mode, this module computes the identity function. |
| """ |
|
|
| def __init__(self, r: float, batch_dim: Union[int, List[int]]): |
| """ |
| Args: |
| r: |
| Dropout rate |
| batch_dim: |
| Dimension(s) along which the dropout mask is shared |
| """ |
| super(Dropout, self).__init__() |
|
|
| self.r = r |
| if type(batch_dim) == int: |
| batch_dim = [batch_dim] |
| self.batch_dim = batch_dim |
| self.dropout = nn.Dropout(self.r) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: |
| Tensor to which dropout is applied. Can have any shape |
| compatible with self.batch_dim |
| """ |
| shape = list(x.shape) |
| if self.batch_dim is not None: |
| for bd in self.batch_dim: |
| shape[bd] = 1 |
| mask = x.new_ones(shape) |
| mask = self.dropout(mask) |
| x *= mask |
| return x |
|
|
|
|
| class DropoutRowwise(Dropout): |
| """ |
| Convenience class for rowwise dropout as described in subsection |
| 1.11.6. |
| """ |
|
|
| __init__ = partialmethod(Dropout.__init__, batch_dim=-3) |
|
|
|
|
| class DropoutColumnwise(Dropout): |
| """ |
| Convenience class for columnwise dropout as described in subsection |
| 1.11.6. |
| """ |
|
|
| __init__ = partialmethod(Dropout.__init__, batch_dim=-2) |
|
|