Commit ·
b620a9c
1
Parent(s): cb00551
Pad short series for Toto + persist forecast log to HF Dataset
Browse files- src/forecast.py: when context is shorter than one Toto patch (32), left-
pad with the first value and set target_mask False on padded steps so
the model ignores them. Fixes the 4-hour cadence dropdown error when the
station has fewer than 32 four-hour points (~5 days of data).
- src/persist.py: pull/push the forecast SQLite to a private HF Dataset
(default bitsofchris/toto-weather-forecast-log) so the scoreboard
survives Space rebuilds. Push is locked + coalesced (60s min interval),
fired async after each refresh; pull runs once at startup.
- app.py: call persist.pull_db() on launch and persist.push_db_async()
at the end of refresh().
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- app.py +5 -1
- src/forecast.py +19 -10
- src/persist.py +103 -0
app.py
CHANGED
|
@@ -17,7 +17,7 @@ from datetime import datetime, timedelta, timezone
|
|
| 17 |
import gradio as gr
|
| 18 |
import pandas as pd
|
| 19 |
|
| 20 |
-
from src import ecowitt, forecast_log, nws
|
| 21 |
from src.forecast import forecast_series
|
| 22 |
from src.weather_ui import (
|
| 23 |
combined_figure,
|
|
@@ -154,6 +154,9 @@ def refresh(cycle_label: str = "Hourly", horizon_label: str = "24 h"):
|
|
| 154 |
strip = emoji_strip_markdown(nws_df_raw, DISPLAY_TZ, n=12)
|
| 155 |
scoreboard = render_scoreboard(log_conn)
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
return hero, toto_md, nws_md, strip, fig, scoreboard
|
| 158 |
|
| 159 |
|
|
@@ -252,5 +255,6 @@ with gr.Blocks(title="Toto Weather Forecast", theme=gr.themes.Soft()) as demo:
|
|
| 252 |
|
| 253 |
|
| 254 |
if __name__ == "__main__":
|
|
|
|
| 255 |
_start_autorefresh()
|
| 256 |
demo.launch()
|
|
|
|
| 17 |
import gradio as gr
|
| 18 |
import pandas as pd
|
| 19 |
|
| 20 |
+
from src import ecowitt, forecast_log, nws, persist
|
| 21 |
from src.forecast import forecast_series
|
| 22 |
from src.weather_ui import (
|
| 23 |
combined_figure,
|
|
|
|
| 154 |
strip = emoji_strip_markdown(nws_df_raw, DISPLAY_TZ, n=12)
|
| 155 |
scoreboard = render_scoreboard(log_conn)
|
| 156 |
|
| 157 |
+
# Backup the SQLite log to the HF dataset (non-blocking).
|
| 158 |
+
persist.push_db_async()
|
| 159 |
+
|
| 160 |
return hero, toto_md, nws_md, strip, fig, scoreboard
|
| 161 |
|
| 162 |
|
|
|
|
| 255 |
|
| 256 |
|
| 257 |
if __name__ == "__main__":
|
| 258 |
+
persist.pull_db() # bootstrap the forecast log from the HF Dataset
|
| 259 |
_start_autorefresh()
|
| 260 |
demo.launch()
|
src/forecast.py
CHANGED
|
@@ -87,23 +87,32 @@ def forecast_series(
|
|
| 87 |
if series.empty:
|
| 88 |
raise ValueError("Cannot forecast an empty series")
|
| 89 |
|
|
|
|
|
|
|
| 90 |
clean = series.astype(float).interpolate(limit_direction="both")
|
| 91 |
|
| 92 |
# Toto requires the context length to be a multiple of the model's
|
| 93 |
-
# patch_size (32 for Toto-2.0-4m).
|
|
|
|
|
|
|
|
|
|
| 94 |
model = load_model(model_id, device=device)
|
| 95 |
patch = int(model.config.patch_size)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
target = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # (1, 1, T)
|
| 106 |
-
target_mask = torch.
|
| 107 |
series_ids = torch.zeros(1, 1, dtype=torch.long)
|
| 108 |
|
| 109 |
target = target.to(device)
|
|
|
|
| 87 |
if series.empty:
|
| 88 |
raise ValueError("Cannot forecast an empty series")
|
| 89 |
|
| 90 |
+
import numpy as np # noqa: PLC0415
|
| 91 |
+
|
| 92 |
clean = series.astype(float).interpolate(limit_direction="both")
|
| 93 |
|
| 94 |
# Toto requires the context length to be a multiple of the model's
|
| 95 |
+
# patch_size (32 for Toto-2.0-4m). If we have at least one full patch,
|
| 96 |
+
# truncate the oldest points to fit. If we have fewer, left-pad with the
|
| 97 |
+
# first value and mark the padded region False in the mask so Toto
|
| 98 |
+
# ignores it.
|
| 99 |
model = load_model(model_id, device=device)
|
| 100 |
patch = int(model.config.patch_size)
|
| 101 |
+
raw = clean.to_numpy(dtype=np.float32)
|
| 102 |
+
n_raw = len(raw)
|
| 103 |
+
|
| 104 |
+
if n_raw >= patch:
|
| 105 |
+
n = (n_raw // patch) * patch
|
| 106 |
+
arr = raw[-n:]
|
| 107 |
+
mask_vec = np.ones(n, dtype=bool)
|
| 108 |
+
else:
|
| 109 |
+
n = patch
|
| 110 |
+
pad = n - n_raw
|
| 111 |
+
arr = np.concatenate([np.full(pad, raw[0], dtype=np.float32), raw])
|
| 112 |
+
mask_vec = np.concatenate([np.zeros(pad, dtype=bool), np.ones(n_raw, dtype=bool)])
|
| 113 |
|
| 114 |
target = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # (1, 1, T)
|
| 115 |
+
target_mask = torch.from_numpy(mask_vec).unsqueeze(0).unsqueeze(0)
|
| 116 |
series_ids = torch.zeros(1, 1, dtype=torch.long)
|
| 117 |
|
| 118 |
target = target.to(device)
|
src/persist.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Persist the forecast SQLite DB across HF Space rebuilds.
|
| 2 |
+
|
| 3 |
+
HF Spaces' free tier has ephemeral storage — every `git push` rebuilds the
|
| 4 |
+
container and wipes any local files. We back the forecast log with a
|
| 5 |
+
private HF Dataset:
|
| 6 |
+
|
| 7 |
+
- On startup: pull the latest forecasts.db from the dataset (if any).
|
| 8 |
+
- After every refresh: push the current forecasts.db back.
|
| 9 |
+
|
| 10 |
+
Environment:
|
| 11 |
+
HF_TOKEN must have write access to the dataset
|
| 12 |
+
LOG_DATASET_REPO override the default dataset repo id
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import shutil
|
| 19 |
+
import threading
|
| 20 |
+
import time
|
| 21 |
+
import traceback
|
| 22 |
+
|
| 23 |
+
DEFAULT_REPO = "bitsofchris/toto-weather-forecast-log"
|
| 24 |
+
PATH_IN_REPO = "forecasts.db"
|
| 25 |
+
DEFAULT_LOCAL = "data/forecasts.db"
|
| 26 |
+
|
| 27 |
+
_push_lock = threading.Lock()
|
| 28 |
+
_last_push_at = 0.0
|
| 29 |
+
PUSH_MIN_INTERVAL = 60.0 # seconds — coalesce rapid pushes
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _repo_id() -> str:
|
| 33 |
+
return os.environ.get("LOG_DATASET_REPO", DEFAULT_REPO)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _token() -> str | None:
|
| 37 |
+
return os.environ.get("HF_TOKEN")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def pull_db(local_path: str = DEFAULT_LOCAL) -> bool:
|
| 41 |
+
"""Download the latest DB from the dataset, overwriting any local copy.
|
| 42 |
+
Returns True on success."""
|
| 43 |
+
tok = _token()
|
| 44 |
+
if not tok:
|
| 45 |
+
print("[persist] HF_TOKEN not set — skipping pull")
|
| 46 |
+
return False
|
| 47 |
+
try:
|
| 48 |
+
from huggingface_hub import hf_hub_download # noqa: PLC0415
|
| 49 |
+
downloaded = hf_hub_download(
|
| 50 |
+
repo_id=_repo_id(),
|
| 51 |
+
repo_type="dataset",
|
| 52 |
+
filename=PATH_IN_REPO,
|
| 53 |
+
token=tok,
|
| 54 |
+
)
|
| 55 |
+
os.makedirs(os.path.dirname(local_path) or ".", exist_ok=True)
|
| 56 |
+
shutil.copyfile(downloaded, local_path)
|
| 57 |
+
print(f"[persist] pulled DB from {_repo_id()} ({os.path.getsize(local_path)} bytes)")
|
| 58 |
+
return True
|
| 59 |
+
except Exception: # noqa: BLE001
|
| 60 |
+
print(f"[persist] pull skipped (no remote DB or network error):")
|
| 61 |
+
traceback.print_exc()
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def push_db(local_path: str = DEFAULT_LOCAL) -> bool:
|
| 66 |
+
"""Upload the local DB to the dataset. Coalesced and lock-protected so
|
| 67 |
+
overlapping refreshes don't issue redundant uploads."""
|
| 68 |
+
global _last_push_at
|
| 69 |
+
tok = _token()
|
| 70 |
+
if not tok or not os.path.exists(local_path):
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
# Coalesce: if we just pushed, skip.
|
| 74 |
+
if time.time() - _last_push_at < PUSH_MIN_INTERVAL:
|
| 75 |
+
return False
|
| 76 |
+
if not _push_lock.acquire(blocking=False):
|
| 77 |
+
return False
|
| 78 |
+
try:
|
| 79 |
+
from huggingface_hub import HfApi # noqa: PLC0415
|
| 80 |
+
api = HfApi(token=tok)
|
| 81 |
+
api.upload_file(
|
| 82 |
+
path_or_fileobj=local_path,
|
| 83 |
+
path_in_repo=PATH_IN_REPO,
|
| 84 |
+
repo_id=_repo_id(),
|
| 85 |
+
repo_type="dataset",
|
| 86 |
+
commit_message="forecast log update",
|
| 87 |
+
)
|
| 88 |
+
_last_push_at = time.time()
|
| 89 |
+
print(f"[persist] pushed DB to {_repo_id()} ({os.path.getsize(local_path)} bytes)")
|
| 90 |
+
return True
|
| 91 |
+
except Exception: # noqa: BLE001
|
| 92 |
+
print("[persist] push failed:")
|
| 93 |
+
traceback.print_exc()
|
| 94 |
+
return False
|
| 95 |
+
finally:
|
| 96 |
+
_push_lock.release()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def push_db_async(local_path: str = DEFAULT_LOCAL) -> None:
|
| 100 |
+
"""Fire-and-forget push so refresh() returns to the user immediately."""
|
| 101 |
+
threading.Thread(
|
| 102 |
+
target=push_db, args=(local_path,), daemon=True, name="persist-push"
|
| 103 |
+
).start()
|