Ateshh's picture
Upload 5 files
626b231 verified
"""
run.py — Apply your trained style to any photo
Usage:
python run.py --model starry_night.pth --input my_photo.jpg --output result.jpg
python run.py --model mosaic.pth --input my_photo.jpg --output result.jpg
No GPU needed — runs on CPU in under 1 second.
"""
import torch
from torchvision import transforms
from PIL import Image
import argparse
from model import StyleNet
def stylize(model_path, input_path, output_path):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")
# Load trained model
model = StyleNet()
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
model.to(device)
# Load and prepare input image
img = Image.open(input_path).convert("RGB")
original_size = img.size # save so we can restore it at the end
print(f"Input image: {input_path} ({img.width}x{img.height})")
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
tensor = to_tensor(img).unsqueeze(0).to(device) # shape: [1, 3, H, W]
# Run inference
with torch.no_grad():
output = model(tensor).squeeze(0).clamp(0, 1) # shape: [3, H, W]
# Convert back to PIL image and save
to_pil = transforms.ToPILImage()
result = to_pil(output)
result = result.resize(original_size, Image.LANCZOS) # restore original size
result.save(output_path, quality=95)
print(f"Styled image saved to: {output_path}")
print("Open the file to see your result!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Path to your .pth model file")
parser.add_argument("--input", required=True, help="Path to your input photo")
parser.add_argument("--output", default="output.jpg", help="Where to save the result")
args = parser.parse_args()
stylize(args.model, args.input, args.output)