# MNIST Handwritten Digit Generation Web App # TensorFlow/Keras version using VAE and Gradio for Google Colab # Auto-training version - model trains on startup import numpy as np import gradio as gr import tensorflow as tf from tensorflow.keras import layers, Model from tensorflow.keras.datasets import mnist import matplotlib.pyplot as plt from PIL import Image import io import threading import time # ============================================================================= # PART 1: VAE MODEL DEFINITION # ============================================================================= class Sampling(layers.Layer): def call(self, inputs): z_mean, z_log_var = inputs batch = tf.shape(z_mean)[0] dim = tf.shape(z_mean)[1] epsilon = tf.random.normal(shape=(batch, dim)) return z_mean + tf.exp(0.5 * z_log_var) * epsilon class VAE(Model): def __init__(self, encoder, decoder, **kwargs): super(VAE, self).__init__(**kwargs) self.encoder = encoder self.decoder = decoder self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss") self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") @property def metrics(self): return [ self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker, ] def train_step(self, data): if isinstance(data, tuple): data = data[0] with tf.GradientTape() as tape: z_mean, z_log_var, z = self.encoder(data) reconstruction = self.decoder(z) reconstruction_loss = tf.reduce_mean( tf.reduce_sum( tf.keras.losses.binary_crossentropy(data, reconstruction), axis=-1 ) ) kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) total_loss = reconstruction_loss + kl_loss grads = tape.gradient(total_loss, self.trainable_weights) self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), } def build_vae(input_shape=(784,), latent_dim=20): encoder_inputs = layers.Input(shape=input_shape) x = layers.Dense(400, activation='relu')(encoder_inputs) x = layers.Dense(400, activation='relu')(x) z_mean = layers.Dense(latent_dim, name='z_mean')(x) z_log_var = layers.Dense(latent_dim, name='z_log_var')(x) z = Sampling()([z_mean, z_log_var]) encoder = Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder') latent_inputs = layers.Input(shape=(latent_dim,)) x = layers.Dense(400, activation='relu')(latent_inputs) x = layers.Dense(400, activation='relu')(x) decoder_outputs = layers.Dense(784, activation='sigmoid')(x) decoder = Model(latent_inputs, decoder_outputs, name='decoder') vae = VAE(encoder, decoder) vae.compile(optimizer='adam') return vae, encoder, decoder # ============================================================================= # PART 2: DATA LOADING AND TRAINING # ============================================================================= encoder = None decoder = None digit_latents = None model_ready = False training_progress = "Initializing..." def train_model_background(): global encoder, decoder, digit_latents, model_ready, training_progress try: training_progress = "Loading MNIST data..." print("Loading MNIST data...") (x_train, y_train), _ = mnist.load_data() x_train = x_train.astype('float32') / 255.0 x_train = x_train.reshape((-1, 784)) x_train = x_train[:10000] y_train = y_train[:10000] training_progress = "Building VAE model..." print("Building VAE model...") vae, encoder_model, decoder_model = build_vae() training_progress = "Training VAE model (20 epochs)..." print("Training VAE model (20 epochs)...") class ProgressCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): global training_progress training_progress = f"Training... Epoch {epoch + 1}/20 (Loss: {logs.get('loss', 0):.4f})" print(f"Epoch {epoch + 1}/20 completed") history = vae.fit( x_train, x_train, epochs=20, batch_size=128, verbose=0, callbacks=[ProgressCallback()] ) encoder = encoder_model decoder = decoder_model training_progress = "Computing digit latent representations..." print("Computing digit latent representations...") digit_latents = compute_digit_latents(encoder, x_train, y_train) training_progress = "✅ Model ready! You can now generate digits." model_ready = True print("Model training completed successfully!") except Exception as e: training_progress = f"❌ Error training model: {str(e)}" print(f"Error training model: {str(e)}") def compute_digit_latents(encoder_model, x_train, y_train): try: digit_latents = {i: [] for i in range(10)} z_means, _, _ = encoder_model.predict(x_train, verbose=0) for i, label in enumerate(y_train): digit_latents[label].append(z_means[i]) for i in range(10): if len(digit_latents[i]) > 0: digit_latents[i] = np.array(digit_latents[i]) else: digit_latents[i] = np.random.normal(0, 1, (1, 20)) return digit_latents except Exception as e: print(f"Error computing digit latents: {str(e)}") return None def get_training_status(): return training_progress # ============================================================================= # PART 3: IMAGE GENERATION # ============================================================================= def generate_digit_images(digit, num_images): global encoder, decoder, digit_latents, model_ready if not model_ready: return None, "⏳ Model is still training. Please wait..." if encoder is None or decoder is None or digit_latents is None: return None, "❌ Model not ready yet. Please wait for training to complete." try: latent_vectors = digit_latents[digit] if len(latent_vectors) == 0: selected_latents = np.random.normal(0, 1, (num_images, 20)) else: if len(latent_vectors) >= num_images: indices = np.random.choice(len(latent_vectors), num_images, replace=False) else: indices = np.random.choice(len(latent_vectors), num_images, replace=True) selected_latents = latent_vectors[indices] noise = np.random.normal(0, 0.1, selected_latents.shape) selected_latents = selected_latents + noise generated = decoder.predict(selected_latents, verbose=0) images = (generated.reshape(-1, 28, 28) * 255).astype(np.uint8) if num_images == 1: grid_img = Image.fromarray(images[0], mode='L') else: cols = min(5, num_images) rows = (num_images + cols - 1) // cols grid_width = cols * 28 grid_height = rows * 28 grid_img = Image.new('L', (grid_width, grid_height), color=255) for i, img in enumerate(images): row = i // cols col = i % cols x = col * 28 y = row * 28 grid_img.paste(Image.fromarray(img, mode='L'), (x, y)) success_msg = f"✅ Generated {len(images)} images of digit {digit}!" return grid_img, success_msg except Exception as e: error_msg = f"❌ Error generating images: {str(e)}" return None, error_msg # ============================================================================= # PART 4: GRADIO INTERFACE # ============================================================================= def create_interface(): with gr.Blocks(title="MNIST VAE Digit Generator", theme=gr.themes.Soft()) as app: gr.Markdown("# 🔢 TensorFlow VAE Handwritten Digit Generator") gr.Markdown("Generate MNIST-style handwritten digits using a Variational Autoencoder (VAE).") with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Training Status") training_status = gr.Textbox( label="Model Status", value="Initializing...", interactive=False ) refresh_btn = gr.Button("🔄 Refresh Status", size="sm") gr.Markdown("## Generation Controls") selected_digit = gr.Dropdown( choices=list(range(10)), value=0, label="Select Digit to Generate" ) num_images = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of Images" ) generate_btn = gr.Button("🎲 Generate Images", variant="primary", size="lg") with gr.Column(scale=2): gr.Markdown("## Generated Images") output_image = gr.Image(label="Generated Digits", type="pil") generation_status = gr.Textbox( label="Generation Status", value="Model is training... Please wait before generating images.", interactive=False ) with gr.Accordion("ℹ️ About this App", open=False): gr.Markdown(""" This app uses a **Variational Autoencoder (VAE)** to generate handwritten digits similar to the MNIST dataset. - Wait for training to finish - Select digit & number of images - Click 'Generate' """) refresh_btn.click(fn=get_training_status, outputs=training_status) generate_btn.click(fn=generate_digit_images, inputs=[selected_digit, num_images], outputs=[output_image, generation_status]).then( fn=get_training_status, outputs=training_status ) app.load(fn=get_training_status, outputs=training_status) return app # ============================================================================= # PART 5: MAIN EXECUTION # ============================================================================= if __name__ == "__main__": print("Starting MNIST VAE Digit Generator...") print("Model will train automatically in the background...") training_thread = threading.Thread(target=train_model_background, daemon=True) training_thread.start() app = create_interface() app.launch(share=True, debug=True, show_error=True)