| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import lightning as L |
|
|
|
|
| class BasicBlock(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, in_channels, out_channels, stride=1): |
| super().__init__() |
| self.conv1 = nn.Conv2d( |
| in_channels, out_channels, kernel_size=3, |
| stride=stride, padding=1, bias=False |
| ) |
| self.bn1 = nn.BatchNorm2d(out_channels) |
|
|
| self.conv2 = nn.Conv2d( |
| out_channels, out_channels, kernel_size=3, |
| stride=1, padding=1, bias=False |
| ) |
| self.bn2 = nn.BatchNorm2d(out_channels) |
|
|
| |
| self.shortcut = nn.Sequential() |
| if stride != 1 or in_channels != out_channels: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d( |
| in_channels, out_channels, kernel_size=1, |
| stride=stride, bias=False |
| ), |
| nn.BatchNorm2d(out_channels) |
| ) |
|
|
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| out += self.shortcut(x) |
| out = F.relu(out) |
| return out |
|
|
|
|
| class ResNet18_CIFAR10(nn.Module): |
| def __init__(self, num_classes=10): |
| super().__init__() |
|
|
| |
| self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(64) |
|
|
| |
| self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1) |
| self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2) |
| self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2) |
| self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2) |
|
|
| self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) |
| self.fc = nn.Sequential( |
| nn.Dropout(0.2), |
| nn.Linear(512 * BasicBlock.expansion, num_classes) |
| ) |
|
|
| def _make_layer(self, in_c, out_c, num_blocks, stride): |
| layers = [] |
| layers.append(BasicBlock(in_c, out_c, stride)) |
| for _ in range(1, num_blocks): |
| layers.append(BasicBlock(out_c, out_c, stride=1)) |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
|
|
| out = self.layer1(out) |
| out = self.layer2(out) |
| out = self.layer3(out) |
| out = self.layer4(out) |
|
|
| out = self.avg_pool(out) |
| out = torch.flatten(out, 1) |
| out = self.fc(out) |
| return out |
|
|
|
|
|
|
| class CIFARCNN(L.LightningModule): |
| def __init__(self, lr=1e-3): |
| super().__init__() |
| self.save_hyperparameters() |
| self.example_input_array = torch.Tensor(64, 3, 32, 32) |
| |
| self.net = ResNet18_CIFAR10(num_classes=10) |
| |
| self.loss_fn = nn.CrossEntropyLoss() |
| |
| def forward(self, x): |
| return self.net(x) |
|
|
| def training_step(self, batch, batch_idx): |
| x, y = batch |
| logits = self(x) |
| loss = self.loss_fn(logits, y) |
| |
| preds = torch.argmax(logits, dim=1) |
| acc = (preds == y).float().mean() |
| |
| self.log("train_loss", loss, on_step=True, prog_bar=True) |
| self.log("train_acc", acc, on_step=True, prog_bar=True) |
| return loss |
| |
| |
| def validation_step(self, batch, batch_idx): |
| x, y = batch |
| logits = self(x) |
| loss = self.loss_fn(logits, y) |
|
|
| preds = torch.argmax(logits, dim=1) |
| acc = (preds == y).float().mean() |
|
|
| |
| self.log("val_loss", loss, prog_bar=True, sync_dist=True) |
| self.log("val_acc", acc, prog_bar=True, sync_dist=True) |
|
|
| return {"val_loss": loss, "val_acc": acc} |
|
|
| def test_step(self, batch, batch_idx): |
| x, y = batch |
| logits = self(x) |
| loss = self.loss_fn(logits, y) |
|
|
| preds = torch.argmax(logits, dim=1) |
| acc = (preds == y).float().mean() |
|
|
| self.log("test_loss", loss, prog_bar=True) |
| self.log("test_acc", acc, prog_bar=True) |
|
|
| return {"test_loss": loss, "test_acc": acc} |
| |
| def predict_step(self, batch, batch_idx, dataloader_idx=0): |
| x, _ = batch |
| return self(x) |
|
|
| def configure_optimizers(self): |
| optimizer = torch.optim.SGD( |
| self.parameters(), |
| lr=self.hparams.lr, |
| momentum=0.9, |
| weight_decay=5e-4 |
| ) |
| |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=self.trainer.max_epochs |
| ) |
| return {"optimizer": optimizer, "lr_scheduler": scheduler} |
|
|
|
|
|
|
| if __name__ == "__main__": |
| |
| model = CIFARCNN() |
| x = torch.randn(4, 3, 32, 32).to(model.device) |
| logits = model(x) |
| print(logits.shape) |