CircleStar commited on
Commit
39e478f
·
verified ·
1 Parent(s): 78038de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -81
app.py CHANGED
@@ -5,6 +5,7 @@ import random
5
  from datetime import datetime
6
  from typing import List, Tuple
7
 
 
8
  import gradio as gr
9
  import torch
10
  import torch.nn as nn
@@ -15,7 +16,7 @@ from PIL import Image
15
 
16
 
17
  # ============================================================
18
- # Paths / Device
19
  # ============================================================
20
  BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd()
21
  DATA_DIR = os.path.join(BASE_DIR, "data")
@@ -26,8 +27,6 @@ os.makedirs(DATA_DIR, exist_ok=True)
26
  os.makedirs(MODEL_DIR, exist_ok=True)
27
  os.makedirs(META_DIR, exist_ok=True)
28
 
29
- # Force CPU on Hugging Face Spaces for this lightweight demo
30
- DEVICE = torch.device("cpu")
31
  CLASS_NAMES = [str(i) for i in range(10)]
32
 
33
 
@@ -56,8 +55,7 @@ class SimpleCNN(nn.Module):
56
  nn.MaxPool2d(2),
57
  )
58
 
59
- # 28x28 -> 14x14 -> 7x7
60
- flattened_dim = conv2_channels * 7 * 7
61
 
