#!/usr/bin/env python3 """ Train SegFormer-B0 for facade segmentation on mixed rectified + unrectified data. Sources: - CMP Facade (Xpitfire/cmp_facade) - rectified facades, ~492 images - ADE20K scene_parse_150 (merve/scene_parse_150) - unrectified street-level perspective, filtered to building-containing scenes 13-class taxonomy (preserves all CMP detail classes): 0: background 7: sill 1: facade 8: blind 2: molding 9: balcony 3: cornice 10: shop 4: pillar 11: deco 5: window 12: vegetation 6: door Two-pass inference: Pass 1 (unrectified street photo): collapse to coarse groups via COARSE_MAP Pass 2 (rectified crop): use full 13-class output Base: nvidia/mit-b0 (clean ImageNet-pretrained encoder, fresh segmentation head) """ import os import io import numpy as np import torch import torch.nn as nn from PIL import Image from datasets import load_dataset, concatenate_datasets, Dataset from transformers import ( SegformerImageProcessor, SegformerForSemanticSegmentation, TrainingArguments, Trainer, ) import evaluate from torchvision.transforms import ColorJitter, RandomPerspective # ─── Configuration ────────────────────────────────────────────────────────── HUB_MODEL_ID = "Marco333/segformer-b0-facade-mixed" BASE_MODEL = "nvidia/mit-b0" OUTPUT_DIR = "./segformer-b0-facade-mixed" NUM_LABELS = 13 id2label = { 0: "background", 1: "facade", 2: "molding", 3: "cornice", 4: "pillar", 5: "window", 6: "door", 7: "sill", 8: "blind", 9: "balcony", 10: "shop", 11: "deco", 12: "vegetation", } label2id = {v: k for k, v in id2label.items()} # ─── Coarse grouping for Pass 1 (unrectified street photo inference) ──────── # Maps fine-grained class IDs → coarse group names # Usage: coarse_pred = COARSE_MAP[fine_pred] (numpy fancy indexing) COARSE_LABELS = ["background", "facade_wall", "window", "door", "balcony", "vegetation"] COARSE_MAP = np.array([ 0, # 0 background → background 1, # 1 facade → facade_wall 1, # 2 molding → facade_wall 1, # 3 cornice → facade_wall 1, # 4 pillar → facade_wall 2, # 5 window → window 3, # 6 door → door 1, # 7 sill → facade_wall 2, # 8 blind → window 4, # 9 balcony → balcony 3, # 10 shop → door 1, # 11 deco → facade_wall 5, # 12 vegetation → vegetation ], dtype=np.uint8) # ─── Label Remapping Tables ───────────────────────────────────────────────── # CMP Facade: paletted PNG, values 1-12 → preserve all detail classes CMP_REMAP = np.full(256, 255, dtype=np.uint8) CMP_REMAP[0] = 255 # unlabeled -> ignore CMP_REMAP[1] = 1 # facade -> facade CMP_REMAP[2] = 2 # molding -> molding CMP_REMAP[3] = 3 # cornice -> cornice CMP_REMAP[4] = 4 # pillar -> pillar CMP_REMAP[5] = 5 # window -> window CMP_REMAP[6] = 6 # door -> door CMP_REMAP[7] = 7 # sill -> sill CMP_REMAP[8] = 8 # blind -> blind CMP_REMAP[9] = 9 # balcony -> balcony CMP_REMAP[10] = 10 # shop -> shop CMP_REMAP[11] = 11 # deco -> deco CMP_REMAP[12] = 0 # background -> background # ADE20K: grayscale 0-150 (1-indexed, 0=unlabeled) # Maps to coarse equivalents in the 13-class taxonomy. # ADE20K has no molding/cornice/sill/etc — those stay 255 (ignore). ADE_REMAP = np.full(256, 255, dtype=np.uint8) ADE_REMAP[0] = 255 # unlabeled -> ignore ADE_REMAP[1] = 1 # wall -> facade ADE_REMAP[2] = 1 # building -> facade ADE_REMAP[3] = 0 # sky -> background ADE_REMAP[5] = 12 # tree -> vegetation ADE_REMAP[7] = 0 # road -> background ADE_REMAP[9] = 5 # windowpane -> window ADE_REMAP[10] = 0 # grass -> background ADE_REMAP[12] = 0 # sidewalk -> background ADE_REMAP[13] = 0 # person -> background ADE_REMAP[14] = 0 # earth -> background ADE_REMAP[15] = 6 # door -> door ADE_REMAP[17] = 0 # mountain -> background ADE_REMAP[18] = 12 # plant -> vegetation ADE_REMAP[21] = 0 # car -> background ADE_REMAP[22] = 0 # water -> background ADE_REMAP[26] = 1 # house -> facade ADE_REMAP[33] = 0 # fence -> background ADE_REMAP[39] = 0 # railing -> background ADE_REMAP[43] = 4 # column -> pillar ADE_REMAP[49] = 1 # skyscraper -> facade ADE_REMAP[54] = 0 # stairs -> background ADE_REMAP[87] = 0 # awning -> background ADE_REMAP[94] = 0 # pole -> background def decode_image(data): """Decode image from dict-with-bytes or PIL Image.""" if isinstance(data, dict) and "bytes" in data: return Image.open(io.BytesIO(data["bytes"])) return data def load_cmp(): """Load and remap CMP Facade dataset.""" print("Loading CMP Facade dataset...") ds = load_dataset("Xpitfire/cmp_facade") out = {} for split_name in ds.keys(): images, labels = [], [] for i in range(len(ds[split_name])): ex = ds[split_name][i] img = decode_image(ex["pixel_values"]).convert("RGB") lbl = decode_image(ex["label"]) arr = np.array(lbl, dtype=np.uint8) remapped = CMP_REMAP[arr] labels.append(Image.fromarray(remapped, mode="L")) images.append(img) out[split_name] = Dataset.from_dict({"image": images, "annotation": labels}) print(f" CMP {split_name}: {len(images)} images") return out def load_ade20k(): """Load ADE20K, filter to building-containing scenes, remap labels.""" print("Loading ADE20K scene_parse_150...") ds = load_dataset("merve/scene_parse_150") building_ids = {1, 2, 26, 43, 49} MIN_BUILDING_FRACTION = 0.03 out = {} for split_name in ["train", "validation"]: if split_name not in ds: continue images, labels = [], [] skipped = 0 for i in range(len(ds[split_name])): ex = ds[split_name][i] ann = ex["annotation"] if ann is None: skipped += 1 continue arr = np.array(ann, dtype=np.uint8) frac = np.isin(arr, list(building_ids)).sum() / arr.size if frac < MIN_BUILDING_FRACTION: skipped += 1 continue remapped = ADE_REMAP[arr] labels.append(Image.fromarray(remapped, mode="L")) images.append(ex["image"].convert("RGB")) out[split_name] = Dataset.from_dict({"image": images, "annotation": labels}) print(f" ADE20K {split_name}: {len(images)} kept, {skipped} skipped") return out def main(): # ─── Load datasets ──────────────────────────────────────────────────────── cmp = load_cmp() ade = load_ade20k() # Combine: CMP train+test + ADE20K train -> train # CMP eval + ADE20K validation -> val train_parts = [] for s in ["train", "test"]: if s in cmp: train_parts.append(cmp[s]) if "train" in ade: train_parts.append(ade["train"]) val_parts = [] if "eval" in cmp: val_parts.append(cmp["eval"]) if "validation" in ade: val_parts.append(ade["validation"]) train_ds = concatenate_datasets(train_parts) val_ds = concatenate_datasets(val_parts) print(f"\nFinal dataset: train={len(train_ds)}, val={len(val_ds)}") # ─── Image processor ────────────────────────────────────────────────────── image_processor = SegformerImageProcessor.from_pretrained( BASE_MODEL, do_reduce_labels=False, size={"height": 512, "width": 512}, ) # ─── Augmentation ───────────────────────────────────────────────────────── color_jitter = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05) perspective_img = RandomPerspective(distortion_scale=0.3, p=0.4, fill=0) perspective_lbl = RandomPerspective(distortion_scale=0.3, p=0.4, fill=255) def train_transforms(batch): imgs, lbls = [], [] for img, ann in zip(batch["image"], batch["annotation"]): img = color_jitter(img) seed = torch.randint(0, 2**32, (1,)).item() torch.manual_seed(seed) img = perspective_img(img) torch.manual_seed(seed) ann = perspective_lbl(ann) imgs.append(img) lbls.append(ann) return image_processor(imgs, lbls) def val_transforms(batch): return image_processor( [x for x in batch["image"]], [x for x in batch["annotation"]], ) train_ds.set_transform(train_transforms) val_ds.set_transform(val_transforms) # ─── Model ──────────────────────────────────────────────────────────────── print(f"\nLoading model from {BASE_MODEL} (clean encoder, fresh seg head)...") model = SegformerForSemanticSegmentation.from_pretrained( BASE_MODEL, id2label=id2label, label2id=label2id, num_labels=NUM_LABELS, ignore_mismatched_sizes=True, ) # ─── Metrics ────────────────────────────────────────────────────────────── metric = evaluate.load("mean_iou") def compute_metrics(eval_pred): with torch.no_grad(): logits, labels = eval_pred logits_tensor = torch.from_numpy(logits) logits_tensor = nn.functional.interpolate( logits_tensor, size=labels.shape[-2:], mode="bilinear", align_corners=False, ).argmax(dim=1) pred_labels = logits_tensor.detach().cpu().numpy() metrics = metric.compute( predictions=pred_labels, references=labels, num_labels=NUM_LABELS, ignore_index=255, reduce_labels=False, ) for key, value in metrics.items(): if isinstance(value, np.ndarray): metrics[key] = value.tolist() return metrics # ─── Training arguments ─────────────────────────────────────────────────── training_args = TrainingArguments( output_dir=OUTPUT_DIR, hub_model_id=HUB_MODEL_ID, push_to_hub=True, # Optimizer learning_rate=6e-5, lr_scheduler_type="polynomial", warmup_ratio=0.1, weight_decay=0.01, # Epochs & batches num_train_epochs=80, per_device_train_batch_size=4, per_device_eval_batch_size=4, gradient_accumulation_steps=2, # effective batch = 8 # Eval & saving eval_strategy="epoch", save_strategy="epoch", save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="mean_iou", greater_is_better=True, eval_accumulation_steps=5, # Logging & monitoring logging_strategy="steps", logging_steps=10, logging_first_step=True, disable_tqdm=True, report_to=["trackio"], run_name="segformer-b0-facade-mixed", # Critical for segmentation tasks remove_unused_columns=False, label_names=["labels"], # Performance fp16=True, dataloader_num_workers=4, ) # ─── Trainer ────────────────────────────────────────────────────────────── trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, compute_metrics=compute_metrics, ) print("\nStarting training...") trainer.train() print("\nPushing best model to HuggingFace Hub...") trainer.push_to_hub( commit_message="SegFormer-B0 facade mixed 13-class: rectified (CMP) + unrectified (ADE20K) — clean nvidia/mit-b0 base", ) image_processor.save_pretrained(OUTPUT_DIR) image_processor.push_to_hub(HUB_MODEL_ID) print(f"\nDone! Model at: https://huggingface.co/{HUB_MODEL_ID}") if __name__ == "__main__": main()