seriffic Claude Opus 4.7 (1M context) commited on
Commit
fee1c30
·
1 Parent(s): 7cb5930

fix: thread Mellea attempt index + diagnose riprap-models 500s

Browse files

Two regressions surfaced on the May 9 demo. Fix both.

1. Mellea reroll text concatenated in the briefing prose
============================================================
The strict-streaming reconcile path on a Mellea reroll fired
token events for the new attempt without telling the frontend
the attempt index changed. Result: the SvelteKit briefing
buffer never reset between attempts, so attempt 2 + attempt 3
text appended onto the already-rendered attempt 1 text.

Root cause: app/fsm.py:step_reconcile installed a token
forwarder that explicitly dropped the attempt_idx
(`lambda d, _ai: token_cb(d)`), and
app/intents/single_address.py:_on_token took only `delta`. So
the SSE `token` events carried no `attempt` field, and
web/sveltekit/src/lib/client/agentStream.ts:onAttemptStart
never fired because `d.attempt !== currentAttempt` stayed
undefined === undefined.

Fix:
- fsm.py: forward (delta, attempt_idx) to token_cb. Probe-fall
back to the 1-arg call for legacy callbacks via TypeError so
non-strict reconcilers keep working.
- single_address.py: _on_token(delta, attempt_idx=0) emits
`{kind: token, delta, attempt: attempt_idx + 1}` so the
client gets the 1-based attempt counter it already expects
(neighborhood.py + development_check.py already do this).

2. terramind / prithvi-pluvial 500s with no diagnostic detail
============================================================
The riprap-models routes raised through to FastAPI's default
handler, returning the opaque body
`{"detail": "Internal Server Error"}`. The lablab UI's
inference._post then surfaced this as
`remote terramind/lulc unreachable: HTTP 500 from /v1/terramind`,
correct but with no actionable detail about WHAT failed inside
the model service.

services/riprap-models/main.py:
- New _safe_route() wrapper: returns
`{"ok": False, "err": "<type>: <msg>", "stage": "<endpoint>"}`
with HTTP 200 instead of 500. The proxy on :7860 forwards
this body untouched so the FSM trace card now reads, e.g.,
`remote terramind/lulc non-ok: torch.cuda.OutOfMemoryError: ...`
instead of a generic Internal Server Error.
- Lifespan startup warms every heavy model (Prithvi, all three
TerraMind paths, GLiNER, Granite Embedding) before traffic is
accepted, so the first user query doesn't compete with
vLLM's CUDA-graph compile for memory bandwidth. Best-effort
per stage; failures are recorded into _LAST_ERR and do not
block startup.
- New GET /v1/diag (auth-required) snapshots loaded models,
CUDA memory state per device, and last-error per stage with
a 3-line traceback tail. Operators can hit it from outside
the Space without grepping container logs.
- /healthz also exposes last_errors.

prithvi_live.py: surfaced-error matcher now checks `err`,
`error`, and `skipped` so the new wrapped 200-with-body shape
propagates cleanly into the trace.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

app/flood_layers/prithvi_live.py CHANGED
@@ -445,9 +445,12 @@ def _fetch_inner(lat: float, lon: float, timeout_s: float) -> dict[str, Any]:
445
  "compute": f"remote · {remote.get('device', 'gpu')}",
446
  "elapsed_s": round(time.time() - t0, 2),
447
  }
 
 
 
 
448
  return {"ok": False,
449
- "skipped": f"remote prithvi-pluvial non-ok: "
450
- f"{remote.get('error') or 'unknown'}",
451
  "elapsed_s": round(time.time() - t0, 2)}
452
  except _inf.RemoteUnreachable as e:
453
  log.info("prithvi_live: remote unreachable (%s)", e)
 
445
  "compute": f"remote · {remote.get('device', 'gpu')}",
446
  "elapsed_s": round(time.time() - t0, 2),
447
  }
448
+ err = (remote.get("err")
449
+ or remote.get("error")
450
+ or remote.get("skipped")
451
+ or "unknown")
452
  return {"ok": False,
453
+ "skipped": f"remote prithvi-pluvial non-ok: {err}",
 
454
  "elapsed_s": round(time.time() - t0, 2)}
455
  except _inf.RemoteUnreachable as e:
456
  log.info("prithvi_live: remote unreachable (%s)", e)
app/fsm.py CHANGED
@@ -1015,11 +1015,24 @@ def step_reconcile(state: State) -> State:
1015
  query=_current_user_query() or state.get("query") or "",
1016
  intent=_current_planner_intent() or "single_address",
1017
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1018
  mres = reconcile_strict_streaming(
1019
  doc_msgs, framed_prompt,
1020
  user_prompt="Write the cited paragraph now.",
1021
  loop_budget=DEFAULT_LOOP_BUDGET,
1022
- on_token=(lambda d, _ai: token_cb(d)) if token_cb else None,
1023
  on_attempt_end=attempt_cb,
1024
  )
1025
  para = mres["paragraph"]
 
1015
  query=_current_user_query() or state.get("query") or "",
1016
  intent=_current_planner_intent() or "single_address",
1017
  )
