Ateshh commited on
Commit
626b231
Β·
verified Β·
1 Parent(s): 902fb7b

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +158 -3
  2. model.py +99 -0
  3. requirements.txt +3 -0
  4. run.py +60 -0
  5. train.py +165 -0
README.md CHANGED
@@ -1,3 +1,158 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - image-to-image
5
+ - style-transfer
6
+ - pytorch
7
+ - beginner
8
+ - fast-inference
9
+ pipeline_tag: image-to-image
10
+ datasets:
11
+ - coco
12
+ metrics:
13
+ - perceptual-loss
14
+ ---
15
+
16
+ # mini-style-transfer
17
+
18
+ A small, fast artistic style transfer model built with PyTorch as a learning project.
19
+ Applies 4 artistic styles to any photo in **under 1 second on CPU**.
20
+
21
+ Based on [Johnson et al. (2016) β€” Perceptual Losses for Real-Time Style Transfer](https://arxiv.org/abs/1603.08155).
22
+
23
+ ---
24
+
25
+ ## What it does
26
+
27
+ | Input photo | + Style painting | β†’ Output |
28
+ |---|---|---|
29
+ | Any photo (any size) | Starry Night / Mosaic / Candy / Sketch | Stylised version |
30
+
31
+ ---
32
+
33
+ ## Styles available
34
+
35
+ | File | Style |
36
+ |---|---|
37
+ | `starry_night.pth` | Van Gogh β€” Starry Night |
38
+ | `mosaic.pth` | Classic mosaic tile pattern |
39
+ | `candy.pth` | Bright candy colours |
40
+ | `sketch.pth` | Pencil sketch look |
41
+
42
+ ---
43
+
44
+ ## Quick start
45
+
46
+ ```python
47
+ import torch
48
+ from torchvision import transforms
49
+ from PIL import Image
50
+ from model import StyleNet
51
+
52
+ # 1. Load model
53
+ model = StyleNet()
54
+ model.load_state_dict(torch.load("starry_night.pth", map_location="cpu"))
55
+ model.eval()
56
+
57
+ # 2. Prepare your image
58
+ img = Image.open("my_photo.jpg").convert("RGB")
59
+ to_tensor = transforms.Compose([
60
+ transforms.ToTensor(),
61
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
62
+ std=[0.229, 0.224, 0.225]),
63
+ ])
64
+ tensor = to_tensor(img).unsqueeze(0)
65
+
66
+ # 3. Run inference
67
+ with torch.no_grad():
68
+ output = model(tensor).squeeze(0).clamp(0, 1)
69
+
70
+ # 4. Save result
71
+ result = transforms.ToPILImage()(output)
72
+ result.save("styled_output.jpg")
73
+ print("Done! Open styled_output.jpg")
74
+ ```
75
+
76
+ Or use the included `run.py` script:
77
+
78
+ ```bash
79
+ python run.py --model starry_night.pth --input my_photo.jpg --output result.jpg
80
+ ```
81
+
82
+ ---
83
+
84
+ ## Model details
85
+
86
+ | Property | Value |
87
+ |---|---|
88
+ | Architecture | Feed-forward CNN (Encoder β†’ 5Γ— ResBlock β†’ Decoder) |
89
+ | Parameters | ~450K |
90
+ | Model size | ~1.7 MB per style |
91
+ | Input | Any RGB image, any resolution |
92
+ | Output | Same size as input, styled |
93
+ | Framework | PyTorch 2.x |
94
+ | Normalisation | ImageNet mean/std |
95
+
96
+ ---
97
+
98
+ ## Training details
99
+
100
+ | Property | Value |
101
+ |---|---|
102
+ | Content dataset | MS-COCO train2017 (subset) |
103
+ | Style images | 4 artwork images |
104
+ | Epochs | 2 per style |
105
+ | Batch size | 4 |
106
+ | Image size (training) | 256 Γ— 256 |
107
+ | Optimizer | Adam, lr=1e-3 |
108
+ | Loss | Perceptual (VGG16) β€” content + style |
109
+ | Content weight | 1.0 |
110
+ | Style weight | 1e5 |
111
+ | Training time | ~45 min per style (GPU) |
112
+
113
+ ---
114
+
115
+ ## Repository structure
116
+
117
+ ```
118
+ mini-style-transfer/
119
+ β”œβ”€β”€ model.py ← StyleNet architecture
120
+ β”œβ”€β”€ train.py ← Training script
121
+ β”œβ”€β”€ run.py ← Inference script
122
+ β”œβ”€β”€ starry_night.pth ← Trained weights (starry night style)
123
+ β”œβ”€β”€ mosaic.pth ← Trained weights (mosaic style)
124
+ β”œβ”€β”€ candy.pth ← Trained weights (candy style)
125
+ β”œβ”€β”€ sketch.pth ← Trained weights (sketch style)
126
+ └── README.md ← This file
127
+ ```
128
+
129
+ ---
130
+
131
+ ## Limitations
132
+
133
+ - Each style is a **separate model file** β€” there is no single multi-style model yet
134
+ - Works best on **natural photos** (landscapes, portraits, cities)
135
+ - Cartoons, diagrams, and text-heavy images may give unexpected results
136
+ - Training images were 256Γ—256; very high-resolution outputs may look slightly blurry
137
+ - Not suitable for commercial use without further evaluation
138
+
139
+ ---
140
+
141
+ ## What I learned building this
142
+
143
+ - How **convolutional encoders and decoders** work together
144
+ - What **Instance Normalisation** does vs Batch Normalisation
145
+ - How **Gram matrices** capture texture and style
146
+ - What **perceptual loss** is and why pixel-level loss looks bad for style transfer
147
+ - How to use a **pretrained VGG** network as a feature extractor without training it
148
+
149
+ ---
150
+
151
+ ## References
152
+
153
+ - Johnson, J., Alahi, A., & Fei-Fei, L. (2016). [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155)
154
+ - Gatys, L., Ecker, A., & Bethge, M. (2015). [A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576)
155
+
156
+ ---
157
+
158
+ *Built as a learning project. Feedback and suggestions welcome!*
model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ mini-style-transfer β€” PyTorch style filter model
3
+ Author: your-username
4
+ HuggingFace: huggingface.co/your-username/mini-style-transfer
5
+
6
+ Architecture: Feed-forward CNN (Johnson et al. 2016)
7
+ - No slow per-image optimisation β€” runs in under 1 second
8
+ - One model file per style (starry, mosaic, candy, sketch)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ # ── Residual Block ────────────────────────────────────────────────────────────
16
+ # The core building block. Learns fine style details without losing content.
17
+
18
+ class ResidualBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.block = nn.Sequential(
22
+ nn.ReflectionPad2d(1), # padding that avoids edge artifacts
23
+ nn.Conv2d(channels, channels, kernel_size=3),
24
+ nn.InstanceNorm2d(channels), # normalise per-image (better for style)
25
+ nn.ReLU(inplace=True),
26
+ nn.ReflectionPad2d(1),
27
+ nn.Conv2d(channels, channels, kernel_size=3),
28
+ nn.InstanceNorm2d(channels),
29
+ )
30
+
31
+ def forward(self, x):
32
+ return x + self.block(x) # skip connection β€” keeps original content
33
+
34
+
35
+ # ── StyleNet ──────────────────────────────────────────────────────────────────
36
+ # Full model: Encoder β†’ Residual blocks β†’ Decoder
37
+ # Input: (B, 3, H, W) β€” any image size
38
+ # Output: (B, 3, H, W) β€” same size, styled
39
+
40
+ class StyleNet(nn.Module):
41
+ def __init__(self, num_residual_blocks=5):
42
+ super().__init__()
43
+
44
+ # Encoder: shrinks image, learns features
45
+ self.encoder = nn.Sequential(
46
+ nn.ReflectionPad2d(4),
47
+ nn.Conv2d(3, 32, kernel_size=9, stride=1), # 32 feature maps
48
+ nn.InstanceNorm2d(32),
49
+ nn.ReLU(inplace=True),
50
+
51
+ nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # downsample
52
+ nn.InstanceNorm2d(64),
53
+ nn.ReLU(inplace=True),
54
+
55
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # downsample
56
+ nn.InstanceNorm2d(128),
57
+ nn.ReLU(inplace=True),
58
+ )
59
+
60
+ # Residual blocks: learn style patterns at compressed resolution (fast!)
61
+ self.residuals = nn.Sequential(
62
+ *[ResidualBlock(128) for _ in range(num_residual_blocks)]
63
+ )
64
+
65
+ # Decoder: upscale back to original resolution with style applied
66
+ self.decoder = nn.Sequential(
67
+ nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
68
+ nn.InstanceNorm2d(64),
69
+ nn.ReLU(inplace=True),
70
+
71
+ nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
72
+ nn.InstanceNorm2d(32),
73
+ nn.ReLU(inplace=True),
74
+
75
+ nn.ReflectionPad2d(4),
76
+ nn.Conv2d(32, 3, kernel_size=9, stride=1), # back to 3 colour channels
77
+ nn.Sigmoid(), # pixel values β†’ 0–1 range
78
+ )
79
+
80
+ def forward(self, x):
81
+ x = self.encoder(x)
82
+ x = self.residuals(x)
83
+ x = self.decoder(x)
84
+ return x
85
+
86
+
87
+ # ── Quick test ────────────────────────────────────────────────────────────────
88
+ if __name__ == "__main__":
89
+ model = StyleNet()
90
+ total_params = sum(p.numel() for p in model.parameters())
91
+ print(f"StyleNet ready β€” {total_params:,} parameters ({total_params/1e6:.1f}M)")
92
+
93
+ # Test with a dummy 512x512 image
94
+ dummy = torch.randn(1, 3, 512, 512)
95
+ with torch.no_grad():
96
+ out = model(dummy)
97
+ print(f"Input: {tuple(dummy.shape)}")
98
+ print(f"Output: {tuple(out.shape)}")
99
+ print("Model works correctly!")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ Pillow>=9.0.0
run.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ run.py β€” Apply your trained style to any photo
3
+
4
+ Usage:
5
+ python run.py --model starry_night.pth --input my_photo.jpg --output result.jpg
6
+ python run.py --model mosaic.pth --input my_photo.jpg --output result.jpg
7
+
8
+ No GPU needed β€” runs on CPU in under 1 second.
9
+ """
10
+
11
+ import torch
12
+ from torchvision import transforms
13
+ from PIL import Image
14
+ import argparse
15
+ from model import StyleNet
16
+
17
+
18
+ def stylize(model_path, input_path, output_path):
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Running on: {device}")
21
+
22
+ # Load trained model
23
+ model = StyleNet()
24
+ model.load_state_dict(torch.load(model_path, map_location=device))
25
+ model.eval()
26
+ model.to(device)
27
+
28
+ # Load and prepare input image
29
+ img = Image.open(input_path).convert("RGB")
30
+ original_size = img.size # save so we can restore it at the end
31
+ print(f"Input image: {input_path} ({img.width}x{img.height})")
32
+
33
+ to_tensor = transforms.Compose([
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
36
+ std=[0.229, 0.224, 0.225]),
37
+ ])
38
+ tensor = to_tensor(img).unsqueeze(0).to(device) # shape: [1, 3, H, W]
39
+
40
+ # Run inference
41
+ with torch.no_grad():
42
+ output = model(tensor).squeeze(0).clamp(0, 1) # shape: [3, H, W]
43
+
44
+ # Convert back to PIL image and save
45
+ to_pil = transforms.ToPILImage()
46
+ result = to_pil(output)
47
+ result = result.resize(original_size, Image.LANCZOS) # restore original size
48
+ result.save(output_path, quality=95)
49
+
50
+ print(f"Styled image saved to: {output_path}")
51
+ print("Open the file to see your result!")
52
+
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument("--model", required=True, help="Path to your .pth model file")
57
+ parser.add_argument("--input", required=True, help="Path to your input photo")
58
+ parser.add_argument("--output", default="output.jpg", help="Where to save the result")
59
+ args = parser.parse_args()
60
+ stylize(args.model, args.input, args.output)
train.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train.py β€” Train your mini-style-transfer model
3
+
4
+ Usage:
5
+ python train.py --style starry_night.jpg --output starry_night.pth
6
+
7
+ What this script does:
8
+ 1. Loads your style image (the painting)
9
+ 2. Loops over MS-COCO images (content images β€” everyday photos)
10
+ 3. For each photo: runs it through StyleNet, compares result to style
11
+ 4. Updates model weights so outputs look more like the style painting
12
+ 5. Saves your trained model as a .pth file
13
+
14
+ Beginner tip: Think of training as teaching the model by example.
15
+ You show it thousands of photos and say "make them look like Van Gogh".
16
+ After enough examples, it learns to do it on its own.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+ from torchvision import transforms, models
23
+ from torch.utils.data import DataLoader, Dataset
24
+ from PIL import Image
25
+ import os
26
+ import argparse
27
+ from model import StyleNet
28
+
29
+
30
+ # ── Settings ──────────────────────────────────────────────────────────────────
31
+
32
+ IMAGE_SIZE = 256 # train on 256x256 (faster); can run inference at any size
33
+ BATCH_SIZE = 4
34
+ EPOCHS = 2 # 2 epochs is enough for a recognisable style
35
+ LR = 1e-3
36
+ CONTENT_W = 1.0 # how much to preserve original content
37
+ STYLE_W = 1e5 # how strongly to apply the style (very high is normal)
38
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+
41
+ # ── Dataset ───────────────────────────────────────────────────────────────────
42
+
43
+ class ImageFolderDataset(Dataset):
44
+ """Loads all images from a folder. Use MS-COCO train2017 images."""
45
+ def __init__(self, folder, transform):
46
+ self.paths = [
47
+ os.path.join(folder, f) for f in os.listdir(folder)
48
+ if f.lower().endswith(('.jpg', '.jpeg', '.png'))
49
+ ]
50
+ self.transform = transform
51
+
52
+ def __len__(self):
53
+ return len(self.paths)
54
+
55
+ def __getitem__(self, idx):
56
+ img = Image.open(self.paths[idx]).convert("RGB")
57
+ return self.transform(img)
58
+
59
+
60
+ # ── Perceptual Loss (VGG16) ───────────────────────────────────────────────────
61
+ # Instead of comparing pixels directly, we compare how images "feel"
62
+ # using a pretrained VGG network. This is what makes the style look good.
63
+
64
+ class VGGLoss(nn.Module):
65
+ def __init__(self):
66
+ super().__init__()
67
+ vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features
68
+ # relu2_2 for content, relu1_2 + relu2_2 + relu3_3 for style
69
+ self.slice1 = nn.Sequential(*list(vgg)[:4]).eval() # relu1_2
70
+ self.slice2 = nn.Sequential(*list(vgg)[4:9]).eval() # relu2_2 ← content
71
+ self.slice3 = nn.Sequential(*list(vgg)[9:16]).eval() # relu3_3
72
+ for p in self.parameters():
73
+ p.requires_grad = False
74
+
75
+ def forward(self, x):
76
+ h1 = self.slice1(x)
77
+ h2 = self.slice2(h1)
78
+ h3 = self.slice3(h2)
79
+ return h1, h2, h3
80
+
81
+ def gram_matrix(feat):
82
+ """Style is captured as correlations between feature maps (Gram matrix)."""
83
+ B, C, H, W = feat.shape
84
+ feat = feat.view(B, C, H * W)
85
+ return torch.bmm(feat, feat.transpose(1, 2)) / (C * H * W)
86
+
87
+
88
+ # ── Training loop ─────────────────────────────────────────────────────────────
89
+
90
+ def train(style_image_path, content_folder, output_path):
91
+ print(f"Device: {DEVICE}")
92
+ print(f"Style: {style_image_path}")
93
+ print(f"Output: {output_path}\n")
94
+
95
+ transform = transforms.Compose([
96
+ transforms.Resize(IMAGE_SIZE),
97
+ transforms.CenterCrop(IMAGE_SIZE),
98
+ transforms.ToTensor(),
99
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
100
+ std=[0.229, 0.224, 0.225]),
101
+ ])
102
+
103
+ # Load style image and precompute its Gram matrices (done once)
104
+ style_img = transform(Image.open(style_image_path).convert("RGB"))
105
+ style_img = style_img.unsqueeze(0).to(DEVICE)
106
+
107
+ dataset = ImageFolderDataset(content_folder, transform)
108
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
109
+
110
+ model = StyleNet().to(DEVICE)
111
+ vgg = VGGLoss().to(DEVICE)
112
+ optimizer = optim.Adam(model.parameters(), lr=LR)
113
+ mse = nn.MSELoss()
114
+
115
+ # Precompute style Gram matrices
116
+ with torch.no_grad():
117
+ s1, s2, s3 = vgg(style_img)
118
+ style_grams = [gram_matrix(s1), gram_matrix(s2), gram_matrix(s3)]
119
+
120
+ print(f"Training on {len(dataset)} images for {EPOCHS} epochs...")
121
+ print("─" * 50)
122
+
123
+ for epoch in range(EPOCHS):
124
+ for i, content in enumerate(loader):
125
+ content = content.to(DEVICE)
126
+ optimizer.zero_grad()
127
+
128
+ # Forward pass
129
+ styled = model(content)
130
+
131
+ # Content loss β€” styled image should still look like the photo
132
+ _, c_feat, _ = vgg(content)
133
+ _, s_feat, _ = vgg(styled)
134
+ content_loss = mse(s_feat, c_feat.detach())
135
+
136
+ # Style loss β€” styled image should look like the painting
137
+ o1, o2, o3 = vgg(styled)
138
+ style_loss = (
139
+ mse(gram_matrix(o1), style_grams[0].expand(content.size(0), -1, -1)) +
140
+ mse(gram_matrix(o2), style_grams[1].expand(content.size(0), -1, -1)) +
141
+ mse(gram_matrix(o3), style_grams[2].expand(content.size(0), -1, -1))
142
+ )
143
+
144
+ loss = CONTENT_W * content_loss + STYLE_W * style_loss
145
+ loss.backward()
146
+ optimizer.step()
147
+
148
+ if i % 100 == 0:
149
+ print(f"Epoch {epoch+1}/{EPOCHS} Batch {i:4d}/{len(loader)}"
150
+ f" Loss: {loss.item():.2f}"
151
+ f" (content {content_loss.item():.3f}"
152
+ f" style {style_loss.item():.2f})")
153
+
154
+ torch.save(model.state_dict(), output_path)
155
+ print(f"\nDone! Model saved to: {output_path}")
156
+ print(f"Upload to HuggingFace: huggingface-cli upload your-username/mini-style-transfer {output_path}")
157
+
158
+
159
+ if __name__ == "__main__":
160
+ parser = argparse.ArgumentParser()
161
+ parser.add_argument("--style", required=True, help="Path to your style painting image")
162
+ parser.add_argument("--content", default="coco/", help="Folder of training photos (MS-COCO)")
163
+ parser.add_argument("--output", default="style_model.pth", help="Output .pth file name")
164
+ args = parser.parse_args()
165
+ train(args.style, args.content, args.output)