| """TerraMind micro-finetune on NYC labels — proof it works on AMD MI300X. |
| |
| Goal: show the smallest possible *real* fine-tune of TerraMind on NYC |
| data converges loss in a few seconds on the MI300X. Not building a |
| useful model — just showing the end-to-end loop (load → forward → |
| backward → step) works on the AMD ROCm path with terratorch. |
| |
| Setup: |
| - TerraMind v1 base ENCODER (frozen) — the multimodal foundation |
| model's vision encoder, ~300 M params. |
| - Tiny classification head on top — single linear layer over the |
| pooled patch embedding, 2-class output. |
| - 8 synthetic NYC samples (6-band S2L2A 224×224 tensors). Labels |
| are deterministic based on the synthetic input (a function of |
| the NIR band's mean) so the head has a real signal to learn. |
| - 30 SGD steps with Adam. Print loss + accuracy + elapsed. |
| |
| This isn't a useful classifier — the labels are synthetic. But it |
| proves: weights load on AMD, forward pass works, gradients flow, |
| optimizer steps. Real NYC fine-tune would replace the synthetic |
| labels with actual Sandy-inundation-zone membership at the chip's |
| center coord (we have the polygon in data/sandy_inundation.geojson). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import time |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def main(): |
| |
| |
| import terratorch.models.backbones.terramind.model.terramind_register |
| from terratorch.registry import BACKBONE_REGISTRY |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"[micro] device: {device}") |
| if device == "cuda": |
| props = torch.cuda.get_device_properties(0) |
| print(f"[micro] gpu: {torch.cuda.get_device_name(0)}, " |
| f"VRAM={props.total_memory/1e9:.1f} GB") |
|
|
| |
| print("[micro] loading terramind_v1_base encoder...") |
| t0 = time.time() |
| backbone = BACKBONE_REGISTRY.build( |
| "terratorch_terramind_v1_base", |
| modalities=["S2L2A"], |
| pretrained=True, |
| ) |
| backbone.eval() |
| backbone.to(device) |
| |
| for p in backbone.parameters(): |
| p.requires_grad = False |
| print(f"[micro] backbone loaded in {time.time()-t0:.2f}s; " |
| f"params={sum(p.numel() for p in backbone.parameters()):,}") |
|
|
| |
| |
| |
| |
| n_samples = 8 |
| img_size = 224 |
| |
| |
| |
| bands = 12 |
| torch.manual_seed(42) |
| x = torch.rand(n_samples, bands, img_size, img_size, device=device) |
| |
| y = (x[:, 8].mean(dim=(1, 2)) > 0.5).long().to(device) |
| print(f"[micro] dataset: {n_samples} samples, " |
| f"label balance: {y.float().mean().item():.2f} positive") |
|
|
| |
| print("[micro] probing embedding shape...") |
| with torch.no_grad(): |
| |
| |
| |
| out = backbone({"S2L2A": x[:1]}) |
| |
| if isinstance(out, (list, tuple)): |
| out_t = out[0] |
| elif isinstance(out, dict): |
| out_t = next(iter(out.values())) |
| else: |
| out_t = out |
| print(f"[micro] backbone output type={type(out).__name__}, " |
| f"shape={tuple(out_t.shape) if hasattr(out_t, 'shape') else 'n/a'}") |
| |
| |
| emb_dim = out_t.shape[-1] |
| print(f"[micro] embedding dim: {emb_dim}") |
|
|
| |
| head = torch.nn.Linear(emb_dim, 2).to(device) |
| optimizer = torch.optim.Adam(head.parameters(), lr=1e-3) |
|
|
| |
| n_steps = 30 |
| print(f"[micro] training {n_steps} steps...") |
| t0 = time.time() |
| losses = [] |
| accs = [] |
| for step in range(n_steps): |
| with torch.no_grad(): |
| emb = backbone({"S2L2A": x}) |
| if isinstance(emb, (list, tuple)): |
| emb = emb[0] |
| elif isinstance(emb, dict): |
| emb = next(iter(emb.values())) |
| |
| if emb.ndim == 3: |
| emb = emb.mean(dim=1) |
| logits = head(emb) |
| loss = F.cross_entropy(logits, y) |
| acc = (logits.argmax(-1) == y).float().mean().item() |
| losses.append(loss.item()) |
| accs.append(acc) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| if step == 0 or (step + 1) % 5 == 0 or step == n_steps - 1: |
| print(f"[micro] step {step+1:2d}/{n_steps} " |
| f"loss={loss.item():.4f} acc={acc:.2f}") |
| elapsed = time.time() - t0 |
| print() |
| print(f"[micro] DONE — {n_steps} steps in {elapsed:.2f}s " |
| f"({elapsed/n_steps*1000:.0f} ms/step)") |
| print(f"[micro] loss: {losses[0]:.4f} -> {losses[-1]:.4f} " |
| f"({(losses[0]-losses[-1])/losses[0]*100:+.1f}% reduction)") |
| print(f"[micro] accuracy: {accs[0]:.2f} -> {accs[-1]:.2f}") |
| print() |
| print("[micro] ✓ TerraMind fine-tune loop working on AMD MI300X") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|