LeJEPA ViT-Small/8 (5-band Sentinel-2, PoC)

A small self-supervised feature extractor for Sentinel-2 L2A imagery: a 5-channel ViT-Small/8 pretrained with the LeJEPA objective (Balestriero & LeCun, 2025) on the falafel-hockey/sentinel2-lejepa-global-diverse-256 chip dataset. Built as the end-to-end reproducibility artifact for the Sentinel Change Explorer foundation-model change-detection proof of concept.

This is a proof of concept, not a general-purpose EO model. The training set is tiny, the training budget is small, and the features are only as strong as ~3450 gradient steps on ~(full dataset) chips can make them. Use it to reproduce the companion app's "Experimental" panel, not as a substitute for Clay, Prithvi, or SSL4EO-S12 models.

Architecture

Component Details
Encoder kind vit_small_patch8
Backbone Vision Transformer (ViT-Small, patch size 8) with 4 register tokens. Same recipe as the Tiny variant but with 384-dim embeddings and 6 heads, trained on GPU for higher feature quality.
Input [B, 5, 128, 128] uint16-derived reflectance (10 m/px)
Output [B, 384, 32, 32] feature map
Feature grid 32x32 = 1024 positions
Parameter count ~21.8M
Band order red, green, blue, nir, swir16 (S2 B02/B03/B04/B08/B11)

Why this architecture. The 384-dim 16×16 grid is the gold visualization target for this project; see the Tiny card for the full rationale on register tokens.

A small MLP predictor head is used during training and is not included in this release; only the encoder is.

Training recipe

Hyperparameter Value
Objective LeJEPA (predictive smooth-L1 + SIGReg)
Mask scheme 32x32 feature grid, disjoint context/target subsets
Context aggregation mean-pool over visible positions
Target encoder EMA of online, momentum ramp 0.996 → 1.0
Optimizer AdamW (lr=0.001, weight_decay=0.05)
LR schedule Cosine annealing over 3450 steps
SIGReg weight (α) 1.0
Batch size 192
Epochs 150
Training chips (full dataset)
Precision BF16 mixed
Device CUDA
Final loss 268.9120
Training date 2026-04-05

Normalization

The encoder expects inputs normalized with the per-band statistics computed over the training split of the companion dataset. These stats ship with the checkpoint (.pt file key norm_stats) so inference does not need to pull the dataset:

Band Mean Std
red 1298.91 1192.39
green 1086.62 908.00
blue 830.22 846.53
nir 2467.29 1264.88
swir16 2357.63 1504.00

Intended use

  • Yes: feature extraction for small-area Sentinel-2 L2A tiles (one or a few 128x128 chips), PCA→RGB visualization, per-patch cosine-distance change detection against another time step of the same AOI.
  • No: high-recall general EO feature extraction, downstream classification without fine-tuning, deployment to workloads where feature quality matters.

Limitations

  • Tiny training budget. This PoC was trained on the hardware noted above; a full GPU run is a planned follow-up that will swap the weights without any API change.
  • Preset bias. 70% of training chips came from 5 specific demo AOIs in the companion app, so the encoder is best on geography resembling those (coastal/tropical, desert-construction, central-European industrial, agricultural-flood-plain, forested-burn).
  • Single sensor, single level. Sentinel-2 L2A only. No S1 SAR, no Landsat, no atmospheric TOA inputs.
  • 5 bands only. Red-edge, cirrus, and SWIR22 are intentionally excluded to keep the model compact for CPU inference.
  • No downstream supervised head. Use mean(dim=[2,3]) on the feature map to get a global descriptor, or operate on the raw spatial map.

Usage

import torch
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(
    repo_id="falafel-hockey/lejepa-vit-small-patch8-256-sentinel2-5band",
    filename="lejepa_vit_small_patch8_256_5band.pt",
)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

# The encoder factory lives in the companion repo:
#   git clone https://github.com/awheelis/sentinel-change-explorer
#   from src.experimental.encoders import build_encoder
from src.experimental.encoders import build_encoder

encoder = build_encoder(ckpt["config"]["encoder_kind"])
encoder.load_state_dict(ckpt["encoder_state"])
encoder.eval()

# Normalize with the stats packed into the checkpoint
norm = ckpt["norm_stats"]
mean = torch.tensor(norm["mean"]).view(1, 5, 1, 1)
std = torch.tensor(norm["std"]).view(1, 5, 1, 1).clamp(min=1.0)

# bands: [B, 5, 128, 128] float tensor in raw uint16-derived reflectance
# feat:  [B, 384, 32, 32]
with torch.no_grad():
    x = (bands - mean) / std
    feat = encoder(x)

License and attribution

Citation

@misc{lejepa_vit_small_patch8_sentinel2_5band,
  title  = {LeJEPA ViT-Small/8 (5-band Sentinel-2, PoC)},
  author = {Wheelis, Alex},
  year   = {2026},
  url    = {https://huggingface.co/falafel-hockey/lejepa-vit-small-patch8-256-sentinel2-5band}
}

@misc{balestriero2025lejepa,
  title  = {LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics},
  author = {Balestriero, Randall and LeCun, Yann},
  year   = {2025},
  eprint = {2511.08544},
  archivePrefix = {arXiv}
}

@misc{darcet2024registers,
  title  = {Vision Transformers Need Registers},
  author = {Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr},
  year   = {2024},
  eprint = {2309.16588},
  archivePrefix = {arXiv}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train falafel-hockey/lejepa-vit-small-patch8-256-sentinel2-5band

Papers for falafel-hockey/lejepa-vit-small-patch8-256-sentinel2-5band