| import os |
| import gradio as gr |
| from cluster_visualize import PredictionArgs, generate_visualization |
|
|
| checkpoint_dir = 'checkpoints' |
|
|
| models = [] |
| checkpoints = [] |
|
|
| for root, dirs, files in os.walk(checkpoint_dir): |
| for file in files: |
| if file.endswith(".pth.tar"): |
| models.append(file.split('.')[0]) |
| checkpoints.append(os.path.join(root, file)) |
|
|
| def generate_coc_viz(model_name, |
| image, |
| stage, |
| block, |
| head, |
| alpha): |
| |
| checkpoint = 'coc_tiny_plain.pth.tar' |
| args = PredictionArgs( |
| model=model_name, |
| image=image, |
| checkpoint=checkpoint, |
| stage=stage, |
| block=block, |
| head=head, |
| alpha=alpha |
| ) |
| coc_visualization, probability = generate_visualization(args) |
| return probability, coc_visualization |
|
|
|
|
| demo = gr.Interface( |
| fn=generate_coc_viz, |
| inputs=[gr.components.Dropdown(models, label="Model Name"), |
| gr.Image(label="Input Image"), |
| gr.Slider(0, 3, step=1, label="Stage"), |
| gr.Slider(-1, 4, step=1, label="Block"), |
| gr.Slider(0, 7, step=1, label="Head"), |
| gr.components.Number(0.5, label="Alpha")], |
| outputs=[gr.Number(label="Probability"), gr.Image(label="Cluster Visualization")], |
| ) |
|
|
| demo.launch() |
|
|