import torch import torch.nn as nn class Model(nn.Module): def __init__(self): super().__init__() self.conv_layers = nn.Sequential( # Block 1: 1 -> 32 channels, 28x28 -> 14x14 nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout2d(0.25), # Block 2: 32 -> 64 channels, 14x14 -> 7x7 nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout2d(0.25), # Block 3: 64 -> 128 channels, 7x7 -> 3x3 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout2d(0.25), # Block 3: 128 -> 256 channels, 3x3 -> 1x1 nn.Conv2d(128, 256, kernel_size=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout2d(0.25), ) self.fc_layers = nn.Sequential( nn.Flatten(), # 256 * 1 * 1 = 256 nn.Linear(256 * 1 * 1, 128), nn.ReLU(), nn.Dropout(0.25), nn.Linear(128, 10) ) def forward(self, x): x = self.conv_layers(x) x = self.fc_layers(x) return x