1018
+ # Forward the (delta, attempt_idx) pair through. Older
1019
+ # token_cb signatures were single-arg; we detect by
1020
+ # introspecting the callable's expected positional count
1021
+ # so single_address.py's old shape still works while new
1022
+ # callbacks see the attempt index they need to clear the
1023
+ # frontend buffer on a Mellea reroll.
1024
+ def _fwd_token(delta: str, attempt_idx: int) -> None:
1025
+ if token_cb is None:
1026
+ return
1027
+ try:
1028
+ token_cb(delta, attempt_idx)
1029
+ except TypeError:
1030
+ token_cb(delta)
1031
  mres = reconcile_strict_streaming(
1032
  doc_msgs, framed_prompt,
1033
  user_prompt="Write the cited paragraph now.",
1034
  loop_budget=DEFAULT_LOOP_BUDGET,
1035
+ on_token=_fwd_token if token_cb else None,
1036
  on_attempt_end=attempt_cb,
1037
  )
1038
  para = mres["paragraph"]
app/intents/single_address.py CHANGED
@@ -51,8 +51,15 @@ def run(plan, query: str, progress_q=None, strict: bool = False) -> dict:
51
  set_user_query(query)
52
  set_planner_intent(plan.intent)
53
  if progress_q is not None:
54
- def _on_token(delta: str):
55
- progress_q.put({"kind": "token", "delta": delta})
 
 
 
 
 
 
 
56
  def _on_mellea_attempt(attempt_idx, passed, failed):
