Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -430,37 +430,35 @@ with gr.Blocks(title="CNN Trainer and Tester") as demo:
|
|
| 430 |
"Tab 1 trains a simple CNN. Tab 2 loads a saved model and tests it on uploaded images or random test samples."
|
| 431 |
)
|
| 432 |
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
)
|
| 464 |
|
| 465 |
with gr.Tab("Test"):
|
| 466 |
with gr.Row():
|
|
|
|
| 430 |
"Tab 1 trains a simple CNN. Tab 2 loads a saved model and tests it on uploaded images or random test samples."
|
| 431 |
)
|
| 432 |
|
| 433 |
+
with gr.Tab("Train"):
|
| 434 |
+
with gr.Row():
|
| 435 |
+
with gr.Column(scale=1):
|
| 436 |
+
dataset_name = gr.Dropdown(
|
| 437 |
+
choices=["MNIST", "FashionMNIST"],
|
| 438 |
+
value="MNIST",
|
| 439 |
+
label="Dataset"
|
| 440 |
+
)
|
| 441 |
+
conv1_channels = gr.Slider(8, 64, value=16, step=8, label="Conv1 Channels")
|
| 442 |
+
conv2_channels = gr.Slider(16, 128, value=32, step=16, label="Conv2 Channels")
|
| 443 |
+
kernel_size = gr.Dropdown(choices=[3, 5], value=3, label="Kernel Size")
|
| 444 |
+
dropout = gr.Slider(0.0, 0.7, value=0.2, step=0.05, label="Dropout")
|
| 445 |
+
fc_dim = gr.Slider(32, 256, value=128, step=32, label="FC Hidden Dimension")
|
| 446 |
+
learning_rate = gr.Number(value=0.001, label="Learning Rate")
|
| 447 |
+
batch_size = gr.Dropdown(choices=[32, 64, 128, 256], value=64, label="Batch Size")
|
| 448 |
+
epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
|
| 449 |
+
model_tag = gr.Textbox(label="Model Tag", placeholder="e.g. mnist_demo")
|
| 450 |
+
train_btn = gr.Button("Start Training", variant="primary")
|
| 451 |
+
|
| 452 |
+
with gr.Column(scale=1):
|
| 453 |
+
train_status = gr.Textbox(label="Training Status", lines=10)
|
| 454 |
+
train_plot = gr.LinePlot(
|
| 455 |
+
x="epoch",
|
| 456 |
+
y="value",
|
| 457 |
+
color="metric",
|
| 458 |
+
title="Training Curves",
|
| 459 |
+
y_title="Value",
|
| 460 |
+
x_title="Epoch",
|
| 461 |
+
)
|
|
|
|
|
|
|
| 462 |
|
| 463 |
with gr.Tab("Test"):
|
| 464 |
with gr.Row():
|