CircleStar commited on
Commit
ca16013
·
verified ·
1 Parent(s): 8f0be2d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +542 -0
app.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import math
5
+ from datetime import datetime
6
+ from typing import List, Tuple
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import DataLoader, random_split
13
+ from torchvision import datasets, transforms
14
+ from PIL import Image
15
+
16
+ # ============================================================
17
+ # Configuration
18
+ # ============================================================
19
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd()
20
+ DATA_DIR = os.path.join(BASE_DIR, "data")
21
+ MODEL_DIR = os.path.join(BASE_DIR, "saved_models")
22
+ META_DIR = os.path.join(BASE_DIR, "saved_models_meta")
23
+ os.makedirs(DATA_DIR, exist_ok=True)
24
+ os.makedirs(MODEL_DIR, exist_ok=True)
25
+ os.makedirs(META_DIR, exist_ok=True)
26
+
27
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ CLASS_NAMES = [str(i) for i in range(10)]
29
+
30
+
31
+ # ============================================================
32
+ # Model
33
+ # ============================================================
34
+ class SimpleCNN(nn.Module):
35
+ def __init__(self, conv1_channels: int = 16, conv2_channels: int = 32, kernel_size: int = 3,
36
+ dropout: float = 0.2, fc_dim: int = 128):
37
+ super().__init__()
38
+ padding = kernel_size // 2
39
+
40
+ self.features = nn.Sequential(
41
+ nn.Conv2d(1, conv1_channels, kernel_size=kernel_size, padding=padding),
42
+ nn.ReLU(),
43
+ nn.MaxPool2d(2),
44
+ nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding),
45
+ nn.ReLU(),
46
+ nn.MaxPool2d(2),
47
+ )
48
+
49
+ # MNIST input = 1 x 28 x 28
50
+ # after two 2x2 poolings => 7 x 7
51
+ flattened_dim = conv2_channels * 7 * 7
52
+
53
+ self.classifier = nn.Sequential(
54
+ nn.Flatten(),
55
+ nn.Linear(flattened_dim, fc_dim),
56
+ nn.ReLU(),
57
+ nn.Dropout(dropout),
58
+ nn.Linear(fc_dim, 10),
59
+ )
60
+
61
+ def forward(self, x):
62
+ x = self.features(x)
63
+ x = self.classifier(x)
64
+ return x
65
+
66
+
67
+ # ============================================================
68
+ # Data utilities
69
+ # ============================================================
70
+ def get_datasets(dataset_name: str):
71
+ transform = transforms.Compose([
72
+ transforms.ToTensor(),
73
+ transforms.Normalize((0.5,), (0.5,))
74
+ ])
75
+
76
+ if dataset_name == "MNIST":
77
+ train_dataset = datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform)
78
+ test_dataset = datasets.MNIST(DATA_DIR, train=False, download=True, transform=transform)
79
+ elif dataset_name == "FashionMNIST":
80
+ train_dataset = datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform)
81
+ test_dataset = datasets.FashionMNIST(DATA_DIR, train=False, download=True, transform=transform)
82
+ else:
83
+ raise ValueError(f"Unsupported dataset: {dataset_name}")
84
+
85
+ return train_dataset, test_dataset
86
+
87
+
88
+ def make_loaders(dataset_name: str, batch_size: int, val_ratio: float = 0.1):
89
+ train_dataset, test_dataset = get_datasets(dataset_name)
90
+
91
+ val_size = int(len(train_dataset) * val_ratio)
92
+ train_size = len(train_dataset) - val_size
93
+ train_subset, val_subset = random_split(train_dataset, [train_size, val_size])
94
+
95
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
96
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
97
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
98
+ return train_loader, val_loader, test_loader
99
+
100
+
101
+ # ============================================================
102
+ # Model registry helpers
103
+ # ============================================================
104
+ def model_meta_path(model_name: str) -> str:
105
+ return os.path.join(META_DIR, f"{model_name}.json")
106
+
107
+
108
+ def model_weight_path(model_name: str) -> str:
109
+ return os.path.join(MODEL_DIR, f"{model_name}.pt")
110
+
111
+
112
+ def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
113
+ torch.save(model.state_dict(), model_weight_path(model_name))
114
+ payload = {
115
+ "model_name": model_name,
116
+ "config": config,
117
+ "training_summary": training_summary,
118
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
119
+ }
120
+ with open(model_meta_path(model_name), "w", encoding="utf-8") as f:
121
+ json.dump(payload, f, indent=2, ensure_ascii=False)
122
+
123
+
124
+ def list_saved_models() -> List[str]:
125
+ models = []
126
+ for filename in os.listdir(META_DIR):
127
+ if filename.endswith(".json"):
128
+ models.append(filename[:-5])
129
+ models.sort(reverse=True)
130
+ return models
131
+
132
+
133
+ def load_model(model_name: str) -> Tuple[nn.Module, dict]:
134
+ meta_file = model_meta_path(model_name)
135
+ weight_file = model_weight_path(model_name)
136
+
137
+ if not os.path.exists(meta_file):
138
+ raise FileNotFoundError(f"Metadata not found for model: {model_name}")
139
+ if not os.path.exists(weight_file):
140
+ raise FileNotFoundError(f"Weights not found for model: {model_name}")
141
+
142
+ with open(meta_file, "r", encoding="utf-8") as f:
143
+ meta = json.load(f)
144
+
145
+ config = meta["config"]
146
+ model = SimpleCNN(
147
+ conv1_channels=config["conv1_channels"],
148
+ conv2_channels=config["conv2_channels"],
149
+ kernel_size=config["kernel_size"],
150
+ dropout=config["dropout"],
151
+ fc_dim=config["fc_dim"],
152
+ )
153
+ state_dict = torch.load(weight_file, map_location=DEVICE)
154
+ model.load_state_dict(state_dict)
155
+ model.to(DEVICE)
156
+ model.eval()
157
+ return model, meta
158
+
159
+
160
+ # ============================================================
161
+ # Training / evaluation
162
+ # ============================================================
163
+ def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module):
164
+ model.eval()
165
+ total_loss = 0.0
166
+ correct = 0
167
+ total = 0
168
+
169
+ with torch.no_grad():
170
+ for images, labels in loader:
171
+ images, labels = images.to(DEVICE), labels.to(DEVICE)
172
+ outputs = model(images)
173
+ loss = criterion(outputs, labels)
174
+
175
+ total_loss += loss.item() * images.size(0)
176
+ preds = outputs.argmax(dim=1)
177
+ correct += (preds == labels).sum().item()
178
+ total += labels.size(0)
179
+
180
+ avg_loss = total_loss / total if total > 0 else 0.0
181
+ acc = correct / total if total > 0 else 0.0
182
+ return avg_loss, acc
183
+
184
+
185
+ def train_model(dataset_name: str, conv1_channels: int, conv2_channels: int, kernel_size: int,
186
+ dropout: float, fc_dim: int, learning_rate: float, batch_size: int,
187
+ epochs: int, model_tag: str):
188
+ train_loader, val_loader, test_loader = make_loaders(dataset_name, batch_size)
189
+
190
+ model = SimpleCNN(
191
+ conv1_channels=conv1_channels,
192
+ conv2_channels=conv2_channels,
193
+ kernel_size=kernel_size,
194
+ dropout=dropout,
195
+ fc_dim=fc_dim,
196
+ ).to(DEVICE)
197
+
198
+ criterion = nn.CrossEntropyLoss()
199
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
200
+
201
+ history = {
202
+ "epoch": [],
203
+ "train_loss": [],
204
+ "train_acc": [],
205
+ "val_loss": [],
206
+ "val_acc": [],
207
+ }
208
+
209
+ start_time = time.time()
210
+
211
+ for epoch in range(1, epochs + 1):
212
+ model.train()
213
+ running_loss = 0.0
214
+ correct = 0
215
+ total = 0
216
+
217
+ for images, labels in train_loader:
218
+ images, labels = images.to(DEVICE), labels.to(DEVICE)
219
+
220
+ optimizer.zero_grad()
221
+ outputs = model(images)
222
+ loss = criterion(outputs, labels)
223
+ loss.backward()
224
+ optimizer.step()
225
+
226
+ running_loss += loss.item() * images.size(0)
227
+ preds = outputs.argmax(dim=1)
228
+ correct += (preds == labels).sum().item()
229
+ total += labels.size(0)
230
+
231
+ train_loss = running_loss / total if total > 0 else 0.0
232
+ train_acc = correct / total if total > 0 else 0.0
233
+ val_loss, val_acc = evaluate(model, val_loader, criterion)
234
+
235
+ history["epoch"].append(epoch)
236
+ history["train_loss"].append(train_loss)
237
+ history["train_acc"].append(train_acc)
238
+ history["val_loss"].append(val_loss)
239
+ history["val_acc"].append(val_acc)
240
+
241
+ yield {
242
+ "status": (
243
+ f"Epoch {epoch}/{epochs} | "
244
+ f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
245
+ f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
246
+ ),
247
+ "history": history,
248
+ "finished": False,
249
+ "models": None,
250
+ }
251
+
252
+ test_loss, test_acc = evaluate(model, test_loader, criterion)
253
+ elapsed = time.time() - start_time
254
+
255
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
256
+ safe_tag = model_tag.strip().replace(" ", "_") if model_tag else dataset_name.lower()
257
+ model_name = f"{safe_tag}_{timestamp}"
258
+
259
+ config = {
260
+ "dataset_name": dataset_name,
261
+ "conv1_channels": conv1_channels,
262
+ "conv2_channels": conv2_channels,
263
+ "kernel_size": kernel_size,
264
+ "dropout": dropout,
265
+ "fc_dim": fc_dim,
266
+ "learning_rate": learning_rate,
267
+ "batch_size": batch_size,
268
+ "epochs": epochs,
269
+ }
270
+ training_summary = {
271
+ "final_train_loss": history["train_loss"][-1],
272
+ "final_train_acc": history["train_acc"][-1],
273
+ "final_val_loss": history["val_loss"][-1],
274
+ "final_val_acc": history["val_acc"][-1],
275
+ "test_loss": test_loss,
276
+ "test_acc": test_acc,
277
+ "elapsed_seconds": elapsed,
278
+ "device": str(DEVICE),
279
+ }
280
+ save_model(model, model_name, config, training_summary)
281
+
282
+ final_message = (
283
+ f"Training finished.\n\n"
284
+ f"Saved model: {model_name}\n"
285
+ f"Device: {DEVICE}\n"
286
+ f"Test loss: {test_loss:.4f}\n"
287
+ f"Test accuracy: {test_acc:.4f}\n"
288
+ f"Elapsed time: {elapsed:.1f}s"
289
+ )
290
+
291
+ yield {
292
+ "status": final_message,
293
+ "history": history,
294
+ "finished": True,
295
+ "models": list_saved_models(),
296
+ }
297
+
298
+
299
+ # ============================================================
300
+ # Inference helpers
301
+ # ============================================================
302
+ def preprocess_uploaded_image(image: Image.Image):
303
+ if image is None:
304
+ raise ValueError("Please upload an image.")
305
+
306
+ transform = transforms.Compose([
307
+ transforms.Grayscale(num_output_channels=1),
308
+ transforms.Resize((28, 28)),
309
+ transforms.ToTensor(),
310
+ transforms.Normalize((0.5,), (0.5,))
311
+ ])
312
+
313
+ tensor = transform(image).unsqueeze(0)
314
+ return tensor
315
+
316
+
317
+ def predict_uploaded_image(model_name: str, image: Image.Image):
318
+ if not model_name:
319
+ return "Please select a model.", None
320
+
321
+ model, meta = load_model(model_name)
322
+ tensor = preprocess_uploaded_image(image).to(DEVICE)
323
+
324
+ with torch.no_grad():
325
+ logits = model(tensor)
326
+ probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
327
+ pred_idx = int(torch.argmax(logits, dim=1).item())
328
+
329
+ conf = max(probs)
330
+ result_text = (
331
+ f"Prediction: {CLASS_NAMES[pred_idx]}\n"
332
+ f"Confidence: {conf:.4f}\n\n"
333
+ f"Model: {model_name}\n"
334
+ f"Dataset: {meta['config']['dataset_name']}"
335
+ )
336
+
337
+ prob_table = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
338
+ return result_text, prob_table
339
+
340
+
341
+ def test_random_sample(model_name: str):
342
+ if not model_name:
343
+ return None, "Please select a model.", None
344
+
345
+ model, meta = load_model(model_name)
346
+ dataset_name = meta["config"]["dataset_name"]
347
+ _, test_dataset = get_datasets(dataset_name)
348
+
349
+ idx = torch.randint(low=0, high=len(test_dataset), size=(1,)).item()
350
+ image_tensor, label = test_dataset[idx]
351
+
352
+ with torch.no_grad():
353
+ logits = model(image_tensor.unsqueeze(0).to(DEVICE))
354
+ probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
355
+ pred_idx = int(torch.argmax(logits, dim=1).item())
356
+
357
+ display_img = image_tensor.squeeze(0).cpu()
358
+ prob_table = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
359
+ result_text = (
360
+ f"Random test sample\n"
361
+ f"Ground truth: {label}\n"
362
+ f"Prediction: {pred_idx}\n"
363
+ f"Confidence: {max(probs):.4f}\n"
364
+ f"Model dataset: {dataset_name}"
365
+ )
366
+ return display_img, result_text, prob_table
367
+
368
+
369
+ def get_model_info(model_name: str):
370
+ if not model_name:
371
+ return "No model selected."
372
+ meta_file = model_meta_path(model_name)
373
+ if not os.path.exists(meta_file):
374
+ return "Selected model metadata not found."
375
+ with open(meta_file, "r", encoding="utf-8") as f:
376
+ meta = json.load(f)
377
+ return json.dumps(meta, indent=2, ensure_ascii=False)
378
+
379
+
380
+ def refresh_models_dropdown():
381
+ models = list_saved_models()
382
+ return gr.update(choices=models, value=models[0] if models else None)
383
+
384
+
385
+ # ============================================================
386
+ # Gradio callbacks
387
+ # ============================================================
388
+ def training_callback(dataset_name, conv1_channels, conv2_channels, kernel_size,
389
+ dropout, fc_dim, learning_rate, batch_size, epochs, model_tag):
390
+ for step in train_model(
391
+ dataset_name=dataset_name,
392
+ conv1_channels=conv1_channels,
393
+ conv2_channels=conv2_channels,
394
+ kernel_size=kernel_size,
395
+ dropout=dropout,
396
+ fc_dim=fc_dim,
397
+ learning_rate=learning_rate,
398
+ batch_size=batch_size,
399
+ epochs=epochs,
400
+ model_tag=model_tag,
401
+ ):
402
+ line_data = [
403
+ [e, tl, ta, vl, va]
404
+ for e, tl, ta, vl, va in zip(
405
+ step["history"]["epoch"],
406
+ step["history"]["train_loss"],
407
+ step["history"]["train_acc"],
408
+ step["history"]["val_loss"],
409
+ step["history"]["val_acc"],
410
+ )
411
+ ]
412
+
413
+ dropdown_update = gr.update()
414
+ if step["finished"] and step["models"] is not None:
415
+ models = step["models"]
416
+ dropdown_update = gr.update(choices=models, value=models[0] if models else None)
417
+
418
+ yield step["status"], line_data, dropdown_update, dropdown_update
419
+
420
+
421
+ # ============================================================
422
+ # UI
423
+ # ============================================================
424
+ initial_models = list_saved_models()
425
+
426
+ with gr.Blocks(title="CNN Trainer and Tester") as demo:
427
+ gr.Markdown("# Simple CNN Trainer and Tester")
428
+ gr.Markdown(
429
+ "This app is designed for lightweight image classification experiments on MNIST or FashionMNIST. "
430
+ "Tab 1 trains a simple CNN. Tab 2 loads a saved model and tests it on uploaded images or random test samples."
431
+ )
432
+
433
+ with gr.Tabs():
434
+ with gr.Tab("Train"):
435
+ with gr.Row():
436
+ with gr.Column(scale=1):
437
+ dataset_name = gr.Dropdown(
438
+ choices=["MNIST", "FashionMNIST"],
439
+ value="MNIST",
440
+ label="Dataset"
441
+ )
442
+ conv1_channels = gr.Slider(8, 64, value=16, step=8, label="Conv1 Channels")
443
+ conv2_channels = gr.Slider(16, 128, value=32, step=16, label="Conv2 Channels")
444
+ kernel_size = gr.Dropdown(choices=[3, 5], value=3, label="Kernel Size")
445
+ dropout = gr.Slider(0.0, 0.7, value=0.2, step=0.05, label="Dropout")
446
+ fc_dim = gr.Slider(32, 256, value=128, step=32, label="FC Hidden Dimension")
447
+ learning_rate = gr.Number(value=0.001, label="Learning Rate")
448
+ batch_size = gr.Dropdown(choices=[32, 64, 128, 256], value=64, label="Batch Size")
449
+ epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
450
+ model_tag = gr.Textbox(label="Model Tag", placeholder="e.g. mnist_demo")
451
+ train_btn = gr.Button("Start Training", variant="primary")
452
+
453
+ with gr.Column(scale=1):
454
+ train_status = gr.Textbox(label="Training Status", lines=10)
455
+ train_plot = gr.LinePlot(
456
+ x="epoch",
457
+ y="value",
458
+ color="metric",
459
+ title="Training Curves",
460
+ y_title="Value",
461
+ x_title="Epoch",
462
+ width=700,
463
+ height=400,
464
+ )
465
+
466
+ with gr.Tab("Test"):
467
+ with gr.Row():
468
+ with gr.Column(scale=1):
469
+ model_selector = gr.Dropdown(
470
+ choices=initial_models,
471
+ value=initial_models[0] if initial_models else None,
472
+ label="Select Saved Model"
473
+ )
474
+ refresh_btn = gr.Button("Refresh Model List")
475
+ model_info = gr.Code(label="Model Metadata", language="json")
476
+ load_info_btn = gr.Button("Show Model Info")
477
+
478
+ with gr.Column(scale=1):
479
+ upload_image = gr.Image(type="pil", label="Upload Image")
480
+ predict_btn = gr.Button("Predict Uploaded Image", variant="primary")
481
+ predict_text = gr.Textbox(label="Prediction Result", lines=6)
482
+ predict_probs = gr.Label(label="Class Probabilities")
483
+
484
+ with gr.Row():
485
+ random_test_btn = gr.Button("Test Random Sample")
486
+ with gr.Row():
487
+ random_sample_image = gr.Image(type="numpy", label="Random Test Image")
488
+ random_sample_text = gr.Textbox(label="Random Sample Result", lines=6)
489
+ random_sample_probs = gr.Label(label="Random Sample Probabilities")
490
+
491
+ def format_lineplot_rows(rows):
492
+ output = []
493
+ for epoch, train_loss, train_acc, val_loss, val_acc in rows:
494
+ output.append({"epoch": epoch, "value": train_loss, "metric": "train_loss"})
495
+ output.append({"epoch": epoch, "value": train_acc, "metric": "train_acc"})
496
+ output.append({"epoch": epoch, "value": val_loss, "metric": "val_loss"})
497
+ output.append({"epoch": epoch, "value": val_acc, "metric": "val_acc"})
498
+ return output
499
+
500
+ def wrapped_training_callback(*args):
501
+ for status, rows, train_dd_update, test_dd_update in training_callback(*args):
502
+ yield status, format_lineplot_rows(rows), train_dd_update, test_dd_update
503
+
504
+ train_model_selector_hidden = gr.Dropdown(visible=False)
505
+ test_model_selector_hidden = gr.Dropdown(visible=False)
506
+
507
+ train_btn.click(
508
+ fn=wrapped_training_callback,
509
+ inputs=[
510
+ dataset_name, conv1_channels, conv2_channels, kernel_size,
511
+ dropout, fc_dim, learning_rate, batch_size, epochs, model_tag
512
+ ],
513
+ outputs=[train_status, train_plot, train_model_selector_hidden, model_selector],
514
+ )
515
+
516
+ refresh_btn.click(
517
+ fn=refresh_models_dropdown,
518
+ inputs=None,
519
+ outputs=model_selector,
520
+ )
521
+
522
+ load_info_btn.click(
523
+ fn=get_model_info,
524
+ inputs=model_selector,
525
+ outputs=model_info,
526
+ )
527
+
528
+ predict_btn.click(
529
+ fn=predict_uploaded_image,
530
+ inputs=[model_selector, upload_image],
531
+ outputs=[predict_text, predict_probs],
532
+ )
533
+
534
+ random_test_btn.click(
535
+ fn=test_random_sample,
536
+ inputs=[model_selector],
537
+ outputs=[random_sample_image, random_sample_text, random_sample_probs],
538
+ )
539
+
540
+
541
+ if __name__ == "__main__":
542
+ demo.launch()