62
  self.classifier = nn.Sequential(
63
  nn.Flatten(),
@@ -132,7 +130,9 @@ def list_saved_models() -> List[str]:
132
 
133
 
134
  def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
135
- torch.save(model.state_dict(), model_weight_path(model_name))
 
 
136
  payload = {
137
  "model_name": model_name,
138
  "config": config,
@@ -143,7 +143,7 @@ def save_model(model: nn.Module, model_name: str, config: dict, training_summary
143
  json.dump(payload, f, indent=2, ensure_ascii=False)
144
 
145
 
146
- def load_model(model_name: str) -> Tuple[nn.Module, dict]:
147
  meta_file = model_meta_path(model_name)
148
  weight_file = model_weight_path(model_name)
149
 
@@ -164,40 +164,23 @@ def load_model(model_name: str) -> Tuple[nn.Module, dict]:
164
  dropout=cfg["dropout"],
165
  fc_dim=cfg["fc_dim"],
166
  )
167
- state_dict = torch.load(weight_file, map_location=DEVICE)
 
168
  model.load_state_dict(state_dict)
169
- model.to(DEVICE)
170
  model.eval()
171
  return model, meta
172
 
173
 
174
  # ============================================================
175
- # Train / Eval
176
  # ============================================================
177
- def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module):
178
- model.eval()
179
- total_loss = 0.0
180
- total = 0
181
- correct = 0
182
-
183
- with torch.no_grad():
184
- for images, labels in loader:
185
- images, labels = images.to(DEVICE), labels.to(DEVICE)
186
-
187
- outputs = model(images)
188
- loss = criterion(outputs, labels)
189
-
190
- total_loss += loss.item() * images.size(0)
191
- preds = outputs.argmax(dim=1)
192
- correct += (preds == labels).sum().item()
193
- total += labels.size(0)
194
 
195
- avg_loss = total_loss / total if total else 0.0
196
- acc = correct / total if total else 0.0
197
- return avg_loss, acc
198
 
199
-
200
- def train_model(
201
  dataset_name: str,
202
  conv1_channels: int,
203
  conv2_channels: int,
@@ -209,6 +192,8 @@ def train_model(
209
  epochs: int,
210
  model_tag: str,
211
  ):
 
 
212
  train_loader, val_loader, test_loader = make_loaders(dataset_name, batch_size)
213
 
214
  model = SimpleCNN(
@@ -217,7 +202,7 @@ def train_model(
217
  kernel_size=kernel_size,
218
  dropout=dropout,
219
  fc_dim=fc_dim,
220
- ).to(DEVICE)
221
 
222
  criterion = nn.CrossEntropyLoss()
223
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
@@ -226,6 +211,27 @@ def train_model(
226
  logs = []
227
  start_time = time.time()
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  for epoch in range(1, epochs + 1):
230
  model.train()
231
  running_loss = 0.0
@@ -233,7 +239,7 @@ def train_model(
233
  correct = 0
234
 
235
  for images, labels in train_loader:
236
- images, labels = images.to(DEVICE), labels.to(DEVICE)
237
 
238
  optimizer.zero_grad()
239
  outputs = model(images)
@@ -248,7 +254,7 @@ def train_model(
248
 
249
  train_loss = running_loss / total if total else 0.0
250
  train_acc = correct / total if total else 0.0
251
- val_loss, val_acc = evaluate(model, val_loader, criterion)
252
 
253
  row = {
254
  "epoch": epoch,
@@ -265,13 +271,7 @@ def train_model(
265
  f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
266
  )
267
 
268
- yield (
269
- "\n".join(logs),
270
- history,
271
- gr.update(),
272
- )
273
-
274
- test_loss, test_acc = evaluate(model, test_loader, criterion)
275
  elapsed = time.time() - start_time
276
 
277
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -298,7 +298,7 @@ def train_model(
298
  "test_loss": round(test_loss, 4),
299
  "test_acc": round(test_acc, 4),
300
  "elapsed_seconds": round(elapsed, 2),
301
- "device": str(DEVICE),
302
  }
303
 
304
  save_model(model, model_name, config, training_summary)
@@ -306,27 +306,24 @@ def train_model(
306
  logs.append("")
307
  logs.append("Training finished.")
308
  logs.append(f"Saved model: {model_name}")
309
- logs.append(f"Device: {DEVICE}")
310
  logs.append(f"Test loss: {test_loss:.4f}")
311
  logs.append(f"Test accuracy: {test_acc:.4f}")
312
  logs.append(f"Elapsed time: {elapsed:.1f}s")
313
 
314
- models = list_saved_models()
315
- selected = model_name if model_name in models else (models[0] if models else None)
316
 
317
- yield (
318
- "\n".join(logs),
319
- history,
320
- gr.update(choices=models, value=selected),
321
- )
322
 
 
 
 
 
323
 
324
- # ============================================================
325
- # Inference
326
- # ============================================================
327
- def preprocess_uploaded_image(image: Image.Image):
328
  if image is None:
329
- raise ValueError("Please upload an image.")
 
 
 
330
 
331
  transform = transforms.Compose(
332
  [
@@ -336,37 +333,32 @@ def preprocess_uploaded_image(image: Image.Image):
336
  transforms.Normalize((0.5,), (0.5,))
337
  ]
338
  )
339
- tensor = transform(image).unsqueeze(0)
340
- return tensor
341
 
342
-
343
- def predict_uploaded_image(model_name: str, image: Image.Image):
344
- if not model_name:
345
- return "Please select a model.", None
346
-
347
- model, meta = load_model(model_name)
348
- tensor = preprocess_uploaded_image(image).to(DEVICE)
349
 
350
  with torch.no_grad():
351
  logits = model(tensor)
352
- probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
353
  pred_idx = int(torch.argmax(logits, dim=1).item())
354
 
355
  result_text = (
356
  f"Prediction: {CLASS_NAMES[pred_idx]}\n"
357
  f"Confidence: {max(probs):.4f}\n\n"
358
  f"Model: {model_name}\n"
359
- f"Dataset: {meta['config']['dataset_name']}"
 
360
  )
361
  prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)}
362
  return result_text, prob_dict
363
 
364
 
365
- def test_random_sample(model_name: str):
 
366
  if not model_name:
367
  return None, "Please select a model.", None
368
 
369
- model, meta = load_model(model_name)
 
370
  dataset_name = meta["config"]["dataset_name"]
371
 
372
  _, test_dataset = get_datasets(dataset_name)
@@ -374,8 +366,8 @@ def test_random_sample(model_name: str):
374
  image_tensor, label = test_dataset[idx]
375
 
376
  with torch.no_grad():
377
- logits = model(image_tensor.unsqueeze(0).to(DEVICE))
378
- probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
379
  pred_idx = int(torch.argmax(logits, dim=1).item())
380
 
381
  display_img = image_tensor.squeeze(0).cpu().numpy()
@@ -385,12 +377,62 @@ def test_random_sample(model_name: str):
385
  f"Ground truth: {label}\n"
386
  f"Prediction: {pred_idx}\n"
387
  f"Confidence: {max(probs):.4f}\n"
388
- f"Model dataset: {dataset_name}"
 
389
  )
390
  prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)}
391
  return display_img, result_text, prob_dict
392
 
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  def get_model_info(model_name: str):
395
  if not model_name:
396
  return {"message": "No model selected."}
@@ -428,7 +470,7 @@ with gr.Blocks(title="Image Classification") as demo:
428
  dataset_name = gr.Dropdown(
429
  choices=["MNIST", "FashionMNIST"],
430
  value="MNIST",
431
- label="Dataset"
432
  )
433
  conv1_channels = gr.Slider(8, 64, value=16, step=8, label="Conv1 Channels")
434
  conv2_channels = gr.Slider(16, 128, value=32, step=16, label="Conv2 Channels")
@@ -444,6 +486,7 @@ with gr.Blocks(title="Image Classification") as demo:
444
  with gr.Column():
445
  train_status = gr.Textbox(label="Training Log", lines=18)
446
  train_history = gr.JSON(label="Training History")
 
447
 
448
  with gr.Tab("Test"):
449
  with gr.Row():
@@ -451,7 +494,7 @@ with gr.Blocks(title="Image Classification") as demo:
451
  model_selector = gr.Dropdown(
452
  choices=initial_models,
453
  value=initial_models[0] if initial_models else None,
454
- label="Select Saved Model"
455
  )
456
  refresh_btn = gr.Button("Refresh Model List")
457
  load_info_btn = gr.Button("Show Model Info")
@@ -460,7 +503,7 @@ with gr.Blocks(title="Image Classification") as demo:
460
  with gr.Column():
461
  upload_image = gr.Image(type="pil", label="Upload Image")
462
  predict_btn = gr.Button("Predict Uploaded Image", variant="primary")
463
- predict_text = gr.Textbox(label="Prediction Result", lines=6)
464
  predict_probs = gr.Label(label="Class Probabilities")
465
 
466
  with gr.Row():
@@ -468,11 +511,11 @@ with gr.Blocks(title="Image Classification") as demo:
468
 
469
  with gr.Row():
470
  random_sample_image = gr.Image(type="numpy", label="Random Test Image")
471
- random_sample_text = gr.Textbox(label="Random Sample Result", lines=6)
472
  random_sample_probs = gr.Label(label="Random Sample Probabilities")
473
 
474
  train_btn.click(
475
- fn=train_model,
476
  inputs=[
477
  dataset_name,
478
  conv1_channels,
@@ -485,7 +528,7 @@ with gr.Blocks(title="Image Classification") as demo:
485
  epochs,
486
  model_tag,
487
  ],
488
- outputs=[train_status, train_history, model_selector],
489
  )
490
 
491
  refresh_btn.click(
@@ -501,17 +544,17 @@ with gr.Blocks(title="Image Classification") as demo:
501
  )
502
 
503
  predict_btn.click(
504
- fn=predict_uploaded_image,
505
  inputs=[model_selector, upload_image],
506
  outputs=[predict_text, predict_probs],
507
  )
508
 
509
  random_test_btn.click(
510
- fn=test_random_sample,
511
  inputs=[model_selector],
512
  outputs=[random_sample_image, random_sample_text, random_sample_probs],
513
  )
514
 
515
 
516
  if __name__ == "__main__":
517
- demo.launch(ssr_mode=False)
 
5
  from datetime import datetime
6
  from typing import List, Tuple
7
 
8
+ import spaces
9
  import gradio as gr
10
  import torch
11
  import torch.nn as nn
 
16
 
17
 
18
  # ============================================================
19
+ # Paths / basic config
20
  # ============================================================
21
  BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd()
22
  DATA_DIR = os.path.join(BASE_DIR, "data")
 
27
  os.makedirs(MODEL_DIR, exist_ok=True)
28
  os.makedirs(META_DIR, exist_ok=True)
29
 
 
 
30
  CLASS_NAMES = [str(i) for i in range(10)]
31
 
32
 
 
55
  nn.MaxPool2d(2),
56
  )
57
 
58
+ flattened_dim = conv2_channels * 7 * 7 # 28x28 -> 14x14 -> 7x7
 
59
 
60
  self.classifier = nn.Sequential(
61
  nn.Flatten(),
 
130
 
131
 
132
  def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
133
+ cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
134
+ torch.save(cpu_state_dict, model_weight_path(model_name))
135
+
136
  payload = {
137
  "model_name": model_name,
138
  "config": config,
 
143
  json.dump(payload, f, indent=2, ensure_ascii=False)
144
 
145
 
146
+ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
147
  meta_file = model_meta_path(model_name)
148
  weight_file = model_weight_path(model_name)
149
 
 
164
  dropout=cfg["dropout"],
165
  fc_dim=cfg["fc_dim"],
166
  )
167
+
168
+ state_dict = torch.load(weight_file, map_location="cpu")
169
  model.load_state_dict(state_dict)
170
+ model.to(device)
171
  model.eval()
172
  return model, meta
173
 
174
 
175
  # ============================================================
176
+ # ZeroGPU helpers
177
  # ============================================================
178
+ def get_runtime_device() -> torch.device:
179
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
 
 
 
181
 
182
+ @spaces.GPU(duration=120)
183
+ def _train_on_gpu(
184
  dataset_name: str,
185
  conv1_channels: int,
186
  conv2_channels: int,
 
192
  epochs: int,
193
  model_tag: str,
194
  ):
195
+ device = get_runtime_device()
196
+
197
  train_loader, val_loader, test_loader = make_loaders(dataset_name, batch_size)
198
 
199
  model = SimpleCNN(
 
202
  kernel_size=kernel_size,
203
  dropout=dropout,
204
  fc_dim=fc_dim,
205
+ ).to(device)
206
 
207
  criterion = nn.CrossEntropyLoss()
208
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
 
211
  logs = []
212
  start_time = time.time()
213
 
214
+ def evaluate(loader):
215
+ model.eval()
216
+ total_loss = 0.0
217
+ total = 0
218
+ correct = 0
219
+
220
+ with torch.no_grad():
221
+ for images, labels in loader:
222
+ images, labels = images.to(device), labels.to(device)
223
+ outputs = model(images)
224
+ loss = criterion(outputs, labels)
225
+
226
+ total_loss += loss.item() * images.size(0)
227
+ preds = outputs.argmax(dim=1)
228
+ correct += (preds == labels).sum().item()
229
+ total += labels.size(0)
230
+
231
+ avg_loss = total_loss / total if total else 0.0
232
+ acc = correct / total if total else 0.0
233
+ return avg_loss, acc
234
+
235
  for epoch in range(1, epochs + 1):
236
  model.train()
237
  running_loss = 0.0
 
239
  correct = 0
240
 
241
  for images, labels in train_loader:
242
+ images, labels = images.to(device), labels.to(device)
243
 
244
  optimizer.zero_grad()
245
  outputs = model(images)
 
254
 
255
  train_loss = running_loss / total if total else 0.0
256
  train_acc = correct / total if total else 0.0
257
+ val_loss, val_acc = evaluate(val_loader)
258
 
259
  row = {
260
  "epoch": epoch,
 
271
  f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
272
  )
273
 
274
+ test_loss, test_acc = evaluate(test_loader)
 
 
 
 
 
 
275
  elapsed = time.time() - start_time
276
 
277
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
298
  "test_loss": round(test_loss, 4),
299
  "test_acc": round(test_acc, 4),
300
  "elapsed_seconds": round(elapsed, 2),
301
+ "device": str(device),
302
  }
303
 
304
  save_model(model, model_name, config, training_summary)
 
306
  logs.append("")
307
  logs.append("Training finished.")
308
  logs.append(f"Saved model: {model_name}")
309
+ logs.append(f"Device: {device}")
310
  logs.append(f"Test loss: {test_loss:.4f}")
311
  logs.append(f"Test accuracy: {test_acc:.4f}")
312
  logs.append(f"Elapsed time: {elapsed:.1f}s")
313
 
314
+ return "\n".join(logs), history, training_summary, model_name
 
315
 
 
 
 
 
 
316
 
317
+ @spaces.GPU(duration=60)
318
+ def _predict_uploaded_image_gpu(model_name: str, image: Image.Image):
319
+ if not model_name:
320
+ return "Please select a model.", None
321
 
 
 
 
 
322
  if image is None:
323
+ return "Please upload an image.", None
324
+
325
+ device = get_runtime_device()
326
+ model, meta = load_model(model_name, device)
327
 
328
  transform = transforms.Compose(
329
  [
 
333
  transforms.Normalize((0.5,), (0.5,))
334
  ]
335
  )
 
 
336
 
337
+ tensor = transform(image).unsqueeze(0).to(device)
 
 
 
 
 
 
338
 
339
  with torch.no_grad():
340
  logits = model(tensor)
341
+ probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu().tolist()
342
  pred_idx = int(torch.argmax(logits, dim=1).item())
343
 
344
  result_text = (
345
  f"Prediction: {CLASS_NAMES[pred_idx]}\n"
346
  f"Confidence: {max(probs):.4f}\n\n"
347
  f"Model: {model_name}\n"
348
+ f"Dataset: {meta['config']['dataset_name']}\n"
349
+ f"Runtime device: {device}"
350
  )
351
  prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)}
