| """ |
| mini-style-transfer β PyTorch style filter model |
| Author: your-username |
| HuggingFace: huggingface.co/your-username/mini-style-transfer |
| |
| Architecture: Feed-forward CNN (Johnson et al. 2016) |
| - No slow per-image optimisation β runs in under 1 second |
| - One model file per style (starry, mosaic, candy, sketch) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| |
| |
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.ReflectionPad2d(1), |
| nn.Conv2d(channels, channels, kernel_size=3), |
| nn.InstanceNorm2d(channels), |
| nn.ReLU(inplace=True), |
| nn.ReflectionPad2d(1), |
| nn.Conv2d(channels, channels, kernel_size=3), |
| nn.InstanceNorm2d(channels), |
| ) |
|
|
| def forward(self, x): |
| return x + self.block(x) |
|
|
|
|
| |
| |
| |
| |
|
|
| class StyleNet(nn.Module): |
| def __init__(self, num_residual_blocks=5): |
| super().__init__() |
|
|
| |
| self.encoder = nn.Sequential( |
| nn.ReflectionPad2d(4), |
| nn.Conv2d(3, 32, kernel_size=9, stride=1), |
| nn.InstanceNorm2d(32), |
| nn.ReLU(inplace=True), |
|
|
| nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), |
| nn.InstanceNorm2d(64), |
| nn.ReLU(inplace=True), |
|
|
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
| nn.InstanceNorm2d(128), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| |
| self.residuals = nn.Sequential( |
| *[ResidualBlock(128) for _ in range(num_residual_blocks)] |
| ) |
|
|
| |
| self.decoder = nn.Sequential( |
| nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), |
| nn.InstanceNorm2d(64), |
| nn.ReLU(inplace=True), |
|
|
| nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), |
| nn.InstanceNorm2d(32), |
| nn.ReLU(inplace=True), |
|
|
| nn.ReflectionPad2d(4), |
| nn.Conv2d(32, 3, kernel_size=9, stride=1), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward(self, x): |
| x = self.encoder(x) |
| x = self.residuals(x) |
| x = self.decoder(x) |
| return x |
|
|
|
|
| |
| if __name__ == "__main__": |
| model = StyleNet() |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"StyleNet ready β {total_params:,} parameters ({total_params/1e6:.1f}M)") |
|
|
| |
| dummy = torch.randn(1, 3, 512, 512) |
| with torch.no_grad(): |
| out = model(dummy) |
| print(f"Input: {tuple(dummy.shape)}") |
| print(f"Output: {tuple(out.shape)}") |
| print("Model works correctly!") |
|
|