2024Lee's picture
Upload folder using huggingface_hub
c58b8ac verified
"""Construction Safety Hazard Detection — Gradio Demo (HF Spaces)."""
from __future__ import annotations
import logging
import os
import sys
import gradio as gr
from PIL import Image
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from pipeline import PipelineResult, SafetyPipeline
from vlm_manager import VLMManager
import visualizer
MODEL_ID = "Qwen/Qwen3.5-9B"
logger.info("Initializing pipeline with %s ...", MODEL_ID)
_vlm = VLMManager(model_id=MODEL_ID)
_pipeline = SafetyPipeline(_vlm)
logger.info("Pipeline ready.")
def analyze_image(
image: Image.Image | None,
) -> tuple[Image.Image | None, str, str, str, str]:
if image is None:
return None, "Please upload a construction site image.", "", "", ""
result: PipelineResult = _pipeline.run(image)
det_text = visualizer.format_detection_text(result.detections)
prompt_text = result.detection_guided_prompt
hazard_text = visualizer.format_hazard_text(result.hazard_result)
timing = (
f"Detection: {result.detection_time_ms:.0f}ms | "
f"Analysis: {result.analysis_time_ms:.0f}ms | "
f"Total: {result.total_time_ms:.0f}ms"
)
raw_section = ""
if result.hazard_result.raw_response:
raw_section = f"\n\n--- Raw VLM Response ---\n{result.hazard_result.raw_response}"
return (
result.annotated_image,
det_text,
prompt_text,
hazard_text + raw_section,
timing,
)
with gr.Blocks(
title="Construction Safety Hazard Detection",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
"# Construction Safety Hazard Detection\n"
"**Two-stage detection-guided VLM pipeline** -- "
"upload a construction site image to identify safety hazards.\n\n"
"*Based on: Integration of Object Detection and Small VLMs for "
"Construction Safety Hazard Identification (Adil et al., 2025)*"
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", type="pil", height=400)
run_btn = gr.Button("Analyze Safety Hazards", variant="primary", size="lg")
timing_text = gr.Textbox(label="Timing", interactive=False)
with gr.Column(scale=1):
output_image = gr.Image(label="Detection + Hazard Results", type="pil", height=400)
with gr.Row():
with gr.Column():
det_output = gr.Textbox(label="Stage 1: Detected Objects", lines=8, interactive=False)
with gr.Column():
hazard_output = gr.Textbox(label="Stage 2: Hazard Analysis", lines=8, interactive=False)
with gr.Accordion("Detection-Guided Prompt (sent to VLM)", open=False):
prompt_output = gr.Textbox(label="Constructed Prompt", lines=15, interactive=False)
gr.Markdown(
"### Hazard Categories\n"
"| Category | Description |\n"
"|---|---|\n"
"| **PPE Non-Compliance** | Workers not wearing hard hats or high-visibility safety vests |\n"
"| **Fall Hazard** | Workers near elevated positions, excavations, or temporary structures without fall protection |\n"
"| **Caught-between Hazard** | Workers at risk of being struck, crushed, or pinned by machinery or structures |\n"
"| **Unsafe Environment** | Exposed rebar, uneven terrain, debris, open electrical wires, poor lighting |"
)
run_btn.click(
fn=analyze_image,
inputs=[input_image],
outputs=[output_image, det_output, prompt_output, hazard_output, timing_text],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7860)))