CircleStar commited on
Commit
7b7175d
·
verified ·
1 Parent(s): 4662103

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -198
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import json
3
  import time
4
- import math
5
  from datetime import datetime
6
  from typing import List, Tuple
7
 
@@ -13,13 +13,15 @@ from torch.utils.data import DataLoader, random_split
13
  from torchvision import datasets, transforms
14
  from PIL import Image
15
 
 
16
  # ============================================================
17
- # Configuration
18
  # ============================================================
19
  BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd()
20
  DATA_DIR = os.path.join(BASE_DIR, "data")
21
  MODEL_DIR = os.path.join(BASE_DIR, "saved_models")
22
  META_DIR = os.path.join(BASE_DIR, "saved_models_meta")
 
23
  os.makedirs(DATA_DIR, exist_ok=True)
24
  os.makedirs(MODEL_DIR, exist_ok=True)
25
  os.makedirs(META_DIR, exist_ok=True)
@@ -32,8 +34,14 @@ CLASS_NAMES = [str(i) for i in range(10)]
32
  # Model
33
  # ============================================================
34
  class SimpleCNN(nn.Module):
35
- def __init__(self, conv1_channels: int = 16, conv2_channels: int = 32, kernel_size: int = 3,
36
- dropout: float = 0.2, fc_dim: int = 128):
 
 
 
 
 
 
37
  super().__init__()
38
  padding = kernel_size // 2
39
 
@@ -41,13 +49,13 @@ class SimpleCNN(nn.Module):
41
  nn.Conv2d(1, conv1_channels, kernel_size=kernel_size, padding=padding),
42
  nn.ReLU(),
43
  nn.MaxPool2d(2),
 
44
  nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding),
45
  nn.ReLU(),
46
  nn.MaxPool2d(2),
47
  )
48
 
49
- # MNIST input = 1 x 28 x 28
50
- # after two 2x2 poolings => 7 x 7
51
  flattened_dim = conv2_channels * 7 * 7
52
 
53
  self.classifier = nn.Sequential(
@@ -65,13 +73,15 @@ class SimpleCNN(nn.Module):
65
 
66
 
67
  # ============================================================
68
- # Data utilities
69
  # ============================================================
70
  def get_datasets(dataset_name: str):
71
- transform = transforms.Compose([
72
- transforms.ToTensor(),
73
- transforms.Normalize((0.5,), (0.5,))
74
- ])
 
 
75
 
76
  if dataset_name == "MNIST":
77
  train_dataset = datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform)
@@ -90,23 +100,34 @@ def make_loaders(dataset_name: str, batch_size: int, val_ratio: float = 0.1):
90
 
91
  val_size = int(len(train_dataset) * val_ratio)
92
  train_size = len(train_dataset) - val_size
 
93
  train_subset, val_subset = random_split(train_dataset, [train_size, val_size])
94
 
95
  train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
96
  val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
97
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
 
98
  return train_loader, val_loader, test_loader
99
 
100
 
101
  # ============================================================
102
- # Model registry helpers
103
  # ============================================================
 
 
 
 
104
  def model_meta_path(model_name: str) -> str:
105
  return os.path.join(META_DIR, f"{model_name}.json")
106
 
107
 
108
- def model_weight_path(model_name: str) -> str:
109
- return os.path.join(MODEL_DIR, f"{model_name}.pt")
 
 
 
 
 
110
 
111
 
112
  def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