57
  progress_q.put({"kind": "mellea_attempt",
58
  "attempt": attempt_idx,
 
51
  set_user_query(query)
52
  set_planner_intent(plan.intent)
53
  if progress_q is not None:
54
+ def _on_token(delta: str, attempt_idx: int = 0):
55
+ # `attempt_idx` is the 0-based Mellea reroll index. The
56
+ # SvelteKit client treats a change in this value as a
57
+ # signal to clear the live briefing buffer (per
58
+ # web/sveltekit/src/lib/client/agentStream.ts:onAttemptStart).
59
+ # We surface it as a 1-based attempt counter so the chip
60
+ # in the UI reads "attempt N" naturally.
61
+ progress_q.put({"kind": "token", "delta": delta,
62
+ "attempt": attempt_idx + 1})
63
  def _on_mellea_attempt(attempt_idx, passed, failed):
64
  progress_q.put({"kind": "mellea_attempt",
65
  "attempt": attempt_idx,
services/riprap-models/main.py CHANGED
@@ -707,43 +707,124 @@ def _gliner_extract(payload: GlinerIn) -> dict[str, Any]:
707
 
708
  # ---- FastAPI app ------------------------------------------------------------
709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
  @asynccontextmanager
711
  async def lifespan(_app: FastAPI):
712
  log.info("riprap-models starting on device=%s auth=%s",
713
  _DEVICE, "yes" if _AUTH_TOKEN else "no")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  yield
715
  log.info("riprap-models stopping")
716
 
717
 
718
- app = FastAPI(title="riprap-models", version="0.4.5", lifespan=lifespan)
719
 
720
 
721
  @app.get("/healthz")
722
  def healthz():
723
  return {"ok": True, "device": _DEVICE,
724
- "models_loaded": sorted(_INSTANCES.keys())}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
 
726
 
727
  @app.post("/v1/prithvi-pluvial", dependencies=[Depends(_require_auth)])
728
  def prithvi_pluvial_route(payload: PrithviIn):
729
- return _prithvi_pluvial(payload)
730
 
731
 
732
  @app.post("/v1/terramind", dependencies=[Depends(_require_auth)])
733
  def terramind_route(payload: TerramindIn):
734
- return _terramind_inference(payload)
 
735
 
736
 
737
  @app.post("/v1/ttm-forecast", dependencies=[Depends(_require_auth)])
738
  def ttm_forecast_route(payload: TtmIn):
739
- return _ttm_forecast(payload)
740
 
741
 
742
  @app.post("/v1/granite-embed", dependencies=[Depends(_require_auth)])
743
  def granite_embed_route(payload: EmbedIn):
744
- return _granite_embed(payload)
745
 
746
 
747
  @app.post("/v1/gliner-extract", dependencies=[Depends(_require_auth)])
748
  def gliner_extract_route(payload: GlinerIn):
749
- return _gliner_extract(payload)
 
707
 
708
  # ---- FastAPI app ------------------------------------------------------------
709
 
710
+ # Last error per route, kept on the in-memory map so /v1/diag can
711
+ # expose it without forcing the operator to grep container logs.
712
+ _LAST_ERR: dict[str, dict[str, Any]] = {}
713
+
714
+
715
+ def _safe_route(stage: str, fn, payload):
716
+ """Wrap a route body so an uncaught exception becomes a structured
717
+ `{"ok": False, "err": "...", "stage": "..."}` JSON response with
718
+ HTTP 200 instead of FastAPI's opaque "Internal Server Error" body.
719
+
720
+ The proxy on :7860 forwards this body untouched, so the FSM
721
+ specialist surfaces the real reason in the trace card. Logs the
722
+ full traceback to stderr so operators can still root-cause from
723
+ the Space's runtime logs."""
724
+ try:
725
+ return fn(payload)
726
+ except HTTPException:
727
+ raise
728
+ except Exception as e: # noqa: BLE001
729
+ import traceback
730
+ tb = traceback.format_exc()
731
+ log.error("route %s failed: %s\n%s", stage, e, tb)
732
+ info = {
733
+ "ok": False,
734
+ "err": f"{type(e).__name__}: {e}",
735
+ "stage": stage,
736
+ "ts": time.time(),
737
+ }
738
+ _LAST_ERR[stage] = {**info, "traceback_tail": tb.splitlines()[-3:]}
739
+ return info
740
+
741
+
742
  @asynccontextmanager
743
  async def lifespan(_app: FastAPI):
744
  log.info("riprap-models starting on device=%s auth=%s",
745
  _DEVICE, "yes" if _AUTH_TOKEN else "no")
746
+ # Pre-load the heavy models so the first user request doesn't
747
+ # collide with a cold-load on the same GPU as vLLM. Each warm
748
+ # is best-effort: a single model failing must not block the
749
+ # service from starting (others may still serve).
750
+ if os.environ.get("RIPRAP_MODELS_WARM_AT_STARTUP", "1").lower() in ("1", "true", "yes"):
751
+ for stage, fn in (
752
+ ("warm/prithvi", _load_prithvi),
753
+ ("warm/terramind_synthesis", _load_terramind_synthesis),
754
+ ("warm/terramind_lulc", lambda: _load_terramind("lulc")),
755
+ ("warm/terramind_buildings", lambda: _load_terramind("buildings")),
756
+ ("warm/embed", _load_embed),
757
+ ("warm/gliner", _load_gliner),
758
+ ):
759
+ try:
760
+ fn()
761
+ log.info("startup %s ok", stage)
762
+ except Exception as e: # noqa: BLE001
763
+ log.exception("startup %s failed: %s", stage, e)
764
+ _LAST_ERR[stage] = {"ok": False,
765
+ "err": f"{type(e).__name__}: {e}",
766
+ "stage": stage}
767
  yield
768
  log.info("riprap-models stopping")
769
 
770
 
771
+ app = FastAPI(title="riprap-models", version="0.5.1", lifespan=lifespan)
772
 
773
 
774
  @app.get("/healthz")
775
  def healthz():
776
  return {"ok": True, "device": _DEVICE,
777
+ "models_loaded": sorted(_INSTANCES.keys()),
778
+ "last_errors": _LAST_ERR}
779
+
780
+
781
+ @app.get("/v1/diag", dependencies=[Depends(_require_auth)])
782
+ def diag():
783
+ """Operator-only diagnostic snapshot — what's loaded, last
784
+ per-stage error (with a 3-line traceback tail), and CUDA
785
+ visibility. The proxy forwards this through the catch-all so
786
+ operators can hit it from outside the Space."""
787
+ cuda = {"available": False, "devices": []}
788
+ try:
789
+ import torch
790
+ cuda["available"] = bool(torch.cuda.is_available())
791
+ if cuda["available"]:
792
+ cuda["devices"] = [{
793
+ "name": torch.cuda.get_device_name(i),
794
+ "mem_total_mb": torch.cuda.get_device_properties(i).total_memory // (1024 * 1024),
795
+ "mem_alloc_mb": torch.cuda.memory_allocated(i) // (1024 * 1024),
796
+ } for i in range(torch.cuda.device_count())]
797
+ except Exception as e: # noqa: BLE001
798
+ cuda["err"] = f"{type(e).__name__}: {e}"
799
+ return {
800
+ "device": _DEVICE,
801
+ "models_loaded": sorted(_INSTANCES.keys()),
802
+ "last_errors": _LAST_ERR,
803
+ "cuda": cuda,
804
+ }
805
 
806
 
807
  @app.post("/v1/prithvi-pluvial", dependencies=[Depends(_require_auth)])
808
  def prithvi_pluvial_route(payload: PrithviIn):
809
+ return _safe_route("prithvi-pluvial", _prithvi_pluvial, payload)
810
 
811
 
812
  @app.post("/v1/terramind", dependencies=[Depends(_require_auth)])
813
  def terramind_route(payload: TerramindIn):
814
+ return _safe_route(f"terramind/{payload.adapter}",
815
+ _terramind_inference, payload)
816
 
817
 
818
  @app.post("/v1/ttm-forecast", dependencies=[Depends(_require_auth)])
819
  def ttm_forecast_route(payload: TtmIn):
820
+ return _safe_route("ttm-forecast", _ttm_forecast, payload)
821
 
822
 
823
  @app.post("/v1/granite-embed", dependencies=[Depends(_require_auth)])
824
  def granite_embed_route(payload: EmbedIn):
825
+ return _safe_route("granite-embed", _granite_embed, payload)
826
 
827
 
828
  @app.post("/v1/gliner-extract", dependencies=[Depends(_require_auth)])
829
  def gliner_extract_route(payload: GlinerIn):
830
+ return _safe_route("gliner-extract", _gliner_extract, payload)