| """ |
| Gradio interface for HATSAT application. |
| """ |
|
|
| import gradio as gr |
| from PIL import Image |
|
|
| from config import REQUIRED_IMAGE_SIZE, UPSCALE_FACTOR |
| from utils.image_utils import validate_image_size, upscale_image |
| from interface.css_styles import generate_css, get_sample_images |
|
|
|
|
| def upscale_and_display(image, model, device): |
| """Process image upload and return upscaled result.""" |
| if image is None: |
| return None, "Please upload an image or select a sample image." |
|
|
| |
| is_valid, message = validate_image_size(image) |
| if not is_valid: |
| return None, f"❌ Error: {message}" |
|
|
| try: |
| |
| upscaled = upscale_image(image, model, device) |
| return upscaled, "✅ Image successfully enhanced!" |
| except Exception as e: |
| return None, f"❌ Error processing image: {str(e)}" |
|
|
|
|
| def select_sample_image(image_path): |
| """Load and return a sample image.""" |
| if image_path: |
| return Image.open(image_path) |
| return None |
|
|
|
|
| def create_interface(model, device): |
| """Create and configure the Gradio interface.""" |
| css = generate_css() |
|
|
| with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface: |
| gr.Markdown("# HATSAT - Super-Resolution for Satellite Images") |
| gr.Markdown(f"Upload a satellite image or select a sample to enhance its resolution by {UPSCALE_FACTOR}x.") |
| gr.Markdown(f"⚠️ **Important**: Images must be exactly **{REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels** for the model to work properly.") |
|
|
| |
| with gr.Accordion("Acknowledgments", open=False): |
| gr.Markdown(""" |
| ### Base Model: HAT (Hybrid Attention Transformer) |
| This model is a fine tuned version of **HAT**: |
| - **GitHub Repository**: [https://github.com/XPixelGroup/HAT](https://github.com/XPixelGroup/HAT) |
| - **Paper**: [Activating More Pixels in Image Super-Resolution Transformer](https://arxiv.org/abs/2205.04437) |
| - **Authors**: Xiangyu Chen, Xintao Wang, Jiantao Zhou, Yu Qiao, Chao Dong |
| |
| ### Training Dataset: SEN2NAIPv2 |
| The model was fine-tuned using the **SEN2NAIPv2** dataset: |
| - **HuggingFace Dataset**: [https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2](https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2) |
| - **Description**: High-resolution satellite imagery dataset for super-resolution tasks |
| """) |
|
|
| |
| sample_images = get_sample_images() |
| sample_buttons = [] |
| if sample_images: |
| gr.Markdown("**Sample Images (click to select):**") |
| with gr.Row(): |
| for i, img_path in enumerate(sample_images): |
| btn = gr.Button( |
| "", |
| elem_id=f"sample_btn_{i}", |
| elem_classes="sample-image-btn" |
| ) |
| sample_buttons.append((btn, img_path)) |
|
|
| with gr.Row(): |
| input_image = gr.Image( |
| type="pil", |
| label=f"Input Image (must be {REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels)", |
| elem_classes="image-container", |
| sources=["upload"], |
| height=500, |
| width=500 |
| ) |
|
|
| output_image = gr.Image( |
| type="pil", |
| label=f"Enhanced Output ({UPSCALE_FACTOR}x)", |
| elem_classes="image-container", |
| interactive=False, |
| height=500, |
| width=500, |
| show_download_button=True |
| ) |
|
|
| submit_btn = gr.Button("Enhance Image", variant="primary") |
|
|
| |
| status_message = gr.Textbox( |
| label="Status", |
| interactive=False, |
| show_label=True |
| ) |
|
|
| |
| if sample_images: |
| for btn, img_path in sample_buttons: |
| btn.click(fn=lambda path=img_path: select_sample_image(path), outputs=input_image) |
|
|
| submit_btn.click( |
| fn=lambda img: upscale_and_display(img, model, device), |
| inputs=input_image, |
| outputs=[output_image, status_message] |
| ) |
|
|
| return iface |