xiangzai's picture
Add files using upload-large-folder tool
3e4f775 verified
from torch import nn
class SimpleDenseNet(nn.Module):
def __init__(
self,
input_size: int = 784,
lin1_size: int = 256,
lin2_size: int = 256,
lin3_size: int = 256,
output_size: int = 10,
):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_size, lin1_size),
nn.BatchNorm1d(lin1_size),
nn.ReLU(),
nn.Linear(lin1_size, lin2_size),
nn.BatchNorm1d(lin2_size),
nn.ReLU(),
nn.Linear(lin2_size, lin3_size),
nn.BatchNorm1d(lin3_size),
nn.ReLU(),
nn.Linear(lin3_size, output_size),
)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (batch, 1, width, height) -> (batch, 1*width*height)
x = x.view(batch_size, -1)
return self.model(x)
if __name__ == "__main__":
_ = SimpleDenseNet()