""" train.py — Train your mini-style-transfer model Usage: python train.py --style starry_night.jpg --output starry_night.pth What this script does: 1. Loads your style image (the painting) 2. Loops over MS-COCO images (content images — everyday photos) 3. For each photo: runs it through StyleNet, compares result to style 4. Updates model weights so outputs look more like the style painting 5. Saves your trained model as a .pth file Beginner tip: Think of training as teaching the model by example. You show it thousands of photos and say "make them look like Van Gogh". After enough examples, it learns to do it on its own. """ import torch import torch.nn as nn import torch.optim as optim from torchvision import transforms, models from torch.utils.data import DataLoader, Dataset from PIL import Image import os import argparse from model import StyleNet # ── Settings ────────────────────────────────────────────────────────────────── IMAGE_SIZE = 256 # train on 256x256 (faster); can run inference at any size BATCH_SIZE = 4 EPOCHS = 2 # 2 epochs is enough for a recognisable style LR = 1e-3 CONTENT_W = 1.0 # how much to preserve original content STYLE_W = 1e5 # how strongly to apply the style (very high is normal) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ── Dataset ─────────────────────────────────────────────────────────────────── class ImageFolderDataset(Dataset): """Loads all images from a folder. Use MS-COCO train2017 images.""" def __init__(self, folder, transform): self.paths = [ os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith(('.jpg', '.jpeg', '.png')) ] self.transform = transform def __len__(self): return len(self.paths) def __getitem__(self, idx): img = Image.open(self.paths[idx]).convert("RGB") return self.transform(img) # ── Perceptual Loss (VGG16) ─────────────────────────────────────────────────── # Instead of comparing pixels directly, we compare how images "feel" # using a pretrained VGG network. This is what makes the style look good. class VGGLoss(nn.Module): def __init__(self): super().__init__() vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features # relu2_2 for content, relu1_2 + relu2_2 + relu3_3 for style self.slice1 = nn.Sequential(*list(vgg)[:4]).eval() # relu1_2 self.slice2 = nn.Sequential(*list(vgg)[4:9]).eval() # relu2_2 ← content self.slice3 = nn.Sequential(*list(vgg)[9:16]).eval() # relu3_3 for p in self.parameters(): p.requires_grad = False def forward(self, x): h1 = self.slice1(x) h2 = self.slice2(h1) h3 = self.slice3(h2) return h1, h2, h3 def gram_matrix(feat): """Style is captured as correlations between feature maps (Gram matrix).""" B, C, H, W = feat.shape feat = feat.view(B, C, H * W) return torch.bmm(feat, feat.transpose(1, 2)) / (C * H * W) # ── Training loop ───────────────────────────────────────────────────────────── def train(style_image_path, content_folder, output_path): print(f"Device: {DEVICE}") print(f"Style: {style_image_path}") print(f"Output: {output_path}\n") transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.CenterCrop(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Load style image and precompute its Gram matrices (done once) style_img = transform(Image.open(style_image_path).convert("RGB")) style_img = style_img.unsqueeze(0).to(DEVICE) dataset = ImageFolderDataset(content_folder, transform) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) model = StyleNet().to(DEVICE) vgg = VGGLoss().to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=LR) mse = nn.MSELoss() # Precompute style Gram matrices with torch.no_grad(): s1, s2, s3 = vgg(style_img) style_grams = [gram_matrix(s1), gram_matrix(s2), gram_matrix(s3)] print(f"Training on {len(dataset)} images for {EPOCHS} epochs...") print("─" * 50) for epoch in range(EPOCHS): for i, content in enumerate(loader): content = content.to(DEVICE) optimizer.zero_grad() # Forward pass styled = model(content) # Content loss — styled image should still look like the photo _, c_feat, _ = vgg(content) _, s_feat, _ = vgg(styled) content_loss = mse(s_feat, c_feat.detach()) # Style loss — styled image should look like the painting o1, o2, o3 = vgg(styled) style_loss = ( mse(gram_matrix(o1), style_grams[0].expand(content.size(0), -1, -1)) + mse(gram_matrix(o2), style_grams[1].expand(content.size(0), -1, -1)) + mse(gram_matrix(o3), style_grams[2].expand(content.size(0), -1, -1)) ) loss = CONTENT_W * content_loss + STYLE_W * style_loss loss.backward() optimizer.step() if i % 100 == 0: print(f"Epoch {epoch+1}/{EPOCHS} Batch {i:4d}/{len(loader)}" f" Loss: {loss.item():.2f}" f" (content {content_loss.item():.3f}" f" style {style_loss.item():.2f})") torch.save(model.state_dict(), output_path) print(f"\nDone! Model saved to: {output_path}") print(f"Upload to HuggingFace: huggingface-cli upload your-username/mini-style-transfer {output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--style", required=True, help="Path to your style painting image") parser.add_argument("--content", default="coco/", help="Folder of training photos (MS-COCO)") parser.add_argument("--output", default="style_model.pth", help="Output .pth file name") args = parser.parse_args() train(args.style, args.content, args.output)