Marco333's picture
Update to 13-class taxonomy: preserve all CMP detail classes + COARSE_MAP for 2-pass pipeline
3a3e574 verified
#!/usr/bin/env python3
"""
Train SegFormer-B0 for facade segmentation on mixed rectified + unrectified data.
Sources:
- CMP Facade (Xpitfire/cmp_facade) - rectified facades, ~492 images
- ADE20K scene_parse_150 (merve/scene_parse_150) - unrectified street-level perspective,
filtered to building-containing scenes
13-class taxonomy (preserves all CMP detail classes):
0: background 7: sill
1: facade 8: blind
2: molding 9: balcony
3: cornice 10: shop
4: pillar 11: deco
5: window 12: vegetation
6: door
Two-pass inference:
Pass 1 (unrectified street photo): collapse to coarse groups via COARSE_MAP
Pass 2 (rectified crop): use full 13-class output
Base: nvidia/mit-b0 (clean ImageNet-pretrained encoder, fresh segmentation head)
"""
import os
import io
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from datasets import load_dataset, concatenate_datasets, Dataset
from transformers import (
SegformerImageProcessor,
SegformerForSemanticSegmentation,
TrainingArguments,
Trainer,
)
import evaluate
from torchvision.transforms import ColorJitter, RandomPerspective
# ─── Configuration ──────────────────────────────────────────────────────────
HUB_MODEL_ID = "Marco333/segformer-b0-facade-mixed"
BASE_MODEL = "nvidia/mit-b0"
OUTPUT_DIR = "./segformer-b0-facade-mixed"
NUM_LABELS = 13
id2label = {
0: "background",
1: "facade",
2: "molding",
3: "cornice",
4: "pillar",
5: "window",
6: "door",
7: "sill",
8: "blind",
9: "balcony",
10: "shop",
11: "deco",
12: "vegetation",
}
label2id = {v: k for k, v in id2label.items()}
# ─── Coarse grouping for Pass 1 (unrectified street photo inference) ────────
# Maps fine-grained class IDs β†’ coarse group names
# Usage: coarse_pred = COARSE_MAP[fine_pred] (numpy fancy indexing)
COARSE_LABELS = ["background", "facade_wall", "window", "door", "balcony", "vegetation"]
COARSE_MAP = np.array([
0, # 0 background β†’ background
1, # 1 facade β†’ facade_wall
1, # 2 molding β†’ facade_wall
1, # 3 cornice β†’ facade_wall
1, # 4 pillar β†’ facade_wall
2, # 5 window β†’ window
3, # 6 door β†’ door
1, # 7 sill β†’ facade_wall
2, # 8 blind β†’ window
4, # 9 balcony β†’ balcony
3, # 10 shop β†’ door
1, # 11 deco β†’ facade_wall
5, # 12 vegetation β†’ vegetation
], dtype=np.uint8)
# ─── Label Remapping Tables ─────────────────────────────────────────────────
# CMP Facade: paletted PNG, values 1-12 β†’ preserve all detail classes
CMP_REMAP = np.full(256, 255, dtype=np.uint8)
CMP_REMAP[0] = 255 # unlabeled -> ignore
CMP_REMAP[1] = 1 # facade -> facade
CMP_REMAP[2] = 2 # molding -> molding
CMP_REMAP[3] = 3 # cornice -> cornice
CMP_REMAP[4] = 4 # pillar -> pillar
CMP_REMAP[5] = 5 # window -> window
CMP_REMAP[6] = 6 # door -> door
CMP_REMAP[7] = 7 # sill -> sill
CMP_REMAP[8] = 8 # blind -> blind
CMP_REMAP[9] = 9 # balcony -> balcony
CMP_REMAP[10] = 10 # shop -> shop
CMP_REMAP[11] = 11 # deco -> deco
CMP_REMAP[12] = 0 # background -> background
# ADE20K: grayscale 0-150 (1-indexed, 0=unlabeled)
# Maps to coarse equivalents in the 13-class taxonomy.
# ADE20K has no molding/cornice/sill/etc β€” those stay 255 (ignore).
ADE_REMAP = np.full(256, 255, dtype=np.uint8)
ADE_REMAP[0] = 255 # unlabeled -> ignore
ADE_REMAP[1] = 1 # wall -> facade
ADE_REMAP[2] = 1 # building -> facade
ADE_REMAP[3] = 0 # sky -> background
ADE_REMAP[5] = 12 # tree -> vegetation
ADE_REMAP[7] = 0 # road -> background
ADE_REMAP[9] = 5 # windowpane -> window
ADE_REMAP[10] = 0 # grass -> background
ADE_REMAP[12] = 0 # sidewalk -> background
ADE_REMAP[13] = 0 # person -> background
ADE_REMAP[14] = 0 # earth -> background
ADE_REMAP[15] = 6 # door -> door
ADE_REMAP[17] = 0 # mountain -> background
ADE_REMAP[18] = 12 # plant -> vegetation
ADE_REMAP[21] = 0 # car -> background
ADE_REMAP[22] = 0 # water -> background
ADE_REMAP[26] = 1 # house -> facade
ADE_REMAP[33] = 0 # fence -> background
ADE_REMAP[39] = 0 # railing -> background
ADE_REMAP[43] = 4 # column -> pillar
ADE_REMAP[49] = 1 # skyscraper -> facade
ADE_REMAP[54] = 0 # stairs -> background
ADE_REMAP[87] = 0 # awning -> background
ADE_REMAP[94] = 0 # pole -> background
def decode_image(data):
"""Decode image from dict-with-bytes or PIL Image."""
if isinstance(data, dict) and "bytes" in data:
return Image.open(io.BytesIO(data["bytes"]))
return data
def load_cmp():
"""Load and remap CMP Facade dataset."""
print("Loading CMP Facade dataset...")
ds = load_dataset("Xpitfire/cmp_facade")
out = {}
for split_name in ds.keys():
images, labels = [], []
for i in range(len(ds[split_name])):
ex = ds[split_name][i]
img = decode_image(ex["pixel_values"]).convert("RGB")
lbl = decode_image(ex["label"])
arr = np.array(lbl, dtype=np.uint8)
remapped = CMP_REMAP[arr]
labels.append(Image.fromarray(remapped, mode="L"))
images.append(img)
out[split_name] = Dataset.from_dict({"image": images, "annotation": labels})
print(f" CMP {split_name}: {len(images)} images")
return out
def load_ade20k():
"""Load ADE20K, filter to building-containing scenes, remap labels."""
print("Loading ADE20K scene_parse_150...")
ds = load_dataset("merve/scene_parse_150")
building_ids = {1, 2, 26, 43, 49}
MIN_BUILDING_FRACTION = 0.03
out = {}
for split_name in ["train", "validation"]:
if split_name not in ds:
continue
images, labels = [], []
skipped = 0
for i in range(len(ds[split_name])):
ex = ds[split_name][i]
ann = ex["annotation"]
if ann is None:
skipped += 1
continue
arr = np.array(ann, dtype=np.uint8)
frac = np.isin(arr, list(building_ids)).sum() / arr.size
if frac < MIN_BUILDING_FRACTION:
skipped += 1
continue
remapped = ADE_REMAP[arr]
labels.append(Image.fromarray(remapped, mode="L"))
images.append(ex["image"].convert("RGB"))
out[split_name] = Dataset.from_dict({"image": images, "annotation": labels})
print(f" ADE20K {split_name}: {len(images)} kept, {skipped} skipped")
return out
def main():
# ─── Load datasets ────────────────────────────────────────────────────────
cmp = load_cmp()
ade = load_ade20k()
# Combine: CMP train+test + ADE20K train -> train
# CMP eval + ADE20K validation -> val
train_parts = []
for s in ["train", "test"]:
if s in cmp:
train_parts.append(cmp[s])
if "train" in ade:
train_parts.append(ade["train"])
val_parts = []
if "eval" in cmp:
val_parts.append(cmp["eval"])
if "validation" in ade:
val_parts.append(ade["validation"])
train_ds = concatenate_datasets(train_parts)
val_ds = concatenate_datasets(val_parts)
print(f"\nFinal dataset: train={len(train_ds)}, val={len(val_ds)}")
# ─── Image processor ──────────────────────────────────────────────────────
image_processor = SegformerImageProcessor.from_pretrained(
BASE_MODEL,
do_reduce_labels=False,
size={"height": 512, "width": 512},
)
# ─── Augmentation ─────────────────────────────────────────────────────────
color_jitter = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05)
perspective_img = RandomPerspective(distortion_scale=0.3, p=0.4, fill=0)
perspective_lbl = RandomPerspective(distortion_scale=0.3, p=0.4, fill=255)
def train_transforms(batch):
imgs, lbls = [], []
for img, ann in zip(batch["image"], batch["annotation"]):
img = color_jitter(img)
seed = torch.randint(0, 2**32, (1,)).item()
torch.manual_seed(seed)
img = perspective_img(img)
torch.manual_seed(seed)
ann = perspective_lbl(ann)
imgs.append(img)
lbls.append(ann)
return image_processor(imgs, lbls)
def val_transforms(batch):
return image_processor(
[x for x in batch["image"]],
[x for x in batch["annotation"]],
)
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)
# ─── Model ────────────────────────────────────────────────────────────────
print(f"\nLoading model from {BASE_MODEL} (clean encoder, fresh seg head)...")
model = SegformerForSemanticSegmentation.from_pretrained(
BASE_MODEL,
id2label=id2label,
label2id=label2id,
num_labels=NUM_LABELS,
ignore_mismatched_sizes=True,
)
# ─── Metrics ──────────────────────────────────────────────────────────────
metric = evaluate.load("mean_iou")
def compute_metrics(eval_pred):
with torch.no_grad():
logits, labels = eval_pred
logits_tensor = torch.from_numpy(logits)
logits_tensor = nn.functional.interpolate(
logits_tensor,
size=labels.shape[-2:],
mode="bilinear",
align_corners=False,
).argmax(dim=1)
pred_labels = logits_tensor.detach().cpu().numpy()
metrics = metric.compute(
predictions=pred_labels,
references=labels,
num_labels=NUM_LABELS,
ignore_index=255,
reduce_labels=False,
)
for key, value in metrics.items():
if isinstance(value, np.ndarray):
metrics[key] = value.tolist()
return metrics
# ─── Training arguments ───────────────────────────────────────────────────
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
hub_model_id=HUB_MODEL_ID,
push_to_hub=True,
# Optimizer
learning_rate=6e-5,
lr_scheduler_type="polynomial",
warmup_ratio=0.1,
weight_decay=0.01,
# Epochs & batches
num_train_epochs=80,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=2, # effective batch = 8
# Eval & saving
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="mean_iou",
greater_is_better=True,
eval_accumulation_steps=5,
# Logging & monitoring
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
report_to=["trackio"],
run_name="segformer-b0-facade-mixed",
# Critical for segmentation tasks
remove_unused_columns=False,
label_names=["labels"],
# Performance
fp16=True,
dataloader_num_workers=4,
)
# ─── Trainer ──────────────────────────────────────────────────────────────
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics,
)
print("\nStarting training...")
trainer.train()
print("\nPushing best model to HuggingFace Hub...")
trainer.push_to_hub(
commit_message="SegFormer-B0 facade mixed 13-class: rectified (CMP) + unrectified (ADE20K) β€” clean nvidia/mit-b0 base",
)
image_processor.save_pretrained(OUTPUT_DIR)
image_processor.push_to_hub(HUB_MODEL_ID)
print(f"\nDone! Model at: https://huggingface.co/{HUB_MODEL_ID}")
if __name__ == "__main__":
main()