construction-safety-demo / visualizer.py
2024Lee's picture
Upload folder using huggingface_hub
c58b8ac verified
"""Visualization utilities for detection and hazard analysis results."""
from __future__ import annotations
import re
import textwrap
from PIL import Image, ImageDraw, ImageFont
from detector import Detection, is_machinery, is_worker
# Color palette
COLOR_WORKER = (0, 180, 0) # green
COLOR_MACHINERY = (220, 60, 60) # red
COLOR_OTHER = (60, 130, 220) # blue
HAZARD_COLORS = {
"ppe_non_compliance": (255, 107, 53),
"fall_hazard": (220, 53, 69),
"caught_between_hazard": (155, 89, 182),
"unsafe_environment": (243, 156, 18),
}
HAZARD_LABELS = {
"ppe_non_compliance": "PPE",
"fall_hazard": "FALL",
"caught_between_hazard": "CAUGHT",
"unsafe_environment": "ENV",
}
def _get_color(det: Detection) -> tuple[int, int, int]:
if is_worker(det.class_name):
return COLOR_WORKER
if is_machinery(det.class_name):
return COLOR_MACHINERY
return COLOR_OTHER
def _try_load_font(size: int = 16) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
for path in [
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
]:
try:
return ImageFont.truetype(path, size)
except (OSError, IOError):
continue
return ImageFont.load_default()
def _extract_mentioned_ids(text: str) -> set[str]:
"""Extract entity IDs (w1, ex1, o3, ...) mentioned in hazard explanations."""
return set(re.findall(r"\b([a-z]{1,2}\d+)\b", text))
def draw_results(
image: Image.Image,
detections: list[Detection],
hazard_result=None,
line_width: int = 3,
) -> Image.Image:
"""Draw bounding boxes, labels, and hazard annotations on the image.
- Workers: green boxes
- Machinery: red boxes
- Other: blue boxes
- Entities mentioned in hazard explanations get a warning badge
- Hazard summary panel at the bottom
"""
W, H = image.size
panel_h = 0
# Pre-compute hazard info
hazarded_ids: dict[str, list[str]] = {} # det_id -> [hazard_labels]
hazard_lines: list[tuple[tuple[int, int, int], str]] = []
if hazard_result and hazard_result.hazards:
for h in hazard_result.hazards:
explanation = hazard_result.explanations.get(h, "")
mentioned = _extract_mentioned_ids(explanation)
short = HAZARD_LABELS.get(h, h[:4].upper())
for mid in mentioned:
hazarded_ids.setdefault(mid, []).append(short)
color = HAZARD_COLORS.get(h, (200, 200, 200))
label = h.replace("_", " ").title()
hazard_lines.append((color, f"{label}: {explanation}"))
panel_h = max(30, 26 * len(hazard_lines) + 16)
# Create output image with panel
out = Image.new("RGB", (W, H + panel_h), (30, 30, 30))
out.paste(image.convert("RGB"), (0, 0))
draw = ImageDraw.Draw(out)
font_size = max(14, min(W, H) // 40)
font = _try_load_font(font_size)
font_sm = _try_load_font(max(11, font_size - 4))
font_panel = _try_load_font(max(12, font_size - 2))
# Draw detection boxes
for det in detections:
x1, y1, x2, y2 = det.xyxy
px1, py1, px2, py2 = int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)
is_hazarded = det.det_id in hazarded_ids
color = (255, 50, 50) if is_hazarded else _get_color(det)
lw = line_width + 1 if is_hazarded else line_width
draw.rectangle([px1, py1, px2, py2], outline=color, width=lw)
# ID label at top
label = f"{det.det_id} ({det.class_name})"
bbox = font.getbbox(label)
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
ty = max(0, py1 - th - 4)
draw.rectangle([px1, ty, px1 + tw + 6, ty + th + 4], fill=color)
draw.text((px1 + 3, ty + 1), label, fill="white", font=font)
# Hazard warning badge at bottom-right of box
if is_hazarded:
badge = " ".join(hazarded_ids[det.det_id])
bb = font_sm.getbbox(badge)
bw, bh = bb[2] - bb[0], bb[3] - bb[1]
bx = px2 - bw - 6
by = py2 + 2
draw.rectangle([bx, by, bx + bw + 6, by + bh + 4], fill=(200, 30, 30))
draw.text((bx + 3, by + 1), badge, fill="white", font=font_sm)
# Draw hazard summary panel at the bottom
if hazard_lines:
y = H + 6
for color, text in hazard_lines:
# Colored dot
draw.ellipse([10, y + 2, 22, y + 14], fill=color)
# Truncate long text
max_chars = max(40, W // 8)
display = text if len(text) <= max_chars else text[:max_chars - 3] + "..."
draw.text((28, y), display, fill="white", font=font_panel)
y += 24
return out
def draw_detections(
image: Image.Image,
detections: list[Detection],
line_width: int = 3,
) -> Image.Image:
"""Draw bounding boxes and labels (no hazard overlay)."""
return draw_results(image, detections, hazard_result=None, line_width=line_width)
def format_detection_text(detections: list[Detection]) -> str:
"""Format detections as readable text."""
if not detections:
return "No objects detected."
lines = [f"Detected {len(detections)} objects:\n"]
for det in detections:
lines.append(
f" [{det.det_id}] {det.class_name} "
f"center=({det.cx:.3f}, {det.cy:.3f}) "
f"size=({det.w:.3f} x {det.h:.3f})"
)
return "\n".join(lines)
def format_hazard_text(hazard_result) -> str:
"""Format hazard analysis result as readable text."""
if not hazard_result.hazards:
return "No hazards identified."
lines = ["Hazards identified:\n"]
for h in hazard_result.hazards:
label = h.replace("_", " ").title()
lines.append(f" ⚠ {label}")
if h in hazard_result.explanations:
lines.append(f" → {hazard_result.explanations[h]}")
return "\n".join(lines)