@@ -121,15 +142,6 @@ def save_model(model: nn.Module, model_name: str, config: dict, training_summary
121
  json.dump(payload, f, indent=2, ensure_ascii=False)
122
 
123
 
124
- def list_saved_models() -> List[str]:
125
- models = []
126
- for filename in os.listdir(META_DIR):
127
- if filename.endswith(".json"):
128
- models.append(filename[:-5])
129
- models.sort(reverse=True)
130
- return models
131
-
132
-
133
  def load_model(model_name: str) -> Tuple[nn.Module, dict]:
134
  meta_file = model_meta_path(model_name)
135
  weight_file = model_weight_path(model_name)
@@ -142,13 +154,14 @@ def load_model(model_name: str) -> Tuple[nn.Module, dict]:
142
  with open(meta_file, "r", encoding="utf-8") as f:
143
  meta = json.load(f)
144
 
145
- config = meta["config"]
 
146
  model = SimpleCNN(
147
- conv1_channels=config["conv1_channels"],
148
- conv2_channels=config["conv2_channels"],
149
- kernel_size=config["kernel_size"],
150
- dropout=config["dropout"],
151
- fc_dim=config["fc_dim"],
152
  )
153
  state_dict = torch.load(weight_file, map_location=DEVICE)
154
  model.load_state_dict(state_dict)
@@ -158,17 +171,18 @@ def load_model(model_name: str) -> Tuple[nn.Module, dict]:
158
 
159
 
160
  # ============================================================
161
- # Training / evaluation
162
  # ============================================================
163
  def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module):
164
  model.eval()
165
  total_loss = 0.0
166
- correct = 0
167
  total = 0
 
168
 
169
  with torch.no_grad():
170
  for images, labels in loader:
171
  images, labels = images.to(DEVICE), labels.to(DEVICE)
 
172
  outputs = model(images)
173
  loss = criterion(outputs, labels)
174
 
@@ -177,14 +191,23 @@ def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module):
177
  correct += (preds == labels).sum().item()
178
  total += labels.size(0)
179
 
180
- avg_loss = total_loss / total if total > 0 else 0.0
181
- acc = correct / total if total > 0 else 0.0
182
  return avg_loss, acc
183
 
184
 
185
- def train_model(dataset_name: str, conv1_channels: int, conv2_channels: int, kernel_size: int,
186
- dropout: float, fc_dim: int, learning_rate: float, batch_size: int,
187
- epochs: int, model_tag: str):
 
 
 
 
 
 
 
 
 
188
  train_loader, val_loader, test_loader = make_loaders(dataset_name, batch_size)
189
 
