seriffic's picture
Backend evolution: Phases 1-10 specialists + agentic FSM + Mellea + LiteLLM router
6a82282
"""TerraMind-NYC inference ensemble: one base, hot-swap adapters.
This is what Riprap's FSM specialist nodes consume. Loads the TerraMind
1.0 base model once into memory, then swaps a single active adapter
(LULC / TiM / Buildings) per task call. Per ADR-007 we don't merge
adapters — sequential swap is simpler and matches our deployment shape.
Usage:
from shared.inference_ensemble import TerraMindNYCEnsemble
ens = TerraMindNYCEnsemble(adapters_root="adapters/")
out = ens.infer(s2l2a_chip, s1rtc_chip, dem_chip, tasks=["lulc", "buildings"])
# {"lulc": [5, 224, 224] long, "buildings": [2, 224, 224] long, ...}
The first call materializes the base; subsequent task switches reuse it.
The adapter swap is ~50 ms per task per call, dominated by file I/O the
first time and a state-dict overwrite afterwards.
"""
from __future__ import annotations
import sys
from dataclasses import dataclass, field
from pathlib import Path
import torch
import yaml
from peft import LoraConfig, inject_adapter_in_model
from safetensors.torch import load_file
sys.path.insert(0, str(Path(__file__).parent))
from train_lora import build_task # noqa: E402
@dataclass
class AdapterSlot:
name: str
config: dict
lora_state: dict = field(default_factory=dict)
head_state: dict = field(default_factory=dict)
task: object | None = None # lazy-built Lightning task
num_classes: int = 0
class TerraMindNYCEnsemble:
"""One TerraMind base per num_classes group, N adapters total.
Per-adapter num_classes differs (LULC=5, Buildings=2) so each
adapter gets its own Lightning task with the right segmentation
head shape. Tasks are lazy-built on first set_adapter call. The
base TerraMind weights are duplicated across tasks (acceptable on
MI300X with 192 GB; if memory-constrained, group tasks by
num_classes and share encoder via PEFT adapter switching within a
group).
"""
def __init__(self, adapters_root: Path | str = "adapters/",
device: str | None = None):
self.adapters_root = Path(adapters_root)
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self._adapters: dict[str, AdapterSlot] = {}
self._active_adapter: str | None = None
# ---- adapter discovery + caching --------------------------------------
def discover(self) -> list[str]:
"""Scan adapters_root and cache LoRA + decoder weights into RAM."""
names = []
for cfg_path in sorted(self.adapters_root.glob("*/config.yaml")):
cfg = yaml.safe_load(cfg_path.read_text())
output_dir = (Path(cfg.get("output_dir", cfg_path.parent / "output"))
.resolve())
adapter_path = cfg_path.parent / "output"
if not (adapter_path / "adapter_model.safetensors").exists():
# Fall back to absolute output_dir from config (e.g. droplet path).
adapter_path = output_dir
if not (adapter_path / "adapter_model.safetensors").exists():
continue
slot = AdapterSlot(
name=cfg.get("task_name", cfg_path.parent.name),
config=cfg,
lora_state=load_file(adapter_path / "adapter_model.safetensors"),
head_state=load_file(adapter_path / "decoder_head.safetensors"),
num_classes=cfg["num_classes"],
)
self._adapters[slot.name] = slot
names.append(slot.name)
return names
# ---- swap + inference -------------------------------------------------
def _build_slot_task(self, slot: AdapterSlot):
"""Build a Lightning task for this adapter, restore weights."""
if slot.task is not None:
return
task = build_task(slot.config).to(self.device).eval()
model = task.model
enc_state = {k.removeprefix("encoder."): v.to(self.device)
for k, v in slot.lora_state.items()
if k.startswith("encoder.")}
model.encoder.load_state_dict(enc_state, strict=False)
head_grouped: dict[str, dict] = {}
for k, v in slot.head_state.items():
sub, _, rest = k.partition(".")
head_grouped.setdefault(sub, {})[rest] = v.to(self.device)
for sub, state in head_grouped.items():
m = getattr(model, sub, None)
if m is None:
continue
m.load_state_dict(state, strict=False)
slot.task = task
def set_adapter(self, name: str):
if name == self._active_adapter:
return
if name not in self._adapters:
raise KeyError(f"adapter {name!r} not loaded; "
f"available: {list(self._adapters)}")
self._build_slot_task(self._adapters[name])
self._active_adapter = name
@property
def _task(self):
"""Convenience accessor for the currently active adapter's task."""
if self._active_adapter is None:
return None
return self._adapters[self._active_adapter].task
@torch.no_grad()
def infer(self, *, s2l2a: torch.Tensor,
s1rtc: torch.Tensor | None = None,
dem: torch.Tensor | None = None,
tasks: list[str]) -> dict[str, torch.Tensor]:
"""Run multiple tasks against the same input chip.
Each tensor: [C, T, H, W] (temporal mode) or [C, H, W] (static).
Outputs: dict {task_name: argmax-class map [H, W] long}.
"""
out = {}
# Add a batch dim if the user passed unbatched input.
def _b(t):
if t is None:
return None
return t.unsqueeze(0) if t.dim() in (3, 4) else t
x = {"S2L2A": _b(s2l2a).to(self.device)}
if s1rtc is not None:
x["S1RTC"] = _b(s1rtc).to(self.device)
if dem is not None:
x["DEM"] = _b(dem).to(self.device)
for task_name in tasks:
self.set_adapter(task_name)
res = self._task.model(x)
logits = res.output if hasattr(res, "output") else res
preds = logits.argmax(dim=1).squeeze(0).cpu()
out[task_name] = preds
return out
def memory_estimate_gb(self) -> float:
n_built = sum(1 for s in self._adapters.values() if s.task is not None)
# Each task is ~168 M params @ fp32 = ~672 MB, fp16 = ~336 MB.
return n_built * 0.336
# ---- diagnostics ------------------------------------------------------
def info(self) -> dict:
return {
"device": self.device,
"loaded_adapters": list(self._adapters),
"active_adapter": self._active_adapter,
"base_built": self._task is not None,
}
if __name__ == "__main__":
# Smoke check.
ens = TerraMindNYCEnsemble("adapters/")
names = ens.discover()
print(f"Discovered adapters: {names}")
if not names:
sys.exit("No adapters found. Train at least one before using the "
"ensemble.")
print(ens.info())