Upload 5 files
Browse files- README.md +158 -3
- model.py +99 -0
- requirements.txt +3 -0
- run.py +60 -0
- train.py +165 -0
README.md
CHANGED
|
@@ -1,3 +1,158 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- image-to-image
|
| 5 |
+
- style-transfer
|
| 6 |
+
- pytorch
|
| 7 |
+
- beginner
|
| 8 |
+
- fast-inference
|
| 9 |
+
pipeline_tag: image-to-image
|
| 10 |
+
datasets:
|
| 11 |
+
- coco
|
| 12 |
+
metrics:
|
| 13 |
+
- perceptual-loss
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# mini-style-transfer
|
| 17 |
+
|
| 18 |
+
A small, fast artistic style transfer model built with PyTorch as a learning project.
|
| 19 |
+
Applies 4 artistic styles to any photo in **under 1 second on CPU**.
|
| 20 |
+
|
| 21 |
+
Based on [Johnson et al. (2016) β Perceptual Losses for Real-Time Style Transfer](https://arxiv.org/abs/1603.08155).
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## What it does
|
| 26 |
+
|
| 27 |
+
| Input photo | + Style painting | β Output |
|
| 28 |
+
|---|---|---|
|
| 29 |
+
| Any photo (any size) | Starry Night / Mosaic / Candy / Sketch | Stylised version |
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## Styles available
|
| 34 |
+
|
| 35 |
+
| File | Style |
|
| 36 |
+
|---|---|
|
| 37 |
+
| `starry_night.pth` | Van Gogh β Starry Night |
|
| 38 |
+
| `mosaic.pth` | Classic mosaic tile pattern |
|
| 39 |
+
| `candy.pth` | Bright candy colours |
|
| 40 |
+
| `sketch.pth` | Pencil sketch look |
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Quick start
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
import torch
|
| 48 |
+
from torchvision import transforms
|
| 49 |
+
from PIL import Image
|
| 50 |
+
from model import StyleNet
|
| 51 |
+
|
| 52 |
+
# 1. Load model
|
| 53 |
+
model = StyleNet()
|
| 54 |
+
model.load_state_dict(torch.load("starry_night.pth", map_location="cpu"))
|
| 55 |
+
model.eval()
|
| 56 |
+
|
| 57 |
+
# 2. Prepare your image
|
| 58 |
+
img = Image.open("my_photo.jpg").convert("RGB")
|
| 59 |
+
to_tensor = transforms.Compose([
|
| 60 |
+
transforms.ToTensor(),
|
| 61 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 62 |
+
std=[0.229, 0.224, 0.225]),
|
| 63 |
+
])
|
| 64 |
+
tensor = to_tensor(img).unsqueeze(0)
|
| 65 |
+
|
| 66 |
+
# 3. Run inference
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
output = model(tensor).squeeze(0).clamp(0, 1)
|
| 69 |
+
|
| 70 |
+
# 4. Save result
|
| 71 |
+
result = transforms.ToPILImage()(output)
|
| 72 |
+
result.save("styled_output.jpg")
|
| 73 |
+
print("Done! Open styled_output.jpg")
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
Or use the included `run.py` script:
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
python run.py --model starry_night.pth --input my_photo.jpg --output result.jpg
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## Model details
|
| 85 |
+
|
| 86 |
+
| Property | Value |
|
| 87 |
+
|---|---|
|
| 88 |
+
| Architecture | Feed-forward CNN (Encoder β 5Γ ResBlock β Decoder) |
|
| 89 |
+
| Parameters | ~450K |
|
| 90 |
+
| Model size | ~1.7 MB per style |
|
| 91 |
+
| Input | Any RGB image, any resolution |
|
| 92 |
+
| Output | Same size as input, styled |
|
| 93 |
+
| Framework | PyTorch 2.x |
|
| 94 |
+
| Normalisation | ImageNet mean/std |
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## Training details
|
| 99 |
+
|
| 100 |
+
| Property | Value |
|
| 101 |
+
|---|---|
|
| 102 |
+
| Content dataset | MS-COCO train2017 (subset) |
|
| 103 |
+
| Style images | 4 artwork images |
|
| 104 |
+
| Epochs | 2 per style |
|
| 105 |
+
| Batch size | 4 |
|
| 106 |
+
| Image size (training) | 256 Γ 256 |
|
| 107 |
+
| Optimizer | Adam, lr=1e-3 |
|
| 108 |
+
| Loss | Perceptual (VGG16) β content + style |
|
| 109 |
+
| Content weight | 1.0 |
|
| 110 |
+
| Style weight | 1e5 |
|
| 111 |
+
| Training time | ~45 min per style (GPU) |
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## Repository structure
|
| 116 |
+
|
| 117 |
+
```
|
| 118 |
+
mini-style-transfer/
|
| 119 |
+
βββ model.py β StyleNet architecture
|
| 120 |
+
βββ train.py β Training script
|
| 121 |
+
βββ run.py β Inference script
|
| 122 |
+
βββ starry_night.pth β Trained weights (starry night style)
|
| 123 |
+
βββ mosaic.pth β Trained weights (mosaic style)
|
| 124 |
+
βββ candy.pth β Trained weights (candy style)
|
| 125 |
+
βββ sketch.pth β Trained weights (sketch style)
|
| 126 |
+
βββ README.md β This file
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## Limitations
|
| 132 |
+
|
| 133 |
+
- Each style is a **separate model file** β there is no single multi-style model yet
|
| 134 |
+
- Works best on **natural photos** (landscapes, portraits, cities)
|
| 135 |
+
- Cartoons, diagrams, and text-heavy images may give unexpected results
|
| 136 |
+
- Training images were 256Γ256; very high-resolution outputs may look slightly blurry
|
| 137 |
+
- Not suitable for commercial use without further evaluation
|
| 138 |
+
|
| 139 |
+
---
|
| 140 |
+
|
| 141 |
+
## What I learned building this
|
| 142 |
+
|
| 143 |
+
- How **convolutional encoders and decoders** work together
|
| 144 |
+
- What **Instance Normalisation** does vs Batch Normalisation
|
| 145 |
+
- How **Gram matrices** capture texture and style
|
| 146 |
+
- What **perceptual loss** is and why pixel-level loss looks bad for style transfer
|
| 147 |
+
- How to use a **pretrained VGG** network as a feature extractor without training it
|
| 148 |
+
|
| 149 |
+
---
|
| 150 |
+
|
| 151 |
+
## References
|
| 152 |
+
|
| 153 |
+
- Johnson, J., Alahi, A., & Fei-Fei, L. (2016). [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155)
|
| 154 |
+
- Gatys, L., Ecker, A., & Bethge, M. (2015). [A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576)
|
| 155 |
+
|
| 156 |
+
---
|
| 157 |
+
|
| 158 |
+
*Built as a learning project. Feedback and suggestions welcome!*
|
model.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
mini-style-transfer β PyTorch style filter model
|
| 3 |
+
Author: your-username
|
| 4 |
+
HuggingFace: huggingface.co/your-username/mini-style-transfer
|
| 5 |
+
|
| 6 |
+
Architecture: Feed-forward CNN (Johnson et al. 2016)
|
| 7 |
+
- No slow per-image optimisation β runs in under 1 second
|
| 8 |
+
- One model file per style (starry, mosaic, candy, sketch)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ββ Residual Block ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
+
# The core building block. Learns fine style details without losing content.
|
| 17 |
+
|
| 18 |
+
class ResidualBlock(nn.Module):
|
| 19 |
+
def __init__(self, channels):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.block = nn.Sequential(
|
| 22 |
+
nn.ReflectionPad2d(1), # padding that avoids edge artifacts
|
| 23 |
+
nn.Conv2d(channels, channels, kernel_size=3),
|
| 24 |
+
nn.InstanceNorm2d(channels), # normalise per-image (better for style)
|
| 25 |
+
nn.ReLU(inplace=True),
|
| 26 |
+
nn.ReflectionPad2d(1),
|
| 27 |
+
nn.Conv2d(channels, channels, kernel_size=3),
|
| 28 |
+
nn.InstanceNorm2d(channels),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
return x + self.block(x) # skip connection β keeps original content
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ββ StyleNet ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
# Full model: Encoder β Residual blocks β Decoder
|
| 37 |
+
# Input: (B, 3, H, W) β any image size
|
| 38 |
+
# Output: (B, 3, H, W) β same size, styled
|
| 39 |
+
|
| 40 |
+
class StyleNet(nn.Module):
|
| 41 |
+
def __init__(self, num_residual_blocks=5):
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
# Encoder: shrinks image, learns features
|
| 45 |
+
self.encoder = nn.Sequential(
|
| 46 |
+
nn.ReflectionPad2d(4),
|
| 47 |
+
nn.Conv2d(3, 32, kernel_size=9, stride=1), # 32 feature maps
|
| 48 |
+
nn.InstanceNorm2d(32),
|
| 49 |
+
nn.ReLU(inplace=True),
|
| 50 |
+
|
| 51 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # downsample
|
| 52 |
+
nn.InstanceNorm2d(64),
|
| 53 |
+
nn.ReLU(inplace=True),
|
| 54 |
+
|
| 55 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # downsample
|
| 56 |
+
nn.InstanceNorm2d(128),
|
| 57 |
+
nn.ReLU(inplace=True),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Residual blocks: learn style patterns at compressed resolution (fast!)
|
| 61 |
+
self.residuals = nn.Sequential(
|
| 62 |
+
*[ResidualBlock(128) for _ in range(num_residual_blocks)]
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Decoder: upscale back to original resolution with style applied
|
| 66 |
+
self.decoder = nn.Sequential(
|
| 67 |
+
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 68 |
+
nn.InstanceNorm2d(64),
|
| 69 |
+
nn.ReLU(inplace=True),
|
| 70 |
+
|
| 71 |
+
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
|
| 72 |
+
nn.InstanceNorm2d(32),
|
| 73 |
+
nn.ReLU(inplace=True),
|
| 74 |
+
|
| 75 |
+
nn.ReflectionPad2d(4),
|
| 76 |
+
nn.Conv2d(32, 3, kernel_size=9, stride=1), # back to 3 colour channels
|
| 77 |
+
nn.Sigmoid(), # pixel values β 0β1 range
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
x = self.encoder(x)
|
| 82 |
+
x = self.residuals(x)
|
| 83 |
+
x = self.decoder(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ββ Quick test ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
model = StyleNet()
|
| 90 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 91 |
+
print(f"StyleNet ready β {total_params:,} parameters ({total_params/1e6:.1f}M)")
|
| 92 |
+
|
| 93 |
+
# Test with a dummy 512x512 image
|
| 94 |
+
dummy = torch.randn(1, 3, 512, 512)
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
out = model(dummy)
|
| 97 |
+
print(f"Input: {tuple(dummy.shape)}")
|
| 98 |
+
print(f"Output: {tuple(out.shape)}")
|
| 99 |
+
print("Model works correctly!")
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
Pillow>=9.0.0
|
run.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
run.py β Apply your trained style to any photo
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python run.py --model starry_night.pth --input my_photo.jpg --output result.jpg
|
| 6 |
+
python run.py --model mosaic.pth --input my_photo.jpg --output result.jpg
|
| 7 |
+
|
| 8 |
+
No GPU needed β runs on CPU in under 1 second.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import argparse
|
| 15 |
+
from model import StyleNet
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def stylize(model_path, input_path, output_path):
|
| 19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
print(f"Running on: {device}")
|
| 21 |
+
|
| 22 |
+
# Load trained model
|
| 23 |
+
model = StyleNet()
|
| 24 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 25 |
+
model.eval()
|
| 26 |
+
model.to(device)
|
| 27 |
+
|
| 28 |
+
# Load and prepare input image
|
| 29 |
+
img = Image.open(input_path).convert("RGB")
|
| 30 |
+
original_size = img.size # save so we can restore it at the end
|
| 31 |
+
print(f"Input image: {input_path} ({img.width}x{img.height})")
|
| 32 |
+
|
| 33 |
+
to_tensor = transforms.Compose([
|
| 34 |
+
transforms.ToTensor(),
|
| 35 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 36 |
+
std=[0.229, 0.224, 0.225]),
|
| 37 |
+
])
|
| 38 |
+
tensor = to_tensor(img).unsqueeze(0).to(device) # shape: [1, 3, H, W]
|
| 39 |
+
|
| 40 |
+
# Run inference
|
| 41 |
+
with torch.no_grad():
|
| 42 |
+
output = model(tensor).squeeze(0).clamp(0, 1) # shape: [3, H, W]
|
| 43 |
+
|
| 44 |
+
# Convert back to PIL image and save
|
| 45 |
+
to_pil = transforms.ToPILImage()
|
| 46 |
+
result = to_pil(output)
|
| 47 |
+
result = result.resize(original_size, Image.LANCZOS) # restore original size
|
| 48 |
+
result.save(output_path, quality=95)
|
| 49 |
+
|
| 50 |
+
print(f"Styled image saved to: {output_path}")
|
| 51 |
+
print("Open the file to see your result!")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
parser = argparse.ArgumentParser()
|
| 56 |
+
parser.add_argument("--model", required=True, help="Path to your .pth model file")
|
| 57 |
+
parser.add_argument("--input", required=True, help="Path to your input photo")
|
| 58 |
+
parser.add_argument("--output", default="output.jpg", help="Where to save the result")
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
stylize(args.model, args.input, args.output)
|
train.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train.py β Train your mini-style-transfer model
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python train.py --style starry_night.jpg --output starry_night.pth
|
| 6 |
+
|
| 7 |
+
What this script does:
|
| 8 |
+
1. Loads your style image (the painting)
|
| 9 |
+
2. Loops over MS-COCO images (content images β everyday photos)
|
| 10 |
+
3. For each photo: runs it through StyleNet, compares result to style
|
| 11 |
+
4. Updates model weights so outputs look more like the style painting
|
| 12 |
+
5. Saves your trained model as a .pth file
|
| 13 |
+
|
| 14 |
+
Beginner tip: Think of training as teaching the model by example.
|
| 15 |
+
You show it thousands of photos and say "make them look like Van Gogh".
|
| 16 |
+
After enough examples, it learns to do it on its own.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.optim as optim
|
| 22 |
+
from torchvision import transforms, models
|
| 23 |
+
from torch.utils.data import DataLoader, Dataset
|
| 24 |
+
from PIL import Image
|
| 25 |
+
import os
|
| 26 |
+
import argparse
|
| 27 |
+
from model import StyleNet
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ββ Settings ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
|
| 32 |
+
IMAGE_SIZE = 256 # train on 256x256 (faster); can run inference at any size
|
| 33 |
+
BATCH_SIZE = 4
|
| 34 |
+
EPOCHS = 2 # 2 epochs is enough for a recognisable style
|
| 35 |
+
LR = 1e-3
|
| 36 |
+
CONTENT_W = 1.0 # how much to preserve original content
|
| 37 |
+
STYLE_W = 1e5 # how strongly to apply the style (very high is normal)
|
| 38 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ββ Dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
|
| 43 |
+
class ImageFolderDataset(Dataset):
|
| 44 |
+
"""Loads all images from a folder. Use MS-COCO train2017 images."""
|
| 45 |
+
def __init__(self, folder, transform):
|
| 46 |
+
self.paths = [
|
| 47 |
+
os.path.join(folder, f) for f in os.listdir(folder)
|
| 48 |
+
if f.lower().endswith(('.jpg', '.jpeg', '.png'))
|
| 49 |
+
]
|
| 50 |
+
self.transform = transform
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return len(self.paths)
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, idx):
|
| 56 |
+
img = Image.open(self.paths[idx]).convert("RGB")
|
| 57 |
+
return self.transform(img)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ββ Perceptual Loss (VGG16) βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
# Instead of comparing pixels directly, we compare how images "feel"
|
| 62 |
+
# using a pretrained VGG network. This is what makes the style look good.
|
| 63 |
+
|
| 64 |
+
class VGGLoss(nn.Module):
|
| 65 |
+
def __init__(self):
|
| 66 |
+
super().__init__()
|
| 67 |
+
vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features
|
| 68 |
+
# relu2_2 for content, relu1_2 + relu2_2 + relu3_3 for style
|
| 69 |
+
self.slice1 = nn.Sequential(*list(vgg)[:4]).eval() # relu1_2
|
| 70 |
+
self.slice2 = nn.Sequential(*list(vgg)[4:9]).eval() # relu2_2 β content
|
| 71 |
+
self.slice3 = nn.Sequential(*list(vgg)[9:16]).eval() # relu3_3
|
| 72 |
+
for p in self.parameters():
|
| 73 |
+
p.requires_grad = False
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
h1 = self.slice1(x)
|
| 77 |
+
h2 = self.slice2(h1)
|
| 78 |
+
h3 = self.slice3(h2)
|
| 79 |
+
return h1, h2, h3
|
| 80 |
+
|
| 81 |
+
def gram_matrix(feat):
|
| 82 |
+
"""Style is captured as correlations between feature maps (Gram matrix)."""
|
| 83 |
+
B, C, H, W = feat.shape
|
| 84 |
+
feat = feat.view(B, C, H * W)
|
| 85 |
+
return torch.bmm(feat, feat.transpose(1, 2)) / (C * H * W)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ββ Training loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
+
|
| 90 |
+
def train(style_image_path, content_folder, output_path):
|
| 91 |
+
print(f"Device: {DEVICE}")
|
| 92 |
+
print(f"Style: {style_image_path}")
|
| 93 |
+
print(f"Output: {output_path}\n")
|
| 94 |
+
|
| 95 |
+
transform = transforms.Compose([
|
| 96 |
+
transforms.Resize(IMAGE_SIZE),
|
| 97 |
+
transforms.CenterCrop(IMAGE_SIZE),
|
| 98 |
+
transforms.ToTensor(),
|
| 99 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 100 |
+
std=[0.229, 0.224, 0.225]),
|
| 101 |
+
])
|
| 102 |
+
|
| 103 |
+
# Load style image and precompute its Gram matrices (done once)
|
| 104 |
+
style_img = transform(Image.open(style_image_path).convert("RGB"))
|
| 105 |
+
style_img = style_img.unsqueeze(0).to(DEVICE)
|
| 106 |
+
|
| 107 |
+
dataset = ImageFolderDataset(content_folder, transform)
|
| 108 |
+
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
|
| 109 |
+
|
| 110 |
+
model = StyleNet().to(DEVICE)
|
| 111 |
+
vgg = VGGLoss().to(DEVICE)
|
| 112 |
+
optimizer = optim.Adam(model.parameters(), lr=LR)
|
| 113 |
+
mse = nn.MSELoss()
|
| 114 |
+
|
| 115 |
+
# Precompute style Gram matrices
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
s1, s2, s3 = vgg(style_img)
|
| 118 |
+
style_grams = [gram_matrix(s1), gram_matrix(s2), gram_matrix(s3)]
|
| 119 |
+
|
| 120 |
+
print(f"Training on {len(dataset)} images for {EPOCHS} epochs...")
|
| 121 |
+
print("β" * 50)
|
| 122 |
+
|
| 123 |
+
for epoch in range(EPOCHS):
|
| 124 |
+
for i, content in enumerate(loader):
|
| 125 |
+
content = content.to(DEVICE)
|
| 126 |
+
optimizer.zero_grad()
|
| 127 |
+
|
| 128 |
+
# Forward pass
|
| 129 |
+
styled = model(content)
|
| 130 |
+
|
| 131 |
+
# Content loss β styled image should still look like the photo
|
| 132 |
+
_, c_feat, _ = vgg(content)
|
| 133 |
+
_, s_feat, _ = vgg(styled)
|
| 134 |
+
content_loss = mse(s_feat, c_feat.detach())
|
| 135 |
+
|
| 136 |
+
# Style loss β styled image should look like the painting
|
| 137 |
+
o1, o2, o3 = vgg(styled)
|
| 138 |
+
style_loss = (
|
| 139 |
+
mse(gram_matrix(o1), style_grams[0].expand(content.size(0), -1, -1)) +
|
| 140 |
+
mse(gram_matrix(o2), style_grams[1].expand(content.size(0), -1, -1)) +
|
| 141 |
+
mse(gram_matrix(o3), style_grams[2].expand(content.size(0), -1, -1))
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
loss = CONTENT_W * content_loss + STYLE_W * style_loss
|
| 145 |
+
loss.backward()
|
| 146 |
+
optimizer.step()
|
| 147 |
+
|
| 148 |
+
if i % 100 == 0:
|
| 149 |
+
print(f"Epoch {epoch+1}/{EPOCHS} Batch {i:4d}/{len(loader)}"
|
| 150 |
+
f" Loss: {loss.item():.2f}"
|
| 151 |
+
f" (content {content_loss.item():.3f}"
|
| 152 |
+
f" style {style_loss.item():.2f})")
|
| 153 |
+
|
| 154 |
+
torch.save(model.state_dict(), output_path)
|
| 155 |
+
print(f"\nDone! Model saved to: {output_path}")
|
| 156 |
+
print(f"Upload to HuggingFace: huggingface-cli upload your-username/mini-style-transfer {output_path}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
parser = argparse.ArgumentParser()
|
| 161 |
+
parser.add_argument("--style", required=True, help="Path to your style painting image")
|
| 162 |
+
parser.add_argument("--content", default="coco/", help="Folder of training photos (MS-COCO)")
|
| 163 |
+
parser.add_argument("--output", default="style_model.pth", help="Output .pth file name")
|
| 164 |
+
args = parser.parse_args()
|
| 165 |
+
train(args.style, args.content, args.output)
|