190
  model = SimpleCNN(
@@ -198,21 +221,15 @@ def train_model(dataset_name: str, conv1_channels: int, conv2_channels: int, ker
198
  criterion = nn.CrossEntropyLoss()
199
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
200
 
201
- history = {
202
- "epoch": [],
203
- "train_loss": [],
204
- "train_acc": [],
205
- "val_loss": [],
206
- "val_acc": [],
207
- }
208
-
209
  start_time = time.time()
210
 
211
  for epoch in range(1, epochs + 1):
212
  model.train()
213
  running_loss = 0.0
214
- correct = 0
215
  total = 0
 
216
 
217
  for images, labels in train_loader:
218
  images, labels = images.to(DEVICE), labels.to(DEVICE)
@@ -228,32 +245,36 @@ def train_model(dataset_name: str, conv1_channels: int, conv2_channels: int, ker
228
  correct += (preds == labels).sum().item()
229
  total += labels.size(0)
230
 
231
- train_loss = running_loss / total if total > 0 else 0.0
232
- train_acc = correct / total if total > 0 else 0.0
233
  val_loss, val_acc = evaluate(model, val_loader, criterion)
234
 
235
- history["epoch"].append(epoch)
236
- history["train_loss"].append(train_loss)
237
- history["train_acc"].append(train_acc)
238
- history["val_loss"].append(val_loss)
239
- history["val_acc"].append(val_acc)
240
-
241
- yield {
242
- "status": (
243
- f"Epoch {epoch}/{epochs} | "
244
- f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
245
- f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
246
- ),
247
- "history": history,
248
- "finished": False,
249
- "models": None,
250
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  test_loss, test_acc = evaluate(model, test_loader, criterion)
253
  elapsed = time.time() - start_time
254
 
255
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
256
- safe_tag = model_tag.strip().replace(" ", "_") if model_tag else dataset_name.lower()
257
  model_name = f"{safe_tag}_{timestamp}"
258
 
259
  config = {
@@ -267,49 +288,53 @@ def train_model(dataset_name: str, conv1_channels: int, conv2_channels: int, ker
267
  "batch_size": batch_size,
268
  "epochs": epochs,
269
  }
 
270
  training_summary = {
271
- "final_train_loss": history["train_loss"][-1],
272
- "final_train_acc": history["train_acc"][-1],
273
- "final_val_loss": history["val_loss"][-1],
274
- "final_val_acc": history["val_acc"][-1],
275
- "test_loss": test_loss,
276
- "test_acc": test_acc,
277
- "elapsed_seconds": elapsed,
278
  "device": str(DEVICE),
279
  }
 
280
  save_model(model, model_name, config, training_summary)
281
 
282
- final_message = (
283
- f"Training finished.\n\n"
284
- f"Saved model: {model_name}\n"
285
- f"Device: {DEVICE}\n"
286
- f"Test loss: {test_loss:.4f}\n"
287
- f"Test accuracy: {test_acc:.4f}\n"
288
- f"Elapsed time: {elapsed:.1f}s"
289
- )
290
 
291
- yield {
292
- "status": final_message,
293
- "history": history,
294
- "finished": True,
295
- "models": list_saved_models(),
296
- }
 
 
297
 
298
 
299
  # ============================================================
300
- # Inference helpers
301
  # ============================================================
302
  def preprocess_uploaded_image(image: Image.Image):
303
  if image is None:
304
  raise ValueError("Please upload an image.")
305
 
306
- transform = transforms.Compose([
307
- transforms.Grayscale(num_output_channels=1),
308
- transforms.Resize((28, 28)),
309
- transforms.ToTensor(),
310
- transforms.Normalize((0.5,), (0.5,))
311
- ])
312
-
 
313
  tensor = transform(image).unsqueeze(0)
314
  return tensor
315
 
@@ -326,16 +351,14 @@ def predict_uploaded_image(model_name: str, image: Image.Image):
326
  probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
327
  pred_idx = int(torch.argmax(logits, dim=1).item())
328
 
329
- conf = max(probs)
330
  result_text = (
331
  f"Prediction: {CLASS_NAMES[pred_idx]}\n"
332
- f"Confidence: {conf:.4f}\n\n"
333
  f"Model: {model_name}\n"
334
  f"Dataset: {meta['config']['dataset_name']}"
335
  )
336
-
337
- prob_table = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
338
- return result_text, prob_table
339
 
340
 
341
  def test_random_sample(model_name: str):
@@ -344,9 +367,9 @@ def test_random_sample(model_name: str):
344
 
345
  model, meta = load_model(model_name)
346
  dataset_name = meta["config"]["dataset_name"]
347
- _, test_dataset = get_datasets(dataset_name)
348
 
349
- idx = torch.randint(low=0, high=len(test_dataset), size=(1,)).item()
 
350
  image_tensor, label = test_dataset[idx]
351
 
352
  with torch.no_grad():
@@ -354,8 +377,8 @@ def test_random_sample(model_name: str):
354
  probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
355
  pred_idx = int(torch.argmax(logits, dim=1).item())
356
 
357
- display_img = image_tensor.squeeze(0).cpu()
358
- prob_table = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
359
  result_text = (
360
  f"Random test sample\n"
361
  f"Ground truth: {label}\n"
@@ -363,18 +386,21 @@ def test_random_sample(model_name: str):
363
  f"Confidence: {max(probs):.4f}\n"
364
  f"Model dataset: {dataset_name}"
365
  )
366
- return display_img, result_text, prob_table
 
367
 
368
 
369
  def get_model_info(model_name: str):
370
  if not model_name:
371
- return "No model selected."
 
372
  meta_file = model_meta_path(model_name)
373
  if not os.path.exists(meta_file):
374
- return "Selected model metadata not found."
 
375
  with open(meta_file, "r", encoding="utf-8") as f:
376
  meta = json.load(f)
377
- return json.dumps(meta, indent=2, ensure_ascii=False)
378
 
379
 
380
  def refresh_models_dropdown():
@@ -382,97 +408,55 @@ def refresh_models_dropdown():
382
  return gr.update(choices=models, value=models[0] if models else None)
383
 
384
 
385
- # ============================================================
386
- # Gradio callbacks
387
- # ============================================================
388
- def training_callback(dataset_name, conv1_channels, conv2_channels, kernel_size,
389
- dropout, fc_dim, learning_rate, batch_size, epochs, model_tag):
390
- for step in train_model(
391
- dataset_name=dataset_name,
392
- conv1_channels=conv1_channels,
393
- conv2_channels=conv2_channels,
394
- kernel_size=kernel_size,
395
- dropout=dropout,
396
- fc_dim=fc_dim,
397
- learning_rate=learning_rate,
398
- batch_size=batch_size,
399
- epochs=epochs,
400
- model_tag=model_tag,
401
- ):
402
- line_data = [
403
- [e, tl, ta, vl, va]
404
- for e, tl, ta, vl, va in zip(
405
- step["history"]["epoch"],
406
- step["history"]["train_loss"],
407
- step["history"]["train_acc"],
408
- step["history"]["val_loss"],
409
- step["history"]["val_acc"],
410
- )
411
- ]
412
-
413
- dropdown_update = gr.update()
414
- if step["finished"] and step["models"] is not None:
415
- models = step["models"]
416
- dropdown_update = gr.update(choices=models, value=models[0] if models else None)
417
-
418
- yield step["status"], line_data, dropdown_update, dropdown_update
419
-
420
-
421
  # ============================================================
422
  # UI
423
  # ============================================================
424
  initial_models = list_saved_models()
425
 
426
- with gr.Blocks(title="CNN Trainer and Tester") as demo:
427
- gr.Markdown("# Simple CNN Trainer and Tester")
428
  gr.Markdown(
429
- "This app is designed for lightweight image classification experiments on MNIST or FashionMNIST. "
430
- "Tab 1 trains a simple CNN. Tab 2 loads a saved model and tests it on uploaded images or random test samples."
431
  )
432
 
433
- with gr.Tab("Train"):
434
- with gr.Row():
435
- with gr.Column(scale=1):
436
- dataset_name = gr.Dropdown(
437
- choices=["MNIST", "FashionMNIST"],
438
- value="MNIST",
439
- label="Dataset"
440
- )
441
- conv1_channels = gr.Slider(8, 64, value=16, step=8, label="Conv1 Channels")
442
- conv2_channels = gr.Slider(16, 128, value=32, step=16, label="Conv2 Channels")
443
- kernel_size = gr.Dropdown(choices=[3, 5], value=3, label="Kernel Size")
444
- dropout = gr.Slider(0.0, 0.7, value=0.2, step=0.05, label="Dropout")
445
- fc_dim = gr.Slider(32, 256, value=128, step=32, label="FC Hidden Dimension")
446
- learning_rate = gr.Number(value=0.001, label="Learning Rate")
447
- batch_size = gr.Dropdown(choices=[32, 64, 128, 256], value=64, label="Batch Size")
448
- epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
449
- model_tag = gr.Textbox(label="Model Tag", placeholder="e.g. mnist_demo")
450
- train_btn = gr.Button("Start Training", variant="primary")
451
-
452
- with gr.Column(scale=1):
453
- train_status = gr.Textbox(label="Training Status", lines=10)
454
- train_plot = gr.LinePlot(
455
- x="epoch",
456
- y="value",
457
- color="metric",
458
- title="Training Curves",
459
- y_title="Value",
460
- x_title="Epoch",
461
- )
462
 
463
  with gr.Tab("Test"):
464
  with gr.Row():
465
- with gr.Column(scale=1):
466
  model_selector = gr.Dropdown(
467
  choices=initial_models,
468
  value=initial_models[0] if initial_models else None,
469
  label="Select Saved Model"
470
  )
471
  refresh_btn = gr.Button("Refresh Model List")
472
- model_info = gr.Code(label="Model Metadata", language="json")
473
  load_info_btn = gr.Button("Show Model Info")
 
474
 
475
- with gr.Column(scale=1):
476
  upload_image = gr.Image(type="pil", label="Upload Image")
477
  predict_btn = gr.Button("Predict Uploaded Image", variant="primary")
478
  predict_text = gr.Textbox(label="Prediction Result", lines=6)
@@ -480,34 +464,27 @@ with gr.Tab("Train"):
480
 
481
  with gr.Row():
482
  random_test_btn = gr.Button("Test Random Sample")
 
483
  with gr.Row():
484
  random_sample_image = gr.Image(type="numpy", label="Random Test Image")
485
  random_sample_text = gr.Textbox(label="Random Sample Result", lines=6)
486
  random_sample_probs = gr.Label(label="Random Sample Probabilities")
487
 
488
- def format_lineplot_rows(rows):
489
- output = []
490
- for epoch, train_loss, train_acc, val_loss, val_acc in rows:
491
- output.append({"epoch": epoch, "value": train_loss, "metric": "train_loss"})
492
- output.append({"epoch": epoch, "value": train_acc, "metric": "train_acc"})
493
- output.append({"epoch": epoch, "value": val_loss, "metric": "val_loss"})
494
- output.append({"epoch": epoch, "value": val_acc, "metric": "val_acc"})
495
- return output
496
-
497
- def wrapped_training_callback(*args):
498
- for status, rows, train_dd_update, test_dd_update in training_callback(*args):
499
- yield status, format_lineplot_rows(rows), train_dd_update, test_dd_update
500
-
501
- train_model_selector_hidden = gr.Dropdown(visible=False)
502
- test_model_selector_hidden = gr.Dropdown(visible=False)
503
-
504
  train_btn.click(
505
- fn=wrapped_training_callback,
506
  inputs=[
507
- dataset_name, conv1_channels, conv2_channels, kernel_size,
508
- dropout, fc_dim, learning_rate, batch_size, epochs, model_tag
 
 
 
 
 
 
 
 
509
  ],
510
- outputs=[train_status, train_plot, train_model_selector_hidden, model_selector],
511
  )
512
 
513
  refresh_btn.click(
@@ -534,5 +511,6 @@ with gr.Tab("Train"):
534
  outputs=[random_sample_image, random_sample_text, random_sample_probs],
535
  )
536
 
 
537
  if __name__ == "__main__":
538
- demo.launch()
 
1
  import os
2
  import json
3
  import time
4
+ import random
5
  from datetime import datetime
6
  from typing import List, Tuple
7
 
 
13
  from torchvision import datasets, transforms
14
  from PIL import Image
15
 
16
+
17
  # ============================================================
18
+ # Paths / Device
19
  # ============================================================
20
  BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd()
21
  DATA_DIR = os.path.join(BASE_DIR, "data")
22
  MODEL_DIR = os.path.join(BASE_DIR, "saved_models")
23
  META_DIR = os.path.join(BASE_DIR, "saved_models_meta")
24
+
25
  os.makedirs(DATA_DIR, exist_ok=True)
26
  os.makedirs(MODEL_DIR, exist_ok=True)
27
  os.makedirs(META_DIR, exist_ok=True)
 
34
  # Model
35
  # ============================================================
36
  class SimpleCNN(nn.Module):
37
+ def __init__(
38
+ self,
39
+ conv1_channels: int = 16,
40
+ conv2_channels: int = 32,
41
+ kernel_size: int = 3,
42
+ dropout: float = 0.2,
43
+ fc_dim: int = 128,
44
+ ):
45
  super().__init__()
46
  padding = kernel_size // 2
47
 
 
49
  nn.Conv2d(1, conv1_channels, kernel_size=kernel_size, padding=padding),
50
  nn.ReLU(),
51
  nn.MaxPool2d(2),
52
+
53
  nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding),
54
  nn.ReLU(),
55
  nn.MaxPool2d(2),
56
  )
57
 
58
+ # 28x28 -> 14x14 -> 7x7
 
59
  flattened_dim = conv2_channels * 7 * 7
60
 
61
  self.classifier = nn.Sequential(
 
73
 
74
 
75
  # ============================================================
76
+ # Dataset helpers
77
  # ============================================================
78
  def get_datasets(dataset_name: str):
79
+ transform = transforms.Compose(
80
+ [
81
+ transforms.ToTensor(),
82
+ transforms.Normalize((0.5,), (0.5,))
83
+ ]
84
+ )
85
 
86
  if dataset_name == "MNIST":
87
  train_dataset = datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform)
 
100
 
101
  val_size = int(len(train_dataset) * val_ratio)
102
  train_size = len(train_dataset) - val_size
103
+
104
  train_subset, val_subset = random_split(train_dataset, [train_size, val_size])
105
 
106
  train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
107
  val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
108
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
109
+
110
  return train_loader, val_loader, test_loader
111
 
112
 
113
  # ============================================================
114
+ # Model save/load helpers
115
  # ============================================================
116
+ def model_weight_path(model_name: str) -> str:
117
+ return os.path.join(MODEL_DIR, f"{model_name}.pt")
118
+
119
+
120
  def model_meta_path(model_name: str) -> str:
121
  return os.path.join(META_DIR, f"{model_name}.json")
122
 
123
 
124
+ def list_saved_models() -> List[str]:
125
+ names = []
126
+ for fn in os.listdir(META_DIR):
127
+ if fn.endswith(".json"):
128
+ names.append(fn[:-5])
129
+ names.sort(reverse=True)
130
+ return names
131
 
132
 
133
  def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
 
142
  json.dump(payload, f, indent=2, ensure_ascii=False)
143
 
144
 
 
 
 
 
 
 
 
 
 
145
  def load_model(model_name: str) -> Tuple[nn.Module, dict]:
146
  meta_file = model_meta_path(model_name)
147
  weight_file = model_weight_path(model_name)
 
154
  with open(meta_file, "r", encoding="utf-8") as f:
155
  meta = json.load(f)
156
 
157
+ cfg = meta["config"]
158
+
159
  model = SimpleCNN(
160
+ conv1_channels=cfg["conv1_channels"],
161
+ conv2_channels=cfg["conv2_channels"],
162
+ kernel_size=cfg["kernel_size"],
163
+ dropout=cfg["dropout"],
164
+ fc_dim=cfg["fc_dim"],
165
  )
166
  state_dict = torch.load(weight_file, map_location=DEVICE)
167
  model.load_state_dict(state_dict)
 
171
 
172
 
173
  # ============================================================
174
+ # Train / Eval
175
  # ============================================================
176
  def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module):
177
  model.eval()
178
  total_loss = 0.0
 
179
  total = 0
180
+ correct = 0
181
 
182
  with torch.no_grad():
183
  for images, labels in loader:
184
  images, labels = images.to(DEVICE), labels.to(DEVICE)
185
+
186
  outputs = model(images)
187
  loss = criterion(outputs, labels)
188
 
 
191
  correct += (preds == labels).sum().item()
192
  total += labels.size(0)
193
 
194
+ avg_loss = total_loss / total if total else 0.0
195
+ acc = correct / total if total else 0.0
196
  return avg_loss, acc
197
 
198
 
199
+ def train_model(
200
+ dataset_name: str,
201
+ conv1_channels: int,
202
+ conv2_channels: int,
203
+ kernel_size: int,
204
+ dropout: float,
205
+ fc_dim: int,
206
+ learning_rate: float,
207
+ batch_size: int,
208
+ epochs: int,
209
+ model_tag: str,
210
+ ):
211
  train_loader, val_loader, test_loader = make_loaders(dataset_name, batch_size)
212
 
213
  model = SimpleCNN(
 
221
  criterion = nn.CrossEntropyLoss()
222
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
223
 
224
+ history = []
225
+ logs = []
 
 
 
 
 
 
226
  start_time = time.time()
227
 
228
  for epoch in range(1, epochs + 1):
229
  model.train()
230
  running_loss = 0.0
 
231
  total = 0
232
+ correct = 0
233
 
234
  for images, labels in train_loader:
235
  images, labels = images.to(DEVICE), labels.to(DEVICE)
 
245
  correct += (preds == labels).sum().item()
246
  total += labels.size(0)
247
 
248
+ train_loss = running_loss / total if total else 0.0
249
+ train_acc = correct / total if total else 0.0
250
  val_loss, val_acc = evaluate(model, val_loader, criterion)
251
 
252
+ row = {
253
+ "epoch": epoch,
254
+ "train_loss": round(train_loss, 4),
255
+ "train_acc": round(train_acc, 4),
256
+ "val_loss": round(val_loss, 4),
257
+ "val_acc": round(val_acc, 4),
 
 
 
 
 
 
 
 
 
258
  }
259
+ history.append(row)
260
+
261
+ logs.append(
262
+ f"Epoch {epoch}/{epochs} | "
263
+ f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
264
+ f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
265
+ )
266
+
267
+ yield (
268
+ "\n".join(logs),
269
+ history,
270
+ gr.update(),
271
+ )
272
 
273
  test_loss, test_acc = evaluate(model, test_loader, criterion)
274
  elapsed = time.time() - start_time
275
 
276
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
277
+ safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else dataset_name.lower()
278
  model_name = f"{safe_tag}_{timestamp}"
279
 
280
  config = {
 
288
  "batch_size": batch_size,
289
  "epochs": epochs,
290
  }
291
+
292
  training_summary = {
293
+ "final_train_loss": history[-1]["train_loss"] if history else None,
294
+ "final_train_acc": history[-1]["train_acc"] if history else None,
295
+ "final_val_loss": history[-1]["val_loss"] if history else None,
296
+ "final_val_acc": history[-1]["val_acc"] if history else None,
297
+ "test_loss": round(test_loss, 4),
298
+ "test_acc": round(test_acc, 4),
299
+ "elapsed_seconds": round(elapsed, 2),
300
  "device": str(DEVICE),
301
  }
302
+
303
  save_model(model, model_name, config, training_summary)
304
 
305
+ logs.append("")
306
+ logs.append("Training finished.")
307
+ logs.append(f"Saved model: {model_name}")
308
+ logs.append(f"Device: {DEVICE}")
309
+ logs.append(f"Test loss: {test_loss:.4f}")
310
+ logs.append(f"Test accuracy: {test_acc:.4f}")
311
+ logs.append(f"Elapsed time: {elapsed:.1f}s")
 
312
 
313
+ models = list_saved_models()
314
+ selected = model_name if model_name in models else (models[0] if models else None)
315
+
316
+ yield (
317
+ "\n".join(logs),
318
+ history,
319
+ gr.update(choices=models, value=selected),
320
+ )
321
 
322
 
323
  # ============================================================
324
+ # Inference
325
  # ============================================================
326
  def preprocess_uploaded_image(image: Image.Image):
327
  if image is None:
328
  raise ValueError("Please upload an image.")
329
 
330
+ transform = transforms.Compose(
331
+ [
332
+ transforms.Grayscale(num_output_channels=1),
333
+ transforms.Resize((28, 28)),
334
+ transforms.ToTensor(),
335
+ transforms.Normalize((0.5,), (0.5,))
336
+ ]
337
+ )
338
  tensor = transform(image).unsqueeze(0)
339
  return tensor
340
 
 
351
  probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
352
  pred_idx = int(torch.argmax(logits, dim=1).item())
353
 
 
354
  result_text = (
355
  f"Prediction: {CLASS_NAMES[pred_idx]}\n"
356
+ f"Confidence: {max(probs):.4f}\n\n"
357
  f"Model: {model_name}\n"
358
  f"Dataset: {meta['config']['dataset_name']}"
359
  )
360
+ prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)}
361
+ return result_text, prob_dict
 
362
 
363
 
364
  def test_random_sample(model_name: str):
 
367
 
368
  model, meta = load_model(model_name)
369
  dataset_name = meta["config"]["dataset_name"]
 
370
 
371
+ _, test_dataset = get_datasets(dataset_name)
372
+ idx = random.randint(0, len(test_dataset) - 1)
373
  image_tensor, label = test_dataset[idx]
374
 
375
  with torch.no_grad():
 
377
  probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
378
  pred_idx = int(torch.argmax(logits, dim=1).item())
379
 
380
+ display_img = image_tensor.squeeze(0).cpu().numpy()
381
+
382
  result_text = (
383
  f"Random test sample\n"
384
  f"Ground truth: {label}\n"
 
386
  f"Confidence: {max(probs):.4f}\n"
387
  f"Model dataset: {dataset_name}"
388
  )
389
+ prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)}
390
+ return display_img, result_text, prob_dict
391
 
