| import torch |
| import torch.nn as nn |
|
|
|
|
| class EuroSATCNN(nn.Module): |
| def __init__(self, num_classes, img_height=64, img_width=64): |
| super(EuroSATCNN, self).__init__() |
| self.features = nn.Sequential( |
| nn.Conv2d(13, 128, kernel_size=4, padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(kernel_size=2), |
|
|
| nn.Conv2d(128, 64, kernel_size=4, padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(kernel_size=2), |
|
|
| nn.Conv2d(64, 32, kernel_size=4, padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(kernel_size=2), |
|
|
| nn.Conv2d(32, 16, kernel_size=4, padding=1), |
| nn.ReLU(), |
| nn.MaxPool2d(kernel_size=2), |
| ) |
|
|
| with torch.no_grad(): |
| dummy_input = torch.randn(1, 13, img_height, img_width) |
| out = self.features(dummy_input) |
| fc1_input_size = out.view(1, -1).shape[1] |
|
|
| self.classifier = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(fc1_input_size, 64), |
| nn.ReLU(), |
| nn.Linear(64, num_classes) |
|
|
| ) |
|
|
| def forward(self, x): |
| x = self.features(x) |
| x = self.classifier(x) |
| return x |