Spaces:
Sleeping
Sleeping
| # MNIST Handwritten Digit Generation Web App | |
| # TensorFlow/Keras version using VAE and Gradio for Google Colab | |
| # Auto-training version - model trains on startup | |
| import numpy as np | |
| import gradio as gr | |
| import tensorflow as tf | |
| from tensorflow.keras import layers, Model | |
| from tensorflow.keras.datasets import mnist | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import io | |
| import threading | |
| import time | |
| # ============================================================================= | |
| # PART 1: VAE MODEL DEFINITION | |
| # ============================================================================= | |
| class Sampling(layers.Layer): | |
| def call(self, inputs): | |
| z_mean, z_log_var = inputs | |
| batch = tf.shape(z_mean)[0] | |
| dim = tf.shape(z_mean)[1] | |
| epsilon = tf.random.normal(shape=(batch, dim)) | |
| return z_mean + tf.exp(0.5 * z_log_var) * epsilon | |
| class VAE(Model): | |
| def __init__(self, encoder, decoder, **kwargs): | |
| super(VAE, self).__init__(**kwargs) | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") | |
| self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss") | |
| self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") | |
| def metrics(self): | |
| return [ | |
| self.total_loss_tracker, | |
| self.reconstruction_loss_tracker, | |
| self.kl_loss_tracker, | |
| ] | |
| def train_step(self, data): | |
| if isinstance(data, tuple): | |
| data = data[0] | |
| with tf.GradientTape() as tape: | |
| z_mean, z_log_var, z = self.encoder(data) | |
| reconstruction = self.decoder(z) | |
| reconstruction_loss = tf.reduce_mean( | |
| tf.reduce_sum( | |
| tf.keras.losses.binary_crossentropy(data, reconstruction), axis=-1 | |
| ) | |
| ) | |
| kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) | |
| kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) | |
| total_loss = reconstruction_loss + kl_loss | |
| grads = tape.gradient(total_loss, self.trainable_weights) | |
| self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) | |
| self.total_loss_tracker.update_state(total_loss) | |
| self.reconstruction_loss_tracker.update_state(reconstruction_loss) | |
| self.kl_loss_tracker.update_state(kl_loss) | |
| return { | |
| "loss": self.total_loss_tracker.result(), | |
| "reconstruction_loss": self.reconstruction_loss_tracker.result(), | |
| "kl_loss": self.kl_loss_tracker.result(), | |
| } | |
| def build_vae(input_shape=(784,), latent_dim=20): | |
| encoder_inputs = layers.Input(shape=input_shape) | |
| x = layers.Dense(400, activation='relu')(encoder_inputs) | |
| x = layers.Dense(400, activation='relu')(x) | |
| z_mean = layers.Dense(latent_dim, name='z_mean')(x) | |
| z_log_var = layers.Dense(latent_dim, name='z_log_var')(x) | |
| z = Sampling()([z_mean, z_log_var]) | |
| encoder = Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder') | |
| latent_inputs = layers.Input(shape=(latent_dim,)) | |
| x = layers.Dense(400, activation='relu')(latent_inputs) | |
| x = layers.Dense(400, activation='relu')(x) | |
| decoder_outputs = layers.Dense(784, activation='sigmoid')(x) | |
| decoder = Model(latent_inputs, decoder_outputs, name='decoder') | |
| vae = VAE(encoder, decoder) | |
| vae.compile(optimizer='adam') | |
| return vae, encoder, decoder | |
| # ============================================================================= | |
| # PART 2: DATA LOADING AND TRAINING | |
| # ============================================================================= | |
| encoder = None | |
| decoder = None | |
| digit_latents = None | |
| model_ready = False | |
| training_progress = "Initializing..." | |
| def train_model_background(): | |
| global encoder, decoder, digit_latents, model_ready, training_progress | |
| try: | |
| training_progress = "Loading MNIST data..." | |
| print("Loading MNIST data...") | |
| (x_train, y_train), _ = mnist.load_data() | |
| x_train = x_train.astype('float32') / 255.0 | |
| x_train = x_train.reshape((-1, 784)) | |
| x_train = x_train[:10000] | |
| y_train = y_train[:10000] | |
| training_progress = "Building VAE model..." | |
| print("Building VAE model...") | |
| vae, encoder_model, decoder_model = build_vae() | |
| training_progress = "Training VAE model (20 epochs)..." | |
| print("Training VAE model (20 epochs)...") | |
| class ProgressCallback(tf.keras.callbacks.Callback): | |
| def on_epoch_end(self, epoch, logs=None): | |
| global training_progress | |
| training_progress = f"Training... Epoch {epoch + 1}/20 (Loss: {logs.get('loss', 0):.4f})" | |
| print(f"Epoch {epoch + 1}/20 completed") | |
| history = vae.fit( | |
| x_train, x_train, | |
| epochs=20, | |
| batch_size=128, | |
| verbose=0, | |
| callbacks=[ProgressCallback()] | |
| ) | |
| encoder = encoder_model | |
| decoder = decoder_model | |
| training_progress = "Computing digit latent representations..." | |
| print("Computing digit latent representations...") | |
| digit_latents = compute_digit_latents(encoder, x_train, y_train) | |
| training_progress = "✅ Model ready! You can now generate digits." | |
| model_ready = True | |
| print("Model training completed successfully!") | |
| except Exception as e: | |
| training_progress = f"❌ Error training model: {str(e)}" | |
| print(f"Error training model: {str(e)}") | |
| def compute_digit_latents(encoder_model, x_train, y_train): | |
| try: | |
| digit_latents = {i: [] for i in range(10)} | |
| z_means, _, _ = encoder_model.predict(x_train, verbose=0) | |
| for i, label in enumerate(y_train): | |
| digit_latents[label].append(z_means[i]) | |
| for i in range(10): | |
| if len(digit_latents[i]) > 0: | |
| digit_latents[i] = np.array(digit_latents[i]) | |
| else: | |
| digit_latents[i] = np.random.normal(0, 1, (1, 20)) | |
| return digit_latents | |
| except Exception as e: | |
| print(f"Error computing digit latents: {str(e)}") | |
| return None | |
| def get_training_status(): | |
| return training_progress | |
| # ============================================================================= | |
| # PART 3: IMAGE GENERATION | |
| # ============================================================================= | |
| def generate_digit_images(digit, num_images): | |
| global encoder, decoder, digit_latents, model_ready | |
| if not model_ready: | |
| return None, "⏳ Model is still training. Please wait..." | |
| if encoder is None or decoder is None or digit_latents is None: | |
| return None, "❌ Model not ready yet. Please wait for training to complete." | |
| try: | |
| latent_vectors = digit_latents[digit] | |
| if len(latent_vectors) == 0: | |
| selected_latents = np.random.normal(0, 1, (num_images, 20)) | |
| else: | |
| if len(latent_vectors) >= num_images: | |
| indices = np.random.choice(len(latent_vectors), num_images, replace=False) | |
| else: | |
| indices = np.random.choice(len(latent_vectors), num_images, replace=True) | |
| selected_latents = latent_vectors[indices] | |
| noise = np.random.normal(0, 0.1, selected_latents.shape) | |
| selected_latents = selected_latents + noise | |
| generated = decoder.predict(selected_latents, verbose=0) | |
| images = (generated.reshape(-1, 28, 28) * 255).astype(np.uint8) | |
| if num_images == 1: | |
| grid_img = Image.fromarray(images[0], mode='L') | |
| else: | |
| cols = min(5, num_images) | |
| rows = (num_images + cols - 1) // cols | |
| grid_width = cols * 28 | |
| grid_height = rows * 28 | |
| grid_img = Image.new('L', (grid_width, grid_height), color=255) | |
| for i, img in enumerate(images): | |
| row = i // cols | |
| col = i % cols | |
| x = col * 28 | |
| y = row * 28 | |
| grid_img.paste(Image.fromarray(img, mode='L'), (x, y)) | |
| success_msg = f"✅ Generated {len(images)} images of digit {digit}!" | |
| return grid_img, success_msg | |
| except Exception as e: | |
| error_msg = f"❌ Error generating images: {str(e)}" | |
| return None, error_msg | |
| # ============================================================================= | |
| # PART 4: GRADIO INTERFACE | |
| # ============================================================================= | |
| def create_interface(): | |
| with gr.Blocks(title="MNIST VAE Digit Generator", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# 🔢 TensorFlow VAE Handwritten Digit Generator") | |
| gr.Markdown("Generate MNIST-style handwritten digits using a Variational Autoencoder (VAE).") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Training Status") | |
| training_status = gr.Textbox( | |
| label="Model Status", | |
| value="Initializing...", | |
| interactive=False | |
| ) | |
| refresh_btn = gr.Button("🔄 Refresh Status", size="sm") | |
| gr.Markdown("## Generation Controls") | |
| selected_digit = gr.Dropdown( | |
| choices=list(range(10)), | |
| value=0, | |
| label="Select Digit to Generate" | |
| ) | |
| num_images = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Number of Images" | |
| ) | |
| generate_btn = gr.Button("🎲 Generate Images", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Generated Images") | |
| output_image = gr.Image(label="Generated Digits", type="pil") | |
| generation_status = gr.Textbox( | |
| label="Generation Status", | |
| value="Model is training... Please wait before generating images.", | |
| interactive=False | |
| ) | |
| with gr.Accordion("ℹ️ About this App", open=False): | |
| gr.Markdown(""" | |
| This app uses a **Variational Autoencoder (VAE)** to generate handwritten digits similar to the MNIST dataset. | |
| - Wait for training to finish | |
| - Select digit & number of images | |
| - Click 'Generate' | |
| """) | |
| refresh_btn.click(fn=get_training_status, outputs=training_status) | |
| generate_btn.click(fn=generate_digit_images, inputs=[selected_digit, num_images], | |
| outputs=[output_image, generation_status]).then( | |
| fn=get_training_status, outputs=training_status | |
| ) | |
| app.load(fn=get_training_status, outputs=training_status) | |
| return app | |
| # ============================================================================= | |
| # PART 5: MAIN EXECUTION | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| print("Starting MNIST VAE Digit Generator...") | |
| print("Model will train automatically in the background...") | |
| training_thread = threading.Thread(target=train_model_background, daemon=True) | |
| training_thread.start() | |
| app = create_interface() | |
| app.launch(share=True, debug=True, show_error=True) | |