Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import math | |
| from datetime import datetime | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, random_split | |
| from torchvision import datasets, transforms | |
| from PIL import Image | |
| # ============================================================ | |
| # Configuration | |
| # ============================================================ | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd() | |
| DATA_DIR = os.path.join(BASE_DIR, "data") | |
| MODEL_DIR = os.path.join(BASE_DIR, "saved_models") | |
| META_DIR = os.path.join(BASE_DIR, "saved_models_meta") | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| os.makedirs(META_DIR, exist_ok=True) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CLASS_NAMES = [str(i) for i in range(10)] | |
| # ============================================================ | |
| # Model | |
| # ============================================================ | |
| class SimpleCNN(nn.Module): | |
| def __init__(self, conv1_channels: int = 16, conv2_channels: int = 32, kernel_size: int = 3, | |
| dropout: float = 0.2, fc_dim: int = 128): | |
| super().__init__() | |
| padding = kernel_size // 2 | |
| self.features = nn.Sequential( | |
| nn.Conv2d(1, conv1_channels, kernel_size=kernel_size, padding=padding), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| ) | |
| # MNIST input = 1 x 28 x 28 | |
| # after two 2x2 poolings => 7 x 7 | |
| flattened_dim = conv2_channels * 7 * 7 | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(flattened_dim, fc_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(fc_dim, 10), | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.classifier(x) | |
| return x | |
| # ============================================================ | |
| # Data utilities | |
| # ============================================================ | |
| def get_datasets(dataset_name: str): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| if dataset_name == "MNIST": | |
| train_dataset = datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform) | |
| test_dataset = datasets.MNIST(DATA_DIR, train=False, download=True, transform=transform) | |
| elif dataset_name == "FashionMNIST": | |
| train_dataset = datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform) | |
| test_dataset = datasets.FashionMNIST(DATA_DIR, train=False, download=True, transform=transform) | |
| else: | |
| raise ValueError(f"Unsupported dataset: {dataset_name}") | |
| return train_dataset, test_dataset | |
| def make_loaders(dataset_name: str, batch_size: int, val_ratio: float = 0.1): | |
| train_dataset, test_dataset = get_datasets(dataset_name) | |
| val_size = int(len(train_dataset) * val_ratio) | |
| train_size = len(train_dataset) - val_size | |
| train_subset, val_subset = random_split(train_dataset, [train_size, val_size]) | |
| train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| return train_loader, val_loader, test_loader | |
| # ============================================================ | |
| # Model registry helpers | |
| # ============================================================ | |
| def model_meta_path(model_name: str) -> str: | |
| return os.path.join(META_DIR, f"{model_name}.json") | |
| def model_weight_path(model_name: str) -> str: | |
| return os.path.join(MODEL_DIR, f"{model_name}.pt") | |
| def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict): | |
| torch.save(model.state_dict(), model_weight_path(model_name)) | |
| payload = { | |
| "model_name": model_name, | |
| "config": config, | |
| "training_summary": training_summary, | |
| "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| } | |
| with open(model_meta_path(model_name), "w", encoding="utf-8") as f: | |
| json.dump(payload, f, indent=2, ensure_ascii=False) | |
| def list_saved_models() -> List[str]: | |
| models = [] | |
| for filename in os.listdir(META_DIR): | |
| if filename.endswith(".json"): | |
| models.append(filename[:-5]) | |
| models.sort(reverse=True) | |
| return models | |
| def load_model(model_name: str) -> Tuple[nn.Module, dict]: | |
| meta_file = model_meta_path(model_name) | |
| weight_file = model_weight_path(model_name) | |
| if not os.path.exists(meta_file): | |
| raise FileNotFoundError(f"Metadata not found for model: {model_name}") | |
| if not os.path.exists(weight_file): | |
| raise FileNotFoundError(f"Weights not found for model: {model_name}") | |
| with open(meta_file, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| config = meta["config"] | |
| model = SimpleCNN( | |
| conv1_channels=config["conv1_channels"], | |
| conv2_channels=config["conv2_channels"], | |
| kernel_size=config["kernel_size"], | |
| dropout=config["dropout"], | |
| fc_dim=config["fc_dim"], | |
| ) | |
| state_dict = torch.load(weight_file, map_location=DEVICE) | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model, meta | |
| # ============================================================ | |
| # Training / evaluation | |
| # ============================================================ | |
| def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module): | |
| model.eval() | |
| total_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images, labels = images.to(DEVICE), labels.to(DEVICE) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| avg_loss = total_loss / total if total > 0 else 0.0 | |
| acc = correct / total if total > 0 else 0.0 | |
| return avg_loss, acc | |
| def train_model(dataset_name: str, conv1_channels: int, conv2_channels: int, kernel_size: int, | |
| dropout: float, fc_dim: int, learning_rate: float, batch_size: int, | |
| epochs: int, model_tag: str): | |
| train_loader, val_loader, test_loader = make_loaders(dataset_name, batch_size) | |
| model = SimpleCNN( | |
| conv1_channels=conv1_channels, | |
| conv2_channels=conv2_channels, | |
| kernel_size=kernel_size, | |
| dropout=dropout, | |
| fc_dim=fc_dim, | |
| ).to(DEVICE) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| history = { | |
| "epoch": [], | |
| "train_loss": [], | |
| "train_acc": [], | |
| "val_loss": [], | |
| "val_acc": [], | |
| } | |
| start_time = time.time() | |
| for epoch in range(1, epochs + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for images, labels in train_loader: | |
| images, labels = images.to(DEVICE), labels.to(DEVICE) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| train_loss = running_loss / total if total > 0 else 0.0 | |
| train_acc = correct / total if total > 0 else 0.0 | |
| val_loss, val_acc = evaluate(model, val_loader, criterion) | |
| history["epoch"].append(epoch) | |
| history["train_loss"].append(train_loss) | |
| history["train_acc"].append(train_acc) | |
| history["val_loss"].append(val_loss) | |
| history["val_acc"].append(val_acc) | |
| yield { | |
| "status": ( | |
| f"Epoch {epoch}/{epochs} | " | |
| f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, " | |
| f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}" | |
| ), | |
| "history": history, | |
| "finished": False, | |
| "models": None, | |
| } | |
| test_loss, test_acc = evaluate(model, test_loader, criterion) | |
| elapsed = time.time() - start_time | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_tag = model_tag.strip().replace(" ", "_") if model_tag else dataset_name.lower() | |
| model_name = f"{safe_tag}_{timestamp}" | |
| config = { | |
| "dataset_name": dataset_name, | |
| "conv1_channels": conv1_channels, | |
| "conv2_channels": conv2_channels, | |
| "kernel_size": kernel_size, | |
| "dropout": dropout, | |
| "fc_dim": fc_dim, | |
| "learning_rate": learning_rate, | |
| "batch_size": batch_size, | |
| "epochs": epochs, | |
| } | |
| training_summary = { | |
| "final_train_loss": history["train_loss"][-1], | |
| "final_train_acc": history["train_acc"][-1], | |
| "final_val_loss": history["val_loss"][-1], | |
| "final_val_acc": history["val_acc"][-1], | |
| "test_loss": test_loss, | |
| "test_acc": test_acc, | |
| "elapsed_seconds": elapsed, | |
| "device": str(DEVICE), | |
| } | |
| save_model(model, model_name, config, training_summary) | |
| final_message = ( | |
| f"Training finished.\n\n" | |
| f"Saved model: {model_name}\n" | |
| f"Device: {DEVICE}\n" | |
| f"Test loss: {test_loss:.4f}\n" | |
| f"Test accuracy: {test_acc:.4f}\n" | |
| f"Elapsed time: {elapsed:.1f}s" | |
| ) | |
| yield { | |
| "status": final_message, | |
| "history": history, | |
| "finished": True, | |
| "models": list_saved_models(), | |
| } | |
| # ============================================================ | |
| # Inference helpers | |
| # ============================================================ | |
| def preprocess_uploaded_image(image: Image.Image): | |
| if image is None: | |
| raise ValueError("Please upload an image.") | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(num_output_channels=1), | |
| transforms.Resize((28, 28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| tensor = transform(image).unsqueeze(0) | |
| return tensor | |
| def predict_uploaded_image(model_name: str, image: Image.Image): | |
| if not model_name: | |
| return "Please select a model.", None | |
| model, meta = load_model(model_name) | |
| tensor = preprocess_uploaded_image(image).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist() | |
| pred_idx = int(torch.argmax(logits, dim=1).item()) | |
| conf = max(probs) | |
| result_text = ( | |
| f"Prediction: {CLASS_NAMES[pred_idx]}\n" | |
| f"Confidence: {conf:.4f}\n\n" | |
| f"Model: {model_name}\n" | |
| f"Dataset: {meta['config']['dataset_name']}" | |
| ) | |
| prob_table = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} | |
| return result_text, prob_table | |
| def test_random_sample(model_name: str): | |
| if not model_name: | |
| return None, "Please select a model.", None | |
| model, meta = load_model(model_name) | |
| dataset_name = meta["config"]["dataset_name"] | |
| _, test_dataset = get_datasets(dataset_name) | |
| idx = torch.randint(low=0, high=len(test_dataset), size=(1,)).item() | |
| image_tensor, label = test_dataset[idx] | |
| with torch.no_grad(): | |
| logits = model(image_tensor.unsqueeze(0).to(DEVICE)) | |
| probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist() | |
| pred_idx = int(torch.argmax(logits, dim=1).item()) | |
| display_img = image_tensor.squeeze(0).cpu() | |
| prob_table = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} | |
| result_text = ( | |
| f"Random test sample\n" | |
| f"Ground truth: {label}\n" | |
| f"Prediction: {pred_idx}\n" | |
| f"Confidence: {max(probs):.4f}\n" | |
| f"Model dataset: {dataset_name}" | |
| ) | |
| return display_img, result_text, prob_table | |
| def get_model_info(model_name: str): | |
| if not model_name: | |
| return "No model selected." | |
| meta_file = model_meta_path(model_name) | |
| if not os.path.exists(meta_file): | |
| return "Selected model metadata not found." | |
| with open(meta_file, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| return json.dumps(meta, indent=2, ensure_ascii=False) | |
| def refresh_models_dropdown(): | |
| models = list_saved_models() | |
| return gr.update(choices=models, value=models[0] if models else None) | |
| # ============================================================ | |
| # Gradio callbacks | |
| # ============================================================ | |
| def training_callback(dataset_name, conv1_channels, conv2_channels, kernel_size, | |
| dropout, fc_dim, learning_rate, batch_size, epochs, model_tag): | |
| for step in train_model( | |
| dataset_name=dataset_name, | |
| conv1_channels=conv1_channels, | |
| conv2_channels=conv2_channels, | |
| kernel_size=kernel_size, | |
| dropout=dropout, | |
| fc_dim=fc_dim, | |
| learning_rate=learning_rate, | |
| batch_size=batch_size, | |
| epochs=epochs, | |
| model_tag=model_tag, | |
| ): | |
| line_data = [ | |
| [e, tl, ta, vl, va] | |
| for e, tl, ta, vl, va in zip( | |
| step["history"]["epoch"], | |
| step["history"]["train_loss"], | |
| step["history"]["train_acc"], | |
| step["history"]["val_loss"], | |
| step["history"]["val_acc"], | |
| ) | |
| ] | |
| dropdown_update = gr.update() | |
| if step["finished"] and step["models"] is not None: | |
| models = step["models"] | |
| dropdown_update = gr.update(choices=models, value=models[0] if models else None) | |
| yield step["status"], line_data, dropdown_update, dropdown_update | |
| # ============================================================ | |
| # UI | |
| # ============================================================ | |
| initial_models = list_saved_models() | |
| with gr.Blocks(title="CNN Trainer and Tester") as demo: | |
| gr.Markdown("# Simple CNN Trainer and Tester") | |
| gr.Markdown( | |
| "This app is designed for lightweight image classification experiments on MNIST or FashionMNIST. " | |
| "Tab 1 trains a simple CNN. Tab 2 loads a saved model and tests it on uploaded images or random test samples." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Train"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dataset_name = gr.Dropdown( | |
| choices=["MNIST", "FashionMNIST"], | |
| value="MNIST", | |
| label="Dataset" | |
| ) | |
| conv1_channels = gr.Slider(8, 64, value=16, step=8, label="Conv1 Channels") | |
| conv2_channels = gr.Slider(16, 128, value=32, step=16, label="Conv2 Channels") | |
| kernel_size = gr.Dropdown(choices=[3, 5], value=3, label="Kernel Size") | |
| dropout = gr.Slider(0.0, 0.7, value=0.2, step=0.05, label="Dropout") | |
| fc_dim = gr.Slider(32, 256, value=128, step=32, label="FC Hidden Dimension") | |
| learning_rate = gr.Number(value=0.001, label="Learning Rate") | |
| batch_size = gr.Dropdown(choices=[32, 64, 128, 256], value=64, label="Batch Size") | |
| epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs") | |
| model_tag = gr.Textbox(label="Model Tag", placeholder="e.g. mnist_demo") | |
| train_btn = gr.Button("Start Training", variant="primary") | |
| with gr.Column(scale=1): | |
| train_status = gr.Textbox(label="Training Status", lines=10) | |
| train_plot = gr.LinePlot( | |
| x="epoch", | |
| y="value", | |
| color="metric", | |
| title="Training Curves", | |
| y_title="Value", | |
| x_title="Epoch", | |
| ) | |
| with gr.Tab("Test"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| choices=initial_models, | |
| value=initial_models[0] if initial_models else None, | |
| label="Select Saved Model" | |
| ) | |
| refresh_btn = gr.Button("Refresh Model List") | |
| model_info = gr.Code(label="Model Metadata", language="json") | |
| load_info_btn = gr.Button("Show Model Info") | |
| with gr.Column(scale=1): | |
| upload_image = gr.Image(type="pil", label="Upload Image") | |
| predict_btn = gr.Button("Predict Uploaded Image", variant="primary") | |
| predict_text = gr.Textbox(label="Prediction Result", lines=6) | |
| predict_probs = gr.Label(label="Class Probabilities") | |
| with gr.Row(): | |
| random_test_btn = gr.Button("Test Random Sample") | |
| with gr.Row(): | |
| random_sample_image = gr.Image(type="numpy", label="Random Test Image") | |
| random_sample_text = gr.Textbox(label="Random Sample Result", lines=6) | |
| random_sample_probs = gr.Label(label="Random Sample Probabilities") | |
| def format_lineplot_rows(rows): | |
| output = [] | |
| for epoch, train_loss, train_acc, val_loss, val_acc in rows: | |
| output.append({"epoch": epoch, "value": train_loss, "metric": "train_loss"}) | |
| output.append({"epoch": epoch, "value": train_acc, "metric": "train_acc"}) | |
| output.append({"epoch": epoch, "value": val_loss, "metric": "val_loss"}) | |
| output.append({"epoch": epoch, "value": val_acc, "metric": "val_acc"}) | |
| return output | |
| def wrapped_training_callback(*args): | |
| for status, rows, train_dd_update, test_dd_update in training_callback(*args): | |
| yield status, format_lineplot_rows(rows), train_dd_update, test_dd_update | |
| train_model_selector_hidden = gr.Dropdown(visible=False) | |
| test_model_selector_hidden = gr.Dropdown(visible=False) | |
| train_btn.click( | |
| fn=wrapped_training_callback, | |
| inputs=[ | |
| dataset_name, conv1_channels, conv2_channels, kernel_size, | |
| dropout, fc_dim, learning_rate, batch_size, epochs, model_tag | |
| ], | |
| outputs=[train_status, train_plot, train_model_selector_hidden, model_selector], | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_models_dropdown, | |
| inputs=None, | |
| outputs=model_selector, | |
| ) | |
| load_info_btn.click( | |
| fn=get_model_info, | |
| inputs=model_selector, | |
| outputs=model_info, | |
| ) | |
| predict_btn.click( | |
| fn=predict_uploaded_image, | |
| inputs=[model_selector, upload_image], | |
| outputs=[predict_text, predict_probs], | |
| ) | |
| random_test_btn.click( | |
| fn=test_random_sample, | |
| inputs=[model_selector], | |
| outputs=[random_sample_image, random_sample_text, random_sample_probs], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |