""" mini-style-transfer — PyTorch style filter model Author: your-username HuggingFace: huggingface.co/your-username/mini-style-transfer Architecture: Feed-forward CNN (Johnson et al. 2016) - No slow per-image optimisation — runs in under 1 second - One model file per style (starry, mosaic, candy, sketch) """ import torch import torch.nn as nn # ── Residual Block ──────────────────────────────────────────────────────────── # The core building block. Learns fine style details without losing content. class ResidualBlock(nn.Module): def __init__(self, channels): super().__init__() self.block = nn.Sequential( nn.ReflectionPad2d(1), # padding that avoids edge artifacts nn.Conv2d(channels, channels, kernel_size=3), nn.InstanceNorm2d(channels), # normalise per-image (better for style) nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, kernel_size=3), nn.InstanceNorm2d(channels), ) def forward(self, x): return x + self.block(x) # skip connection — keeps original content # ── StyleNet ────────────────────────────────────────────────────────────────── # Full model: Encoder → Residual blocks → Decoder # Input: (B, 3, H, W) — any image size # Output: (B, 3, H, W) — same size, styled class StyleNet(nn.Module): def __init__(self, num_residual_blocks=5): super().__init__() # Encoder: shrinks image, learns features self.encoder = nn.Sequential( nn.ReflectionPad2d(4), nn.Conv2d(3, 32, kernel_size=9, stride=1), # 32 feature maps nn.InstanceNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # downsample nn.InstanceNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # downsample nn.InstanceNorm2d(128), nn.ReLU(inplace=True), ) # Residual blocks: learn style patterns at compressed resolution (fast!) self.residuals = nn.Sequential( *[ResidualBlock(128) for _ in range(num_residual_blocks)] ) # Decoder: upscale back to original resolution with style applied self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(64), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(32), nn.ReLU(inplace=True), nn.ReflectionPad2d(4), nn.Conv2d(32, 3, kernel_size=9, stride=1), # back to 3 colour channels nn.Sigmoid(), # pixel values → 0–1 range ) def forward(self, x): x = self.encoder(x) x = self.residuals(x) x = self.decoder(x) return x # ── Quick test ──────────────────────────────────────────────────────────────── if __name__ == "__main__": model = StyleNet() total_params = sum(p.numel() for p in model.parameters()) print(f"StyleNet ready — {total_params:,} parameters ({total_params/1e6:.1f}M)") # Test with a dummy 512x512 image dummy = torch.randn(1, 3, 512, 512) with torch.no_grad(): out = model(dummy) print(f"Input: {tuple(dummy.shape)}") print(f"Output: {tuple(out.shape)}") print("Model works correctly!")