392
 
393
  def get_model_info(model_name: str):
394
  if not model_name:
395
+ return {"message": "No model selected."}
396
+
397
  meta_file = model_meta_path(model_name)
398
  if not os.path.exists(meta_file):
399
+ return {"message": "Metadata not found."}
400
+
401
  with open(meta_file, "r", encoding="utf-8") as f:
402
  meta = json.load(f)
403
+ return meta
404
 
405
 
406
  def refresh_models_dropdown():
 
408
  return gr.update(choices=models, value=models[0] if models else None)
409
 
410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  # ============================================================
412
  # UI
413
  # ============================================================
414
  initial_models = list_saved_models()
415
 
416
+ with gr.Blocks(title="Image Classification") as demo:
417
+ gr.Markdown("# Image Classification")
418
  gr.Markdown(
419
+ "Train a simple CNN on MNIST or FashionMNIST, then test saved models "
420
+ "with an uploaded image or a random sample."
421
  )
422
 
423
+ with gr.Tabs():
424
+ with gr.Tab("Train"):
425
+ with gr.Row():
426
+ with gr.Column():
427
+ dataset_name = gr.Dropdown(
428
+ choices=["MNIST", "FashionMNIST"],
429
+ value="MNIST",
430
+ label="Dataset"
431
+ )
432
+ conv1_channels = gr.Slider(8, 64, value=16, step=8, label="Conv1 Channels")
433
+ conv2_channels = gr.Slider(16, 128, value=32, step=16, label="Conv2 Channels")
434
+ kernel_size = gr.Dropdown(choices=[3, 5], value=3, label="Kernel Size")
435
+ dropout = gr.Slider(0.0, 0.7, value=0.2, step=0.05, label="Dropout")
436
+ fc_dim = gr.Slider(32, 256, value=128, step=32, label="FC Hidden Dimension")
437
+ learning_rate = gr.Number(value=0.001, label="Learning Rate")
438
+ batch_size = gr.Dropdown(choices=[32, 64, 128, 256], value=64, label="Batch Size")
439
+ epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
440
+ model_tag = gr.Textbox(label="Model Tag", placeholder="e.g. mnist_demo")
441
+ train_btn = gr.Button("Start Training", variant="primary")
442
+
443
+ with gr.Column():
444
+ train_status = gr.Textbox(label="Training Log", lines=18)
445
+ train_history = gr.JSON(label="Training History")
 
 
 
 
 
 
446
 
