prediff_code / datamodule /data_utils /augmentation.py
weatherforecast1024's picture
Upload folder using huggingface_hub
7667a87 verified
from typing import Sequence
import random
from torch import nn
import torchvision.transforms.functional as TF
class TransformsFixRotation(nn.Module):
r"""
Rotate by one of the given angles.
Example: `rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])`
"""
def __init__(self, angles):
super(TransformsFixRotation, self).__init__()
if not isinstance(angles, Sequence):
angles = [angles, ]
self.angles = angles
def forward(self, x):
angle = random.choice(self.angles)
return TF.rotate(x, angle)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(angles={self.angles})"