seriffic's picture
Backend evolution: Phases 1-10 specialists + agentic FSM + Mellea + LiteLLM router
6a82282
"""TerraMind micro-finetune on NYC labels — proof it works on AMD MI300X.
Goal: show the smallest possible *real* fine-tune of TerraMind on NYC
data converges loss in a few seconds on the MI300X. Not building a
useful model — just showing the end-to-end loop (load → forward →
backward → step) works on the AMD ROCm path with terratorch.
Setup:
- TerraMind v1 base ENCODER (frozen) — the multimodal foundation
model's vision encoder, ~300 M params.
- Tiny classification head on top — single linear layer over the
pooled patch embedding, 2-class output.
- 8 synthetic NYC samples (6-band S2L2A 224×224 tensors). Labels
are deterministic based on the synthetic input (a function of
the NIR band's mean) so the head has a real signal to learn.
- 30 SGD steps with Adam. Print loss + accuracy + elapsed.
This isn't a useful classifier — the labels are synthetic. But it
proves: weights load on AMD, forward pass works, gradients flow,
optimizer steps. Real NYC fine-tune would replace the synthetic
labels with actual Sandy-inundation-zone membership at the chip's
center coord (we have the polygon in data/sandy_inundation.geojson).
"""
from __future__ import annotations
import time
import torch
import torch.nn.functional as F
def main():
# Force-import the terramind registration module so the registry
# gets populated.
import terratorch.models.backbones.terramind.model.terramind_register # noqa
from terratorch.registry import BACKBONE_REGISTRY
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[micro] device: {device}")
if device == "cuda":
props = torch.cuda.get_device_properties(0)
print(f"[micro] gpu: {torch.cuda.get_device_name(0)}, "
f"VRAM={props.total_memory/1e9:.1f} GB")
# ---- Load TerraMind base ENCODER (not the generative variant) ----
print("[micro] loading terramind_v1_base encoder...")
t0 = time.time()
backbone = BACKBONE_REGISTRY.build(
"terratorch_terramind_v1_base",
modalities=["S2L2A"], # one input modality is enough for fine-tune
pretrained=True,
)
backbone.eval()
backbone.to(device)
# Freeze backbone — we're only training the head.
for p in backbone.parameters():
p.requires_grad = False
print(f"[micro] backbone loaded in {time.time()-t0:.2f}s; "
f"params={sum(p.numel() for p in backbone.parameters()):,}")
# ---- Build synthetic NYC dataset --------------------------------
# 8 samples. Labels = (1 if mean of NIR band > 0.5 else 0).
# Real fine-tune would use Sandy-zone membership at the chip's
# geographic center, derived from data/sandy_inundation.geojson.
n_samples = 8
img_size = 224
# TerraMind v1's S2L2A encoder is trained on the full 12-band L2A
# tensor (B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12).
# 6-band subsets (Phase 1's Sen1Floods11 ordering) won't work here.
bands = 12
torch.manual_seed(42)
x = torch.rand(n_samples, bands, img_size, img_size, device=device)
# Synthetic label rule: True if Narrow NIR (band index 8 = B8A) mean > 0.5.
y = (x[:, 8].mean(dim=(1, 2)) > 0.5).long().to(device)
print(f"[micro] dataset: {n_samples} samples, "
f"label balance: {y.float().mean().item():.2f} positive")
# ---- Embedding shape probe -------------------------------------
print("[micro] probing embedding shape...")
with torch.no_grad():
# TerraMind backbone expects (B, C, T, H, W) for time-series, or
# the encoder takes a dict with the modality key. Use the dict
# form that the registry's modality config knows about.
out = backbone({"S2L2A": x[:1]})
# Embedding shape varies by variant; check what we got.
if isinstance(out, (list, tuple)):
out_t = out[0]
elif isinstance(out, dict):
out_t = next(iter(out.values()))
else:
out_t = out
print(f"[micro] backbone output type={type(out).__name__}, "
f"shape={tuple(out_t.shape) if hasattr(out_t, 'shape') else 'n/a'}")
# Pool to per-sample embedding. Output is typically
# (B, num_patches, dim) for ViT-style; mean-pool patches.
emb_dim = out_t.shape[-1]
print(f"[micro] embedding dim: {emb_dim}")
# ---- Tiny linear head -------------------------------------------
head = torch.nn.Linear(emb_dim, 2).to(device)
optimizer = torch.optim.Adam(head.parameters(), lr=1e-3)
# ---- Train loop --------------------------------------------------
n_steps = 30
print(f"[micro] training {n_steps} steps...")
t0 = time.time()
losses = []
accs = []
for step in range(n_steps):
with torch.no_grad():
emb = backbone({"S2L2A": x})
if isinstance(emb, (list, tuple)):
emb = emb[0]
elif isinstance(emb, dict):
emb = next(iter(emb.values()))
# Mean-pool over patches (axis 1) -> (B, dim)
if emb.ndim == 3:
emb = emb.mean(dim=1)
logits = head(emb)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(-1) == y).float().mean().item()
losses.append(loss.item())
accs.append(acc)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step == 0 or (step + 1) % 5 == 0 or step == n_steps - 1:
print(f"[micro] step {step+1:2d}/{n_steps} "
f"loss={loss.item():.4f} acc={acc:.2f}")
elapsed = time.time() - t0
print()
print(f"[micro] DONE — {n_steps} steps in {elapsed:.2f}s "
f"({elapsed/n_steps*1000:.0f} ms/step)")
print(f"[micro] loss: {losses[0]:.4f} -> {losses[-1]:.4f} "
f"({(losses[0]-losses[-1])/losses[0]*100:+.1f}% reduction)")
print(f"[micro] accuracy: {accs[0]:.2f} -> {accs[-1]:.2f}")
print()
print("[micro] ✓ TerraMind fine-tune loop working on AMD MI300X")
if __name__ == "__main__":
main()