447
  with gr.Tab("Test"):
448
  with gr.Row():
449
+ with gr.Column():
450
  model_selector = gr.Dropdown(
451
  choices=initial_models,
452
  value=initial_models[0] if initial_models else None,
453
  label="Select Saved Model"
454
  )
455
  refresh_btn = gr.Button("Refresh Model List")
 
456
  load_info_btn = gr.Button("Show Model Info")
457
+ model_info = gr.JSON(label="Model Metadata")
458
 
459
+ with gr.Column():
460
  upload_image = gr.Image(type="pil", label="Upload Image")
461
  predict_btn = gr.Button("Predict Uploaded Image", variant="primary")
462
  predict_text = gr.Textbox(label="Prediction Result", lines=6)
 
464
 
465
  with gr.Row():
466
  random_test_btn = gr.Button("Test Random Sample")
467
+
468
  with gr.Row():
469
  random_sample_image = gr.Image(type="numpy", label="Random Test Image")
470
  random_sample_text = gr.Textbox(label="Random Sample Result", lines=6)
471
  random_sample_probs = gr.Label(label="Random Sample Probabilities")
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  train_btn.click(
474
+ fn=train_model,
475
  inputs=[
476
+ dataset_name,
477
+ conv1_channels,
478
+ conv2_channels,
479
+ kernel_size,
480
+ dropout,
481
+ fc_dim,
482
+ learning_rate,
483
+ batch_size,
484
+ epochs,
485
+ model_tag,
486
  ],
487
+ outputs=[train_status, train_history, model_selector],
488
  )
489
 
490
  refresh_btn.click(
 
511
  outputs=[random_sample_image, random_sample_text, random_sample_probs],
512
  )
513
 
514
+
515
  if __name__ == "__main__":
516
+ demo.launch()