bitsofchris Claude Opus 4.7 (1M context) commited on
Commit
b620a9c
·
1 Parent(s): cb00551

Pad short series for Toto + persist forecast log to HF Dataset

Browse files

- src/forecast.py: when context is shorter than one Toto patch (32), left-
pad with the first value and set target_mask False on padded steps so
the model ignores them. Fixes the 4-hour cadence dropdown error when the
station has fewer than 32 four-hour points (~5 days of data).
- src/persist.py: pull/push the forecast SQLite to a private HF Dataset
(default bitsofchris/toto-weather-forecast-log) so the scoreboard
survives Space rebuilds. Push is locked + coalesced (60s min interval),
fired async after each refresh; pull runs once at startup.
- app.py: call persist.pull_db() on launch and persist.push_db_async()
at the end of refresh().

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +5 -1
  2. src/forecast.py +19 -10
  3. src/persist.py +103 -0
app.py CHANGED
@@ -17,7 +17,7 @@ from datetime import datetime, timedelta, timezone
17
  import gradio as gr
18
  import pandas as pd
19
 
20
- from src import ecowitt, forecast_log, nws
21
  from src.forecast import forecast_series
22
  from src.weather_ui import (
23
  combined_figure,
@@ -154,6 +154,9 @@ def refresh(cycle_label: str = "Hourly", horizon_label: str = "24 h"):
154
  strip = emoji_strip_markdown(nws_df_raw, DISPLAY_TZ, n=12)
155
  scoreboard = render_scoreboard(log_conn)
156
 
 
 
 
157
  return hero, toto_md, nws_md, strip, fig, scoreboard
158
 
159
 
@@ -252,5 +255,6 @@ with gr.Blocks(title="Toto Weather Forecast", theme=gr.themes.Soft()) as demo:
252
 
253
 
254
  if __name__ == "__main__":
 
255
  _start_autorefresh()
256
  demo.launch()
 
17
  import gradio as gr
18
  import pandas as pd
19
 
20
+ from src import ecowitt, forecast_log, nws, persist
21
  from src.forecast import forecast_series
22
  from src.weather_ui import (
23
  combined_figure,
 
154
  strip = emoji_strip_markdown(nws_df_raw, DISPLAY_TZ, n=12)
155
  scoreboard = render_scoreboard(log_conn)
156
 
157
+ # Backup the SQLite log to the HF dataset (non-blocking).
158
+ persist.push_db_async()
159
+
160
  return hero, toto_md, nws_md, strip, fig, scoreboard
161
 
162
 
 
255
 
256
 
257
  if __name__ == "__main__":
258
+ persist.pull_db() # bootstrap the forecast log from the HF Dataset
259
  _start_autorefresh()
260
  demo.launch()
src/forecast.py CHANGED
@@ -87,23 +87,32 @@ def forecast_series(
87
  if series.empty:
88
  raise ValueError("Cannot forecast an empty series")
89
 
 
 
90
  clean = series.astype(float).interpolate(limit_direction="both")
91
 
92
  # Toto requires the context length to be a multiple of the model's
93
- # patch_size (32 for Toto-2.0-4m). Truncate the oldest points to fit.
 
 
 
94
  model = load_model(model_id, device=device)
95
  patch = int(model.config.patch_size)
96
- n = (len(clean) // patch) * patch
97
- if n < patch:
98
- raise ValueError(
99
- f"Need at least {patch} points (got {len(clean)}); "
100
- "fetch a longer history window."
101
- )
102
- clean = clean.iloc[-n:]
103
- arr = clean.to_numpy(dtype=np.float32)
 
 
 
 
104
 
105
  target = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # (1, 1, T)
106
- target_mask = torch.ones_like(target, dtype=torch.bool)
107
  series_ids = torch.zeros(1, 1, dtype=torch.long)
108
 
109
  target = target.to(device)
 
87
  if series.empty:
88
  raise ValueError("Cannot forecast an empty series")
89
 
90
+ import numpy as np # noqa: PLC0415
91
+
92
  clean = series.astype(float).interpolate(limit_direction="both")
93
 
94
  # Toto requires the context length to be a multiple of the model's
95
+ # patch_size (32 for Toto-2.0-4m). If we have at least one full patch,
96
+ # truncate the oldest points to fit. If we have fewer, left-pad with the
97
+ # first value and mark the padded region False in the mask so Toto
98
+ # ignores it.
99
  model = load_model(model_id, device=device)
100
  patch = int(model.config.patch_size)
101
+ raw = clean.to_numpy(dtype=np.float32)
102
+ n_raw = len(raw)
103
+
104
+ if n_raw >= patch:
105
+ n = (n_raw // patch) * patch
106
+ arr = raw[-n:]
107
+ mask_vec = np.ones(n, dtype=bool)
108
+ else:
109
+ n = patch
110
+ pad = n - n_raw
111
+ arr = np.concatenate([np.full(pad, raw[0], dtype=np.float32), raw])
112
+ mask_vec = np.concatenate([np.zeros(pad, dtype=bool), np.ones(n_raw, dtype=bool)])
113
 
114
  target = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # (1, 1, T)
115
+ target_mask = torch.from_numpy(mask_vec).unsqueeze(0).unsqueeze(0)
116
  series_ids = torch.zeros(1, 1, dtype=torch.long)
117
 
118
  target = target.to(device)
src/persist.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Persist the forecast SQLite DB across HF Space rebuilds.
2
+
3
+ HF Spaces' free tier has ephemeral storage — every `git push` rebuilds the
4
+ container and wipes any local files. We back the forecast log with a
5
+ private HF Dataset:
6
+
7
+ - On startup: pull the latest forecasts.db from the dataset (if any).
8
+ - After every refresh: push the current forecasts.db back.
9
+
10
+ Environment:
11
+ HF_TOKEN must have write access to the dataset
12
+ LOG_DATASET_REPO override the default dataset repo id
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import shutil
19
+ import threading
20
+ import time
21
+ import traceback
22
+
23
+ DEFAULT_REPO = "bitsofchris/toto-weather-forecast-log"
24
+ PATH_IN_REPO = "forecasts.db"
25
+ DEFAULT_LOCAL = "data/forecasts.db"
26
+
27
+ _push_lock = threading.Lock()
28
+ _last_push_at = 0.0
29
+ PUSH_MIN_INTERVAL = 60.0 # seconds — coalesce rapid pushes
30
+
31
+
32
+ def _repo_id() -> str:
33
+ return os.environ.get("LOG_DATASET_REPO", DEFAULT_REPO)
34
+
35
+
36
+ def _token() -> str | None:
37
+ return os.environ.get("HF_TOKEN")
38
+
39
+
40
+ def pull_db(local_path: str = DEFAULT_LOCAL) -> bool:
41
+ """Download the latest DB from the dataset, overwriting any local copy.
42
+ Returns True on success."""
43
+ tok = _token()
44
+ if not tok:
45
+ print("[persist] HF_TOKEN not set — skipping pull")
46
+ return False
47
+ try:
48
+ from huggingface_hub import hf_hub_download # noqa: PLC0415
49
+ downloaded = hf_hub_download(
50
+ repo_id=_repo_id(),
51
+ repo_type="dataset",
52
+ filename=PATH_IN_REPO,
53
+ token=tok,
54
+ )
55
+ os.makedirs(os.path.dirname(local_path) or ".", exist_ok=True)
56
+ shutil.copyfile(downloaded, local_path)
57
+ print(f"[persist] pulled DB from {_repo_id()} ({os.path.getsize(local_path)} bytes)")
58
+ return True
59
+ except Exception: # noqa: BLE001
60
+ print(f"[persist] pull skipped (no remote DB or network error):")
61
+ traceback.print_exc()
62
+ return False
63
+
64
+
65
+ def push_db(local_path: str = DEFAULT_LOCAL) -> bool:
66
+ """Upload the local DB to the dataset. Coalesced and lock-protected so
67
+ overlapping refreshes don't issue redundant uploads."""
68
+ global _last_push_at
69
+ tok = _token()
70
+ if not tok or not os.path.exists(local_path):
71
+ return False
72
+
73
+ # Coalesce: if we just pushed, skip.
74
+ if time.time() - _last_push_at < PUSH_MIN_INTERVAL:
75
+ return False
76
+ if not _push_lock.acquire(blocking=False):
77
+ return False
78
+ try:
79
+ from huggingface_hub import HfApi # noqa: PLC0415
80
+ api = HfApi(token=tok)
81
+ api.upload_file(
82
+ path_or_fileobj=local_path,
83
+ path_in_repo=PATH_IN_REPO,
84
+ repo_id=_repo_id(),
85
+ repo_type="dataset",
86
+ commit_message="forecast log update",
87
+ )
88
+ _last_push_at = time.time()
89
+ print(f"[persist] pushed DB to {_repo_id()} ({os.path.getsize(local_path)} bytes)")
90
+ return True
91
+ except Exception: # noqa: BLE001
92
+ print("[persist] push failed:")
93
+ traceback.print_exc()
94
+ return False
95
+ finally:
96
+ _push_lock.release()
97
+
98
+
99
+ def push_db_async(local_path: str = DEFAULT_LOCAL) -> None:
100
+ """Fire-and-forget push so refresh() returns to the user immediately."""
101
+ threading.Thread(
102
+ target=push_db, args=(local_path,), daemon=True, name="persist-push"
103
+ ).start()