| |
| """ |
| Train SegFormer-B0 for facade segmentation on mixed rectified + unrectified data. |
| |
| Sources: |
| - CMP Facade (Xpitfire/cmp_facade) - rectified facades, ~492 images |
| - ADE20K scene_parse_150 (merve/scene_parse_150) - unrectified street-level perspective, |
| filtered to building-containing scenes |
| |
| 13-class taxonomy (preserves all CMP detail classes): |
| 0: background 7: sill |
| 1: facade 8: blind |
| 2: molding 9: balcony |
| 3: cornice 10: shop |
| 4: pillar 11: deco |
| 5: window 12: vegetation |
| 6: door |
| |
| Two-pass inference: |
| Pass 1 (unrectified street photo): collapse to coarse groups via COARSE_MAP |
| Pass 2 (rectified crop): use full 13-class output |
| |
| Base: nvidia/mit-b0 (clean ImageNet-pretrained encoder, fresh segmentation head) |
| """ |
|
|
| import os |
| import io |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from datasets import load_dataset, concatenate_datasets, Dataset |
| from transformers import ( |
| SegformerImageProcessor, |
| SegformerForSemanticSegmentation, |
| TrainingArguments, |
| Trainer, |
| ) |
| import evaluate |
| from torchvision.transforms import ColorJitter, RandomPerspective |
|
|
|
|
| |
| HUB_MODEL_ID = "Marco333/segformer-b0-facade-mixed" |
| BASE_MODEL = "nvidia/mit-b0" |
| OUTPUT_DIR = "./segformer-b0-facade-mixed" |
| NUM_LABELS = 13 |
|
|
| id2label = { |
| 0: "background", |
| 1: "facade", |
| 2: "molding", |
| 3: "cornice", |
| 4: "pillar", |
| 5: "window", |
| 6: "door", |
| 7: "sill", |
| 8: "blind", |
| 9: "balcony", |
| 10: "shop", |
| 11: "deco", |
| 12: "vegetation", |
| } |
| label2id = {v: k for k, v in id2label.items()} |
|
|
| |
| |
| |
| COARSE_LABELS = ["background", "facade_wall", "window", "door", "balcony", "vegetation"] |
| COARSE_MAP = np.array([ |
| 0, |
| 1, |
| 1, |
| 1, |
| 1, |
| 2, |
| 3, |
| 1, |
| 2, |
| 4, |
| 3, |
| 1, |
| 5, |
| ], dtype=np.uint8) |
|
|
|
|
| |
|
|
| |
| CMP_REMAP = np.full(256, 255, dtype=np.uint8) |
| CMP_REMAP[0] = 255 |
| CMP_REMAP[1] = 1 |
| CMP_REMAP[2] = 2 |
| CMP_REMAP[3] = 3 |
| CMP_REMAP[4] = 4 |
| CMP_REMAP[5] = 5 |
| CMP_REMAP[6] = 6 |
| CMP_REMAP[7] = 7 |
| CMP_REMAP[8] = 8 |
| CMP_REMAP[9] = 9 |
| CMP_REMAP[10] = 10 |
| CMP_REMAP[11] = 11 |
| CMP_REMAP[12] = 0 |
|
|
| |
| |
| |
| ADE_REMAP = np.full(256, 255, dtype=np.uint8) |
| ADE_REMAP[0] = 255 |
| ADE_REMAP[1] = 1 |
| ADE_REMAP[2] = 1 |
| ADE_REMAP[3] = 0 |
| ADE_REMAP[5] = 12 |
| ADE_REMAP[7] = 0 |
| ADE_REMAP[9] = 5 |
| ADE_REMAP[10] = 0 |
| ADE_REMAP[12] = 0 |
| ADE_REMAP[13] = 0 |
| ADE_REMAP[14] = 0 |
| ADE_REMAP[15] = 6 |
| ADE_REMAP[17] = 0 |
| ADE_REMAP[18] = 12 |
| ADE_REMAP[21] = 0 |
| ADE_REMAP[22] = 0 |
| ADE_REMAP[26] = 1 |
| ADE_REMAP[33] = 0 |
| ADE_REMAP[39] = 0 |
| ADE_REMAP[43] = 4 |
| ADE_REMAP[49] = 1 |
| ADE_REMAP[54] = 0 |
| ADE_REMAP[87] = 0 |
| ADE_REMAP[94] = 0 |
|
|
|
|
| def decode_image(data): |
| """Decode image from dict-with-bytes or PIL Image.""" |
| if isinstance(data, dict) and "bytes" in data: |
| return Image.open(io.BytesIO(data["bytes"])) |
| return data |
|
|
|
|
| def load_cmp(): |
| """Load and remap CMP Facade dataset.""" |
| print("Loading CMP Facade dataset...") |
| ds = load_dataset("Xpitfire/cmp_facade") |
| out = {} |
| for split_name in ds.keys(): |
| images, labels = [], [] |
| for i in range(len(ds[split_name])): |
| ex = ds[split_name][i] |
| img = decode_image(ex["pixel_values"]).convert("RGB") |
| lbl = decode_image(ex["label"]) |
| arr = np.array(lbl, dtype=np.uint8) |
| remapped = CMP_REMAP[arr] |
| labels.append(Image.fromarray(remapped, mode="L")) |
| images.append(img) |
| out[split_name] = Dataset.from_dict({"image": images, "annotation": labels}) |
| print(f" CMP {split_name}: {len(images)} images") |
| return out |
|
|
|
|
| def load_ade20k(): |
| """Load ADE20K, filter to building-containing scenes, remap labels.""" |
| print("Loading ADE20K scene_parse_150...") |
| ds = load_dataset("merve/scene_parse_150") |
| building_ids = {1, 2, 26, 43, 49} |
| MIN_BUILDING_FRACTION = 0.03 |
| out = {} |
| for split_name in ["train", "validation"]: |
| if split_name not in ds: |
| continue |
| images, labels = [], [] |
| skipped = 0 |
| for i in range(len(ds[split_name])): |
| ex = ds[split_name][i] |
| ann = ex["annotation"] |
| if ann is None: |
| skipped += 1 |
| continue |
| arr = np.array(ann, dtype=np.uint8) |
| frac = np.isin(arr, list(building_ids)).sum() / arr.size |
| if frac < MIN_BUILDING_FRACTION: |
| skipped += 1 |
| continue |
| remapped = ADE_REMAP[arr] |
| labels.append(Image.fromarray(remapped, mode="L")) |
| images.append(ex["image"].convert("RGB")) |
| out[split_name] = Dataset.from_dict({"image": images, "annotation": labels}) |
| print(f" ADE20K {split_name}: {len(images)} kept, {skipped} skipped") |
| return out |
|
|
|
|
| def main(): |
| |
| cmp = load_cmp() |
| ade = load_ade20k() |
|
|
| |
| |
| train_parts = [] |
| for s in ["train", "test"]: |
| if s in cmp: |
| train_parts.append(cmp[s]) |
| if "train" in ade: |
| train_parts.append(ade["train"]) |
|
|
| val_parts = [] |
| if "eval" in cmp: |
| val_parts.append(cmp["eval"]) |
| if "validation" in ade: |
| val_parts.append(ade["validation"]) |
|
|
| train_ds = concatenate_datasets(train_parts) |
| val_ds = concatenate_datasets(val_parts) |
| print(f"\nFinal dataset: train={len(train_ds)}, val={len(val_ds)}") |
|
|
| |
| image_processor = SegformerImageProcessor.from_pretrained( |
| BASE_MODEL, |
| do_reduce_labels=False, |
| size={"height": 512, "width": 512}, |
| ) |
|
|
| |
| color_jitter = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05) |
| perspective_img = RandomPerspective(distortion_scale=0.3, p=0.4, fill=0) |
| perspective_lbl = RandomPerspective(distortion_scale=0.3, p=0.4, fill=255) |
|
|
| def train_transforms(batch): |
| imgs, lbls = [], [] |
| for img, ann in zip(batch["image"], batch["annotation"]): |
| img = color_jitter(img) |
| seed = torch.randint(0, 2**32, (1,)).item() |
| torch.manual_seed(seed) |
| img = perspective_img(img) |
| torch.manual_seed(seed) |
| ann = perspective_lbl(ann) |
| imgs.append(img) |
| lbls.append(ann) |
| return image_processor(imgs, lbls) |
|
|
| def val_transforms(batch): |
| return image_processor( |
| [x for x in batch["image"]], |
| [x for x in batch["annotation"]], |
| ) |
|
|
| train_ds.set_transform(train_transforms) |
| val_ds.set_transform(val_transforms) |
|
|
| |
| print(f"\nLoading model from {BASE_MODEL} (clean encoder, fresh seg head)...") |
| model = SegformerForSemanticSegmentation.from_pretrained( |
| BASE_MODEL, |
| id2label=id2label, |
| label2id=label2id, |
| num_labels=NUM_LABELS, |
| ignore_mismatched_sizes=True, |
| ) |
|
|
| |
| metric = evaluate.load("mean_iou") |
|
|
| def compute_metrics(eval_pred): |
| with torch.no_grad(): |
| logits, labels = eval_pred |
| logits_tensor = torch.from_numpy(logits) |
| logits_tensor = nn.functional.interpolate( |
| logits_tensor, |
| size=labels.shape[-2:], |
| mode="bilinear", |
| align_corners=False, |
| ).argmax(dim=1) |
| pred_labels = logits_tensor.detach().cpu().numpy() |
| metrics = metric.compute( |
| predictions=pred_labels, |
| references=labels, |
| num_labels=NUM_LABELS, |
| ignore_index=255, |
| reduce_labels=False, |
| ) |
| for key, value in metrics.items(): |
| if isinstance(value, np.ndarray): |
| metrics[key] = value.tolist() |
| return metrics |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=OUTPUT_DIR, |
| hub_model_id=HUB_MODEL_ID, |
| push_to_hub=True, |
|
|
| |
| learning_rate=6e-5, |
| lr_scheduler_type="polynomial", |
| warmup_ratio=0.1, |
| weight_decay=0.01, |
|
|
| |
| num_train_epochs=80, |
| per_device_train_batch_size=4, |
| per_device_eval_batch_size=4, |
| gradient_accumulation_steps=2, |
|
|
| |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| save_total_limit=3, |
| load_best_model_at_end=True, |
| metric_for_best_model="mean_iou", |
| greater_is_better=True, |
| eval_accumulation_steps=5, |
|
|
| |
| logging_strategy="steps", |
| logging_steps=10, |
| logging_first_step=True, |
| disable_tqdm=True, |
| report_to=["trackio"], |
| run_name="segformer-b0-facade-mixed", |
|
|
| |
| remove_unused_columns=False, |
| label_names=["labels"], |
|
|
| |
| fp16=True, |
| dataloader_num_workers=4, |
| ) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_ds, |
| eval_dataset=val_ds, |
| compute_metrics=compute_metrics, |
| ) |
|
|
| print("\nStarting training...") |
| trainer.train() |
|
|
| print("\nPushing best model to HuggingFace Hub...") |
| trainer.push_to_hub( |
| commit_message="SegFormer-B0 facade mixed 13-class: rectified (CMP) + unrectified (ADE20K) β clean nvidia/mit-b0 base", |
| ) |
| image_processor.save_pretrained(OUTPUT_DIR) |
| image_processor.push_to_hub(HUB_MODEL_ID) |
|
|
| print(f"\nDone! Model at: https://huggingface.co/{HUB_MODEL_ID}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|