Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import random | |
| from datetime import datetime | |
| from typing import List, Tuple | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, random_split | |
| from torchvision import datasets, transforms | |
| from PIL import Image | |
| # ============================================================ | |
| # Paths / basic config | |
| # ============================================================ | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd() | |
| DATA_DIR = os.path.join(BASE_DIR, "data") | |
| MODEL_DIR = os.path.join(BASE_DIR, "saved_models") | |
| META_DIR = os.path.join(BASE_DIR, "saved_models_meta") | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| os.makedirs(META_DIR, exist_ok=True) | |
| CLASS_NAMES = [str(i) for i in range(10)] | |
| # ============================================================ | |
| # Model | |
| # ============================================================ | |
| class SimpleCNN(nn.Module): | |
| def __init__( | |
| self, | |
| conv1_channels: int = 16, | |
| conv2_channels: int = 32, | |
| kernel_size: int = 3, | |
| dropout: float = 0.2, | |
| fc_dim: int = 128, | |
| ): | |
| super().__init__() | |
| padding = kernel_size // 2 | |
| self.features = nn.Sequential( | |
| nn.Conv2d(1, conv1_channels, kernel_size=kernel_size, padding=padding), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| ) | |
| flattened_dim = conv2_channels * 7 * 7 # 28x28 -> 14x14 -> 7x7 | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(flattened_dim, fc_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(fc_dim, 10), | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.classifier(x) | |
| return x | |
| # ============================================================ | |
| # Dataset helpers | |
| # ============================================================ | |
| def get_datasets(dataset_name: str): | |
| transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ] | |
| ) | |
| if dataset_name == "MNIST": | |
| train_dataset = datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform) | |
| test_dataset = datasets.MNIST(DATA_DIR, train=False, download=True, transform=transform) | |
| elif dataset_name == "FashionMNIST": | |
| train_dataset = datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform) | |
| test_dataset = datasets.FashionMNIST(DATA_DIR, train=False, download=True, transform=transform) | |
| else: | |
| raise ValueError(f"Unsupported dataset: {dataset_name}") | |
| return train_dataset, test_dataset | |
| def make_loaders(dataset_name: str, batch_size: int, val_ratio: float = 0.1): | |
| train_dataset, test_dataset = get_datasets(dataset_name) | |
| val_size = int(len(train_dataset) * val_ratio) | |
| train_size = len(train_dataset) - val_size | |
| train_subset, val_subset = random_split(train_dataset, [train_size, val_size]) | |
| train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| return train_loader, val_loader, test_loader | |
| # ============================================================ | |
| # Model save/load helpers | |
| # ============================================================ | |
| def model_weight_path(model_name: str) -> str: | |
| return os.path.join(MODEL_DIR, f"{model_name}.pt") | |
| def model_meta_path(model_name: str) -> str: | |
| return os.path.join(META_DIR, f"{model_name}.json") | |
| def list_saved_models() -> List[str]: | |
| names = [] | |
| for fn in os.listdir(META_DIR): | |
| if fn.endswith(".json"): | |
| names.append(fn[:-5]) | |
| names.sort(reverse=True) | |
| return names | |
| def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict): | |
| cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()} | |
| torch.save(cpu_state_dict, model_weight_path(model_name)) | |
| payload = { | |
| "model_name": model_name, | |
| "config": config, | |
| "training_summary": training_summary, | |
| "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| } | |
| with open(model_meta_path(model_name), "w", encoding="utf-8") as f: | |
| json.dump(payload, f, indent=2, ensure_ascii=False) | |
| def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]: | |
| meta_file = model_meta_path(model_name) | |
| weight_file = model_weight_path(model_name) | |
| if not os.path.exists(meta_file): | |
| raise FileNotFoundError(f"Metadata not found for model: {model_name}") | |
| if not os.path.exists(weight_file): | |
| raise FileNotFoundError(f"Weights not found for model: {model_name}") | |
| with open(meta_file, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| cfg = meta["config"] | |
| model = SimpleCNN( | |
| conv1_channels=cfg["conv1_channels"], | |
| conv2_channels=cfg["conv2_channels"], | |
| kernel_size=cfg["kernel_size"], | |
| dropout=cfg["dropout"], | |
| fc_dim=cfg["fc_dim"], | |
| ) | |
| state_dict = torch.load(weight_file, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model, meta | |
| # ============================================================ | |
| # ZeroGPU helpers | |
| # ============================================================ | |
| def get_runtime_device() -> torch.device: | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def _train_on_gpu( | |
| dataset_name: str, | |
| conv1_channels: int, | |
| conv2_channels: int, | |
| kernel_size: int, | |
| dropout: float, | |
| fc_dim: int, | |
| learning_rate: float, | |
| batch_size: int, | |
| epochs: int, | |
| model_tag: str, | |
| ): | |
| device = get_runtime_device() | |
| train_loader, val_loader, test_loader = make_loaders(dataset_name, batch_size) | |
| model = SimpleCNN( | |
| conv1_channels=conv1_channels, | |
| conv2_channels=conv2_channels, | |
| kernel_size=kernel_size, | |
| dropout=dropout, | |
| fc_dim=fc_dim, | |
| ).to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| history = [] | |
| logs = [] | |
| start_time = time.time() | |
| def evaluate(loader): | |
| model.eval() | |
| total_loss = 0.0 | |
| total = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| avg_loss = total_loss / total if total else 0.0 | |
| acc = correct / total if total else 0.0 | |
| return avg_loss, acc | |
| for epoch in range(1, epochs + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| total = 0 | |
| correct = 0 | |
| for images, labels in train_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| train_loss = running_loss / total if total else 0.0 | |
| train_acc = correct / total if total else 0.0 | |
| val_loss, val_acc = evaluate(val_loader) | |
| row = { | |
| "epoch": epoch, | |
| "train_loss": round(train_loss, 4), | |
| "train_acc": round(train_acc, 4), | |
| "val_loss": round(val_loss, 4), | |
| "val_acc": round(val_acc, 4), | |
| } | |
| history.append(row) | |
| logs.append( | |
| f"Epoch {epoch}/{epochs} | " | |
| f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, " | |
| f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}" | |
| ) | |
| test_loss, test_acc = evaluate(test_loader) | |
| elapsed = time.time() - start_time | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else dataset_name.lower() | |
| model_name = f"{safe_tag}_{timestamp}" | |
| config = { | |
| "dataset_name": dataset_name, | |
| "conv1_channels": conv1_channels, | |
| "conv2_channels": conv2_channels, | |
| "kernel_size": kernel_size, | |
| "dropout": dropout, | |
| "fc_dim": fc_dim, | |
| "learning_rate": learning_rate, | |
| "batch_size": batch_size, | |
| "epochs": epochs, | |
| } | |
| training_summary = { | |
| "final_train_loss": history[-1]["train_loss"] if history else None, | |
| "final_train_acc": history[-1]["train_acc"] if history else None, | |
| "final_val_loss": history[-1]["val_loss"] if history else None, | |
| "final_val_acc": history[-1]["val_acc"] if history else None, | |
| "test_loss": round(test_loss, 4), | |
| "test_acc": round(test_acc, 4), | |
| "elapsed_seconds": round(elapsed, 2), | |
| "device": str(device), | |
| } | |
| save_model(model, model_name, config, training_summary) | |
| logs.append("") | |
| logs.append("Training finished.") | |
| logs.append(f"Saved model: {model_name}") | |
| logs.append(f"Device: {device}") | |
| logs.append(f"Test loss: {test_loss:.4f}") | |
| logs.append(f"Test accuracy: {test_acc:.4f}") | |
| logs.append(f"Elapsed time: {elapsed:.1f}s") | |
| return "\n".join(logs), history, training_summary, model_name | |
| def _predict_uploaded_image_gpu(model_name: str, image: Image.Image): | |
| if not model_name: | |
| return "Please select a model.", None | |
| if image is None: | |
| return "Please upload an image.", None | |
| device = get_runtime_device() | |
| model, meta = load_model(model_name, device) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Grayscale(num_output_channels=1), | |
| transforms.Resize((28, 28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ] | |
| ) | |
| tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu().tolist() | |
| pred_idx = int(torch.argmax(logits, dim=1).item()) | |
| result_text = ( | |
| f"Prediction: {CLASS_NAMES[pred_idx]}\n" | |
| f"Confidence: {max(probs):.4f}\n\n" | |
| f"Model: {model_name}\n" | |
| f"Dataset: {meta['config']['dataset_name']}\n" | |
| f"Runtime device: {device}" | |
| ) | |
| prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)} | |
| return result_text, prob_dict | |
| def _test_random_sample_gpu(model_name: str): | |
| if not model_name: | |
| return None, "Please select a model.", None | |
| device = get_runtime_device() | |
| model, meta = load_model(model_name, device) | |
| dataset_name = meta["config"]["dataset_name"] | |
| _, test_dataset = get_datasets(dataset_name) | |
| idx = random.randint(0, len(test_dataset) - 1) | |
| image_tensor, label = test_dataset[idx] | |
| with torch.no_grad(): | |
| logits = model(image_tensor.unsqueeze(0).to(device)) | |
| probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu().tolist() | |
| pred_idx = int(torch.argmax(logits, dim=1).item()) | |
| display_img = image_tensor.squeeze(0).cpu().numpy() | |
| result_text = ( | |
| f"Random test sample\n" | |
| f"Ground truth: {label}\n" | |
| f"Prediction: {pred_idx}\n" | |
| f"Confidence: {max(probs):.4f}\n" | |
| f"Model dataset: {dataset_name}\n" | |
| f"Runtime device: {device}" | |
| ) | |
| prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)} | |
| return display_img, result_text, prob_dict | |
| # ============================================================ | |
| # UI callbacks | |
| # ============================================================ | |
| def train_callback( | |
| dataset_name, | |
| conv1_channels, | |
| conv2_channels, | |
| kernel_size, | |
| dropout, | |
| fc_dim, | |
| learning_rate, | |
| batch_size, | |
| epochs, | |
| model_tag, | |
| ): | |
| try: | |
| logs, history, summary, model_name = _train_on_gpu( | |
| dataset_name, | |
| int(conv1_channels), | |
| int(conv2_channels), | |
| int(kernel_size), | |
| float(dropout), | |
| int(fc_dim), | |
| float(learning_rate), | |
| int(batch_size), | |
| int(epochs), | |
| model_tag, | |
| ) | |
| models = list_saved_models() | |
| selected = model_name if model_name in models else (models[0] if models else None) | |
| return logs, history, summary, gr.update(choices=models, value=selected) | |
| except Exception as e: | |
| return f"Training failed:\n{str(e)}", None, None, gr.update() | |
| def predict_uploaded_image_callback(model_name, image): | |
| try: | |
| return _predict_uploaded_image_gpu(model_name, image) | |
| except Exception as e: | |
| return f"Prediction failed:\n{str(e)}", None | |
| def test_random_sample_callback(model_name): | |
| try: | |
| return _test_random_sample_gpu(model_name) | |
| except Exception as e: | |
| return None, f"Random test failed:\n{str(e)}", None | |
| def get_model_info(model_name: str): | |
| if not model_name: | |
| return {"message": "No model selected."} | |
| meta_file = model_meta_path(model_name) | |
| if not os.path.exists(meta_file): | |
| return {"message": "Metadata not found."} | |
| with open(meta_file, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| return meta | |
| def refresh_models_dropdown(): | |
| models = list_saved_models() | |
| return gr.update(choices=models, value=models[0] if models else None) | |
| # ============================================================ | |
| # UI | |
| # ============================================================ | |
| initial_models = list_saved_models() | |
| with gr.Blocks(title="Image Classification") as demo: | |
| gr.Markdown("# Image Classification") | |
| gr.Markdown( | |
| "Train a simple CNN on MNIST or FashionMNIST, then test saved models " | |
| "with an uploaded image or a random sample." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Train"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| dataset_name = gr.Dropdown( | |
| choices=["MNIST", "FashionMNIST"], | |
| value="MNIST", | |
| label="Dataset", | |
| ) | |
| conv1_channels = gr.Slider(8, 64, value=16, step=8, label="Conv1 Channels") | |
| conv2_channels = gr.Slider(16, 128, value=32, step=16, label="Conv2 Channels") | |
| kernel_size = gr.Dropdown(choices=[3, 5], value=3, label="Kernel Size") | |
| dropout = gr.Slider(0.0, 0.7, value=0.2, step=0.05, label="Dropout") | |
| fc_dim = gr.Slider(32, 256, value=128, step=32, label="FC Hidden Dimension") | |
| learning_rate = gr.Number(value=0.001, label="Learning Rate") | |
| batch_size = gr.Dropdown(choices=[32, 64, 128, 256], value=64, label="Batch Size") | |
| epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs") | |
| model_tag = gr.Textbox(label="Model Tag", placeholder="e.g. mnist_demo") | |
| train_btn = gr.Button("Start Training", variant="primary") | |
| with gr.Column(): | |
| train_status = gr.Textbox(label="Training Log", lines=18) | |
| train_history = gr.JSON(label="Training History") | |
| train_summary = gr.JSON(label="Training Summary") | |
| with gr.Tab("Test"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_selector = gr.Dropdown( | |
| choices=initial_models, | |
| value=initial_models[0] if initial_models else None, | |
| label="Select Saved Model", | |
| ) | |
| refresh_btn = gr.Button("Refresh Model List") | |
| load_info_btn = gr.Button("Show Model Info") | |
| model_info = gr.JSON(label="Model Metadata") | |
| with gr.Column(): | |
| upload_image = gr.Image(type="pil", label="Upload Image") | |
| predict_btn = gr.Button("Predict Uploaded Image", variant="primary") | |
| predict_text = gr.Textbox(label="Prediction Result", lines=7) | |
| predict_probs = gr.Label(label="Class Probabilities") | |
| with gr.Row(): | |
| random_test_btn = gr.Button("Test Random Sample") | |
| with gr.Row(): | |
| random_sample_image = gr.Image(type="numpy", label="Random Test Image") | |
| random_sample_text = gr.Textbox(label="Random Sample Result", lines=7) | |
| random_sample_probs = gr.Label(label="Random Sample Probabilities") | |
| train_btn.click( | |
| fn=train_callback, | |
| inputs=[ | |
| dataset_name, | |
| conv1_channels, | |
| conv2_channels, | |
| kernel_size, | |
| dropout, | |
| fc_dim, | |
| learning_rate, | |
| batch_size, | |
| epochs, | |
| model_tag, | |
| ], | |
| outputs=[train_status, train_history, train_summary, model_selector], | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_models_dropdown, | |
| inputs=None, | |
| outputs=model_selector, | |
| ) | |
| load_info_btn.click( | |
| fn=get_model_info, | |
| inputs=model_selector, | |
| outputs=model_info, | |
| ) | |
| predict_btn.click( | |
| fn=predict_uploaded_image_callback, | |
| inputs=[model_selector, upload_image], | |
| outputs=[predict_text, predict_probs], | |
| ) | |
| random_test_btn.click( | |
| fn=test_random_sample_callback, | |
| inputs=[model_selector], | |
| outputs=[random_sample_image, random_sample_text, random_sample_probs], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |