| from PIL import Image |
| from io import BytesIO |
| from matplotlib.figure import Figure |
| from torchvision import transforms |
| from tqdm import tqdm |
| from typing import Literal, Any |
| from urllib.request import urlopen |
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import os |
| import spaces |
| import sys |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| LABELS = [ |
| "Panoramic", |
| "Feature", |
| "Detail", |
| "Enclosed", |
| "Focal", |
| "Ephemeral", |
| "Canopied", |
| ] |
| MODELFILE = "Litton-7type-visual-landscape-model.pth" |
|
|
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
| if not os.path.exists(MODELFILE): |
| model_url = f"https://lclab.thu.edu.tw/modelzoo/{MODELFILE}" |
|
|
| print(f"fetch model from {model_url}...", file=sys.stderr) |
|
|
| with urlopen(model_url) as resp: |
| progress = tqdm(total=int(resp["Content-Length"]), desc="Downloading") |
| with open(MODELFILE, "wb") as modelfile: |
| while True: |
| chunk = resp.read(1024) |
| if len(chunk) == 0: |
| break |
| modelfile.write(chunk) |
| progress.update(len(chunk)) |
|
|
| model = torch.load( |
| MODELFILE, map_location=device, weights_only=False |
| ).module |
| model.eval() |
| preprocess = transforms.Compose( |
| [ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| ), |
| ] |
| ) |
|
|
| @spaces.GPU |
| def predict(image: Image.Image) -> Figure: |
| image = image.convert("RGB") |
| input_tensor = preprocess(image).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| logits = model(input_tensor) |
| probs = F.softmax(logits[:, :7], dim=1).cpu() |
|
|
| return draw_bar_chart( |
| { |
| "class": LABELS, |
| "probs": probs[0] * 100, |
| } |
| ) |
|
|
|
|
| def draw_bar_chart(data: dict[str, list[str | float]]): |
| classes = data["class"] |
| probabilities = data["probs"] |
|
|
| fig, ax = plt.subplots(figsize=(8, 6)) |
| ax.bar(classes, probabilities, color="skyblue") |
|
|
| ax.set_xlabel("Class") |
| ax.set_ylabel("Probability (%)") |
| ax.set_title("Class Probability") |
|
|
| for i, prob in enumerate(probabilities): |
| ax.text(i, prob + 0.01, f"{prob:.2f}%", ha="center", va="bottom") |
|
|
| fig.tight_layout() |
|
|
| return fig |
|
|
|
|
| def choose_example(imgpath: str) -> gr.Image: |
| img = Image.open(imgpath) |
| width, height = img.size |
| ratio = 512 / max(width, height) |
| img = img.resize((int(width * ratio), int(height * ratio))) |
| return gr.Image(value=img, label="輸入影像(不支援 SVG 格式)", type="pil") |
|
|
|
|
| def get_layout(): |
| css = """ |
| .main-title { |
| font-size: 24px; |
| font-weight: bold; |
| text-align: center; |
| margin-bottom: 20px; |
| } |
| .reference { |
| text-align: center; |
| font-size: 1.2em; |
| color: #d1d5db; |
| margin-bottom: 20px; |
| } |
| .reference a { |
| color: #FB923C; |
| text-decoration: none; |
| } |
| .reference a:hover { |
| text-decoration: underline; |
| color: #FB923C; |
| } |
| .title { |
| border-bottom: 1px solid; |
| } |
| .footer { |
| text-align: center; |
| margin-top: 30px; |
| padding-top: 20px; |
| border-top: 1px solid #ddd; |
| color: #d1d5db; |
| font-size: 14px; |
| } |
| .example-image { |
| height: 220px; |
| padding: 25px; |
| } |
| """ |
| theme = gr.themes.Base( |
| primary_hue="orange", |
| secondary_hue="cyan", |
| neutral_hue="gray", |
| ).set( |
| body_text_color='*neutral_100', |
| body_text_color_subdued='*neutral_600', |
| background_fill_primary='*neutral_950', |
| background_fill_secondary='*neutral_600', |
| border_color_accent='*secondary_800', |
| color_accent='*primary_50', |
| color_accent_soft='*secondary_800', |
| code_background_fill='*neutral_700', |
| block_background_fill_dark='*body_background_fill', |
| block_info_text_color='#6b7280', |
| block_label_text_color='*neutral_300', |
| block_label_text_weight='700', |
| block_title_text_color='*block_label_text_color', |
| block_title_text_weight='300', |
| panel_background_fill='*neutral_800', |
| table_text_color_dark='*secondary_800', |
| checkbox_background_color_selected='*primary_500', |
| checkbox_label_background_fill='*neutral_500', |
| checkbox_label_background_fill_hover='*neutral_700', |
| checkbox_label_text_color='*neutral_200', |
| input_background_fill='*neutral_700', |
| input_background_fill_focus='*neutral_600', |
| slider_color='*primary_500', |
| table_even_background_fill='*neutral_700', |
| table_odd_background_fill='*neutral_600', |
| table_row_focus='*neutral_800' |
| ) |
| with gr.Blocks(css=css, theme=theme) as demo: |
| with gr.Column(): |
| gr.HTML( |
| value=( |
| '<div class="main-title">Litton7景觀分類模型</div>' |
| '<div class="reference">引用資料:' |
| '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">' |
| "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)" |
| "</a>" |
| "</div>" |
| ), |
| ) |
|
|
| with gr.Row(equal_height=True): |
| with gr.Group(): |
| img = gr.Image(label="上傳影像", type="pil", height="256px") |
| gr.Label("範例影像", show_label=False) |
| with gr.Row(): |
| ex1 = gr.Image( |
| value="examples/beach.jpg", |
| show_label=False, |
| type="filepath", |
| elem_classes="example-image", |
| interactive=False, |
| show_download_button=False, |
| show_fullscreen_button=False, |
| show_share_button=False, |
| ) |
| ex2 = gr.Image( |
| value="examples/field.jpg", |
| show_label=False, |
| type="filepath", |
| elem_classes="example-image", |
| interactive=False, |
| show_download_button=False, |
| show_fullscreen_button=False, |
| show_share_button=False, |
| ) |
| ex3 = gr.Image( |
| value="examples/sky.jpg", |
| show_label=False, |
| type="filepath", |
| elem_classes="example-image", |
| interactive=False, |
| show_download_button=False, |
| show_fullscreen_button=False, |
| show_share_button=False, |
| ) |
| chart = gr.Plot(label="分類結果") |
|
|
| start_button = gr.Button("開始", variant="primary") |
| gr.HTML( |
| '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>', |
| ) |
|
|
| start_button.click( |
| fn=predict, |
| inputs=img, |
| outputs=chart, |
| ) |
|
|
| ex1.select(fn=choose_example, inputs=ex1, outputs=img) |
| ex2.select(fn=choose_example, inputs=ex2, outputs=img) |
| ex3.select(fn=choose_example, inputs=ex3, outputs=img) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| get_layout().launch() |
|
|