| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import gradio as gr |
| from PIL import Image |
| import os |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| class ConditionalVAE(nn.Module): |
| def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10): |
| super(ConditionalVAE, self).__init__() |
| |
| |
| self.fc1 = nn.Linear(input_dim + num_classes, hidden_dim) |
| self.fc21 = nn.Linear(hidden_dim, latent_dim) |
| self.fc22 = nn.Linear(hidden_dim, latent_dim) |
| |
| |
| self.fc3 = nn.Linear(latent_dim + num_classes, hidden_dim) |
| self.fc4 = nn.Linear(hidden_dim, input_dim) |
| |
| self.latent_dim = latent_dim |
| self.num_classes = num_classes |
| |
| def encode(self, x, y): |
| inputs = torch.cat([x, y], 1) |
| h1 = F.relu(self.fc1(inputs)) |
| return self.fc21(h1), self.fc22(h1) |
| |
| def reparameterize(self, mu, logvar): |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| return mu + eps * std |
| |
| def decode(self, z, y): |
| inputs = torch.cat([z, y], 1) |
| h3 = F.relu(self.fc3(inputs)) |
| return torch.sigmoid(self.fc4(h3)) |
| |
| def forward(self, x, y): |
| mu, logvar = self.encode(x.view(-1, 784), y) |
| z = self.reparameterize(mu, logvar) |
| return self.decode(z, y), mu, logvar |
|
|
| |
| def load_model(): |
| model = ConditionalVAE(input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10) |
| model.load_state_dict(torch.load('mnist_cvae_model.pth', map_location=device)) |
| model = model.to(device) |
| model.eval() |
| return model |
|
|
| def generate_digits(model, digit, num_samples=5): |
| model.eval() |
| with torch.no_grad(): |
| label = torch.zeros(num_samples, 10).to(device) |
| label[:, digit] = 1 |
| |
| z = torch.randn(num_samples, model.latent_dim).to(device) |
| generated = model.decode(z, label) |
| generated = generated.view(num_samples, 28, 28) |
| generated = generated.cpu().numpy() |
| generated = (generated * 255).astype(np.uint8) |
| |
| return generated |
|
|
| def generate_digit_images(digit): |
| try: |
| model = load_model() |
| generated_images = generate_digits(model, int(digit), num_samples=5) |
| |
| pil_images = [] |
| for img in generated_images: |
| pil_img = Image.fromarray(img, mode='L') |
| pil_img = pil_img.resize((112, 112), Image.NEAREST) |
| pil_images.append(pil_img) |
| |
| return pil_images |
| except Exception as e: |
| print(f"Error: {e}") |
| placeholder = Image.new('L', (112, 112), color=128) |
| return [placeholder] * 5 |
|
|
| def generate_and_display(digit): |
| images = generate_digit_images(digit) |
| return images[0], images[1], images[2], images[3], images[4] |
|
|
| |
| with gr.Blocks(title="MNIST Digit Generator", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# π’ MNIST Handwritten Digit Generator") |
| gr.Markdown("Select a digit (0-9) and generate 5 unique handwritten samples using a trained Conditional VAE model.") |
| |
| with gr.Row(): |
| digit_input = gr.Slider( |
| minimum=0, |
| maximum=9, |
| step=1, |
| value=0, |
| label="Select Digit to Generate" |
| ) |
| |
| generate_btn = gr.Button("π¨ Generate 5 Digit Images", variant="primary", size="lg") |
| |
| gr.Markdown("## Generated Images") |
| with gr.Row(): |
| img1 = gr.Image(label="Sample 1", width=112, height=112) |
| img2 = gr.Image(label="Sample 2", width=112, height=112) |
| img3 = gr.Image(label="Sample 3", width=112, height=112) |
| img4 = gr.Image(label="Sample 4", width=112, height=112) |
| img5 = gr.Image(label="Sample 5", width=112, height=112) |
| |
| generate_btn.click( |
| fn=generate_and_display, |
| inputs=[digit_input], |
| outputs=[img1, img2, img3, img4, img5] |
| ) |
| |
| with gr.Accordion("π Model Information", open=False): |
| gr.Markdown(""" |
| ### Technical Details |
| - **Architecture**: Conditional Variational Autoencoder (CVAE) |
| - **Dataset**: MNIST (28Γ28 grayscale images) |
| - **Training**: From scratch on Google Colab T4 GPU |
| - **Latent Dimension**: 20 |
| - **Training Epochs**: 15 |
| - **Loss Function**: BCE + KL Divergence |
| |
| The model generates diverse samples by sampling from the learned latent space conditioned on digit labels. |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch() |