352
  return result_text, prob_dict
353
 
354
 
355
+ @spaces.GPU(duration=60)
356
+ def _test_random_sample_gpu(model_name: str):
357
  if not model_name:
358
  return None, "Please select a model.", None
359
 
360
+ device = get_runtime_device()
361
+ model, meta = load_model(model_name, device)
362
  dataset_name = meta["config"]["dataset_name"]
363
 
364
  _, test_dataset = get_datasets(dataset_name)
 
366
  image_tensor, label = test_dataset[idx]
367
 
368
  with torch.no_grad():
369
+ logits = model(image_tensor.unsqueeze(0).to(device))
370
+ probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu().tolist()
371
  pred_idx = int(torch.argmax(logits, dim=1).item())
372
 
373
  display_img = image_tensor.squeeze(0).cpu().numpy()
 
377
  f"Ground truth: {label}\n"
378
  f"Prediction: {pred_idx}\n"
379
  f"Confidence: {max(probs):.4f}\n"
380
+ f"Model dataset: {dataset_name}\n"
381
+ f"Runtime device: {device}"
382
  )
383
  prob_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(10)}
384
  return display_img, result_text, prob_dict
385
 
386
 
387
+ # ============================================================
388
+ # UI callbacks
389
+ # ============================================================
390
+ def train_callback(
391
+ dataset_name,
392
+ conv1_channels,
393
+ conv2_channels,
394
+ kernel_size,
395
+ dropout,
396
+ fc_dim,
397
+ learning_rate,
398
+ batch_size,
399
+ epochs,
400
+ model_tag,
401
+ ):
402
+ try:
403
+ logs, history, summary, model_name = _train_on_gpu(
404
+ dataset_name,
405
+ int(conv1_channels),
406
+ int(conv2_channels),
407
+ int(kernel_size),
408
+ float(dropout),
409
+ int(fc_dim),
410
+ float(learning_rate),
411
+ int(batch_size),
412
+ int(epochs),
413
+ model_tag,
414
+ )
415
+ models = list_saved_models()
416
+ selected = model_name if model_name in models else (models[0] if models else None)
417
+ return logs, history, summary, gr.update(choices=models, value=selected)
418
+ except Exception as e:
419
+ return f"Training failed:\n{str(e)}", None, None, gr.update()
420
+
421
+
422
+ def predict_uploaded_image_callback(model_name, image):
423
+ try:
424
+ return _predict_uploaded_image_gpu(model_name, image)
425
+ except Exception as e:
426
+ return f"Prediction failed:\n{str(e)}", None
427
+
428
+
429
+ def test_random_sample_callback(model_name):
430
+ try:
431
+ return _test_random_sample_gpu(model_name)
432
+ except Exception as e:
433
+ return None, f"Random test failed:\n{str(e)}", None
434
+
435
+
436
  def get_model_info(model_name: str):
