| import streamlit as st |
| import torch |
| from models_conv import ConvGenerator |
| import numpy as np |
|
|
| |
| st.set_page_config(page_title="MNIST Digit Generator", layout="centered") |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| generator = ConvGenerator().to(device) |
|
|
| |
| checkpoint = torch.load('checkpoints/wgan_checkpoint_epoch_190.pt', map_location=device) |
| generator.load_state_dict(checkpoint['generator_state_dict']) |
| generator.eval() |
|
|
| |
| st.title("MNIST Digit Generator") |
| st.write("Generate MNIST-like digits using a Wasserstein GAN") |
|
|
| |
| with st.sidebar: |
| st.header("Generation Controls") |
| noise_seed = st.slider("Noise Seed", 1, 1000, 42) |
| num_images = st.slider("Number of Images", 1, 16, 4) |
| generate_button = st.button("Generate New Images") |
|
|
| |
| def generate_images(noise_seed, num_images): |
| torch.manual_seed(noise_seed) |
| z = torch.randn(num_images, 100).to(device) |
| |
| with torch.no_grad(): |
| imgs = generator(z) |
| |
| |
| imgs = imgs.cpu().numpy() |
| |
| imgs = (imgs + 1) / 2 |
| return imgs |
|
|
| |
| if generate_button or 'generated_images' not in st.session_state: |
| images = generate_images(noise_seed, num_images) |
| st.session_state.generated_images = images |
| else: |
| images = st.session_state.generated_images |
|
|
| |
| cols = st.columns(min(4, num_images)) |
| for idx, img in enumerate(images): |
| with cols[idx % min(4, num_images)]: |
| st.image(img.squeeze(), caption=f"Generated Image {idx+1}", use_column_width=True) |
|
|
| |
| st.markdown("---") |
| st.markdown("### About the Model") |
| st.write(""" |
| This is a Wasserstein GAN (WGAN) model trained on the MNIST dataset. |
| The model generates 28x28 grayscale images that resemble handwritten digits. |
| """) |