Ateshh's picture
Upload 5 files
626b231 verified
"""
mini-style-transfer β€” PyTorch style filter model
Author: your-username
HuggingFace: huggingface.co/your-username/mini-style-transfer
Architecture: Feed-forward CNN (Johnson et al. 2016)
- No slow per-image optimisation β€” runs in under 1 second
- One model file per style (starry, mosaic, candy, sketch)
"""
import torch
import torch.nn as nn
# ── Residual Block ────────────────────────────────────────────────────────────
# The core building block. Learns fine style details without losing content.
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1), # padding that avoids edge artifacts
nn.Conv2d(channels, channels, kernel_size=3),
nn.InstanceNorm2d(channels), # normalise per-image (better for style)
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3),
nn.InstanceNorm2d(channels),
)
def forward(self, x):
return x + self.block(x) # skip connection β€” keeps original content
# ── StyleNet ──────────────────────────────────────────────────────────────────
# Full model: Encoder β†’ Residual blocks β†’ Decoder
# Input: (B, 3, H, W) β€” any image size
# Output: (B, 3, H, W) β€” same size, styled
class StyleNet(nn.Module):
def __init__(self, num_residual_blocks=5):
super().__init__()
# Encoder: shrinks image, learns features
self.encoder = nn.Sequential(
nn.ReflectionPad2d(4),
nn.Conv2d(3, 32, kernel_size=9, stride=1), # 32 feature maps
nn.InstanceNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # downsample
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # downsample
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True),
)
# Residual blocks: learn style patterns at compressed resolution (fast!)
self.residuals = nn.Sequential(
*[ResidualBlock(128) for _ in range(num_residual_blocks)]
)
# Decoder: upscale back to original resolution with style applied
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(32),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(4),
nn.Conv2d(32, 3, kernel_size=9, stride=1), # back to 3 colour channels
nn.Sigmoid(), # pixel values β†’ 0–1 range
)
def forward(self, x):
x = self.encoder(x)
x = self.residuals(x)
x = self.decoder(x)
return x
# ── Quick test ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
model = StyleNet()
total_params = sum(p.numel() for p in model.parameters())
print(f"StyleNet ready β€” {total_params:,} parameters ({total_params/1e6:.1f}M)")
# Test with a dummy 512x512 image
dummy = torch.randn(1, 3, 512, 512)
with torch.no_grad():
out = model(dummy)
print(f"Input: {tuple(dummy.shape)}")
print(f"Output: {tuple(out.shape)}")
print("Model works correctly!")