437
  if not model_name:
438
  return {"message": "No model selected."}
 
470
  dataset_name = gr.Dropdown(
471
  choices=["MNIST", "FashionMNIST"],
472
  value="MNIST",
473
+ label="Dataset",
474
  )
475
  conv1_channels = gr.Slider(8, 64, value=16, step=8, label="Conv1 Channels")
476
  conv2_channels = gr.Slider(16, 128, value=32, step=16, label="Conv2 Channels")
 
486
  with gr.Column():
487
  train_status = gr.Textbox(label="Training Log", lines=18)
488
  train_history = gr.JSON(label="Training History")
489
+ train_summary = gr.JSON(label="Training Summary")
490
 
491
  with gr.Tab("Test"):
492
  with gr.Row():
 
494
  model_selector = gr.Dropdown(
495
  choices=initial_models,
496
  value=initial_models[0] if initial_models else None,
497
+ label="Select Saved Model",
498
  )
499
  refresh_btn = gr.Button("Refresh Model List")
500
  load_info_btn = gr.Button("Show Model Info")
 
503
  with gr.Column():
504
  upload_image = gr.Image(type="pil", label="Upload Image")
505
  predict_btn = gr.Button("Predict Uploaded Image", variant="primary")
506
+ predict_text = gr.Textbox(label="Prediction Result", lines=7)
507
  predict_probs = gr.Label(label="Class Probabilities")
508
 
509
  with gr.Row():
 
511
 
512
  with gr.Row():
513
  random_sample_image = gr.Image(type="numpy", label="Random Test Image")
514
+ random_sample_text = gr.Textbox(label="Random Sample Result", lines=7)
515
  random_sample_probs = gr.Label(label="Random Sample Probabilities")
516
 
517
  train_btn.click(
518
+ fn=train_callback,
519
  inputs=[
520
  dataset_name,
521
  conv1_channels,
 
528
  epochs,
529
  model_tag,
530
  ],
531
+ outputs=[train_status, train_history, train_summary, model_selector],
532
  )
533
 
534
  refresh_btn.click(
 
544
  )
545
 
546
  predict_btn.click(
547
+ fn=predict_uploaded_image_callback,
548
  inputs=[model_selector, upload_image],
549
  outputs=[predict_text, predict_probs],
550
  )
551
 
552
  random_test_btn.click(
553
+ fn=test_random_sample_callback,
554
  inputs=[model_selector],
555
  outputs=[random_sample_image, random_sample_text, random_sample_probs],
556
  )
557
 
558
 
559
  if __name__ == "__main__":
560
+ demo.launch()