Harley-ml commited on
Commit
de07f57
·
verified ·
1 Parent(s): ca783e4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +629 -3
README.md CHANGED
@@ -1,3 +1,629 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - forcast
5
+ - weather
6
+ - lstm
7
+ - classification
8
+ - regression
9
+ - weather-forcast
10
+ - multitask
11
+ - harley-ml
12
+ ---
13
+
14
+ # Hweh-6M
15
+
16
+ Hweh-6M is a **6 million parameter LSTM** trained to predict the next **12 hours of weather**, including temperature, humidity, pressure, precipitation, and more, using the previous **72 hours of weather context**.
17
+ We recommend using this model as a backup to a weather API or for offline forecasting if internet access is unavailable.
18
+
19
+ We want to give a shoutout to [**Open-Meteo**](https://open-meteo.com/) for providing a **free-to-use** weather-forcasting API.
20
+
21
+ ---
22
+
23
+ [unfinished]
24
+
25
+ # Inference
26
+
27
+ ```python
28
+ #!/usr/bin/env python3
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import time
33
+ from pathlib import Path
34
+ from typing import Any
35
+
36
+ import numpy as np
37
+ import pandas as pd
38
+ import requests
39
+ import torch
40
+ from transformers import AutoConfig, AutoModel
41
+ from zoneinfo import ZoneInfo
42
+
43
+ # ----------------------------
44
+ # Change these values here
45
+ # ----------------------------
46
+ MODEL_ID = r"Harley-ml/Hweh-6M" # HF repo id or local path
47
+ CITY = "Seattle"
48
+ SEQUENCE_META_PATH = "Harley-ml/Hweh-6M/weather_sequences.metadata.json"
49
+ CONTEXT_HOURS = 72
50
+ FORECAST_HOURS = 12
51
+ DEVICE = None # "cpu", "cuda", "cuda:0", or None for auto
52
+
53
+ API_BASE_URL = "https://api.open-meteo.com/v1/forecast"
54
+ MAX_RETRIES = 6
55
+ REQUEST_TIMEOUT_S = 60
56
+
57
+ HOURLY_VARS = [
58
+ "temperature_2m",
59
+ "relative_humidity_2m",
60
+ "apparent_temperature",
61
+ "precipitation",
62
+ "weather_code",
63
+ "pressure_msl",
64
+ "surface_pressure",
65
+ "cloud_cover",
66
+ "visibility",
67
+ "wind_speed_10m",
68
+ "wind_direction_10m",
69
+ ]
70
+
71
+ WEATHER_CODE_BUCKETS = 7
72
+ TEMP_SCALE = 50.0
73
+ HUMIDITY_SCALE = 100.0
74
+ WIND_SCALE = 100.0
75
+
76
+ # ----------------------------
77
+ # City metadata (82 locations)
78
+ # ----------------------------
79
+ CITY_SPECS: dict[str, dict[str, Any]] = {
80
+ "Seattle": {"location_id": "1", "latitude": 47.6062, "longitude": -122.3321, "continent": "North America", "climate_tag": "temperate_oceanic", "elevation": 56},
81
+ "Portland": {"location_id": "2", "latitude": 45.5152, "longitude": -122.6784, "continent": "North America", "climate_tag": "temperate_oceanic", "elevation": 15},
82
+ "San Francisco": {"location_id": "3", "latitude": 37.7749, "longitude": -122.4194, "continent": "North America", "climate_tag": "foggy_mediterranean", "elevation": 16},
83
+ "Los Angeles": {"location_id": "4", "latitude": 34.0522, "longitude": -118.2437, "continent": "North America", "climate_tag": "sunny_mediterranean", "elevation": 71},
84
+ "Denver": {"location_id": "5", "latitude": 39.7392, "longitude": -104.9903, "continent": "North America", "climate_tag": "semi_arid_highland", "elevation": 1609},
85
+ "Chicago": {"location_id": "6", "latitude": 41.8781, "longitude": -87.6298, "continent": "North America", "climate_tag": "humid_continental", "elevation": 181},
86
+ "Dallas": {"location_id": "7", "latitude": 32.7767, "longitude": -96.7970, "continent": "North America", "climate_tag": "hot_subhumid", "elevation": 131},
87
+ "Atlanta": {"location_id": "8", "latitude": 33.7490, "longitude": -84.3880, "continent": "North America", "climate_tag": "humid_subtropical", "elevation": 320},
88
+ "New York": {"location_id": "9", "latitude": 40.7128, "longitude": -74.0060, "continent": "North America", "climate_tag": "humid_subtropical", "elevation": 10},
89
+ "Miami": {"location_id": "10", "latitude": 25.7617, "longitude": -80.1918, "continent": "North America", "climate_tag": "tropical_humid", "elevation": 2},
90
+ "Phoenix": {"location_id": "11", "latitude": 33.4484, "longitude": -112.0740, "continent": "North America", "climate_tag": "hot_arid", "elevation": 331},
91
+ "Salt Lake City": {"location_id": "12", "latitude": 40.7608, "longitude": -111.8910, "continent": "North America", "climate_tag": "semi_arid", "elevation": 1288},
92
+ "Anchorage": {"location_id": "13", "latitude": 61.2181, "longitude": -149.9003, "continent": "North America", "climate_tag": "subarctic_snowy", "elevation": 31},
93
+ "Minneapolis": {"location_id": "14", "latitude": 44.9778, "longitude": -93.2650, "continent": "North America", "climate_tag": "cold_snowy", "elevation": 264},
94
+ "Toronto": {"location_id": "15", "latitude": 43.6532, "longitude": -79.3832, "continent": "North America", "climate_tag": "humid_continental", "elevation": 76},
95
+ "Montreal": {"location_id": "16", "latitude": 45.5017, "longitude": -73.5673, "continent": "North America", "climate_tag": "cold_snowy", "elevation": 233},
96
+ "Vancouver": {"location_id": "17", "latitude": 49.2827, "longitude": -123.1207, "continent": "North America", "climate_tag": "temperate_oceanic", "elevation": 70},
97
+ "Mexico City": {"location_id": "18", "latitude": 19.4326, "longitude": -99.1332, "continent": "North America", "climate_tag": "highland_subtropical", "elevation": 2240},
98
+ "Havana": {"location_id": "19", "latitude": 23.1136, "longitude": -82.3666, "continent": "North America", "climate_tag": "tropical_humid", "elevation": 59},
99
+ "San Juan": {"location_id": "20", "latitude": 18.4655, "longitude": -66.1057, "continent": "North America", "climate_tag": "tropical_humid", "elevation": 8},
100
+
101
+ "Lima": {"location_id": "21", "latitude": -12.0464, "longitude": -77.0428, "continent": "South America", "climate_tag": "coastal_arid", "elevation": 154},
102
+ "Santiago": {"location_id": "22", "latitude": -33.4489, "longitude": -70.6693, "continent": "South America", "climate_tag": "mediterranean", "elevation": 520},
103
+ "Buenos Aires": {"location_id": "23", "latitude": -34.6037, "longitude": -58.3816, "continent": "South America", "climate_tag": "humid_subtropical", "elevation": 25},
104
+ "Bogotá": {"location_id": "24", "latitude": 4.7110, "longitude": -74.0721, "continent": "South America", "climate_tag": "highland_cool", "elevation": 2640},
105
+ "Quito": {"location_id": "25", "latitude": -0.1807, "longitude": -78.4678, "continent": "South America", "climate_tag": "highland_equatorial", "elevation": 2850},
106
+ "Caracas": {"location_id": "26", "latitude": 10.4806, "longitude": -66.9036, "continent": "South America", "climate_tag": "tropical_humid", "elevation": 900},
107
+ "Rio de Janeiro": {"location_id": "27", "latitude": -22.9068, "longitude": -43.1729, "continent": "South America", "climate_tag": "tropical_humid", "elevation": 5},
108
+ "São Paulo": {"location_id": "28", "latitude": -23.5505, "longitude": -46.6333, "continent": "South America", "climate_tag": "humid_subtropical", "elevation": 760},
109
+ "La Paz": {"location_id": "29", "latitude": -16.4897, "longitude": -68.1193, "continent": "South America", "climate_tag": "highland_cold", "elevation": 3640},
110
+ "Cusco": {"location_id": "30", "latitude": -13.5319, "longitude": -71.9675, "continent": "South America", "climate_tag": "highland_cool", "elevation": 3399},
111
+ "Montevideo": {"location_id": "31", "latitude": -34.9011, "longitude": -56.1645, "continent": "South America", "climate_tag": "temperate_oceanic", "elevation": 43},
112
+ "Asunción": {"location_id": "32", "latitude": -25.2637, "longitude": -57.5759, "continent": "South America", "climate_tag": "humid_subtropical", "elevation": 43},
113
+ "Manaus": {"location_id": "33", "latitude": -3.1190, "longitude": -60.0217, "continent": "South America", "climate_tag": "tropical_humid", "elevation": 92},
114
+ "Recife": {"location_id": "34", "latitude": -8.0476, "longitude": -34.8770, "continent": "South America", "climate_tag": "tropical_coastal", "elevation": 4},
115
+ "Punta Arenas": {"location_id": "35", "latitude": -53.1638, "longitude": -70.9171, "continent": "South America", "climate_tag": "cold_windy", "elevation": 34},
116
+
117
+ "London": {"location_id": "36", "latitude": 51.5074, "longitude": -0.1278, "continent": "Europe", "climate_tag": "temperate_oceanic", "elevation": 11},
118
+ "Paris": {"location_id": "37", "latitude": 48.8566, "longitude": 2.3522, "continent": "Europe", "climate_tag": "temperate_oceanic", "elevation": 35},
119
+ "Madrid": {"location_id": "38", "latitude": 40.4168, "longitude": -3.7038, "continent": "Europe", "climate_tag": "hot_summer_mediterranean", "elevation": 667},
120
+ "Rome": {"location_id": "39", "latitude": 41.9028, "longitude": 12.4964, "continent": "Europe", "climate_tag": "hot_summer_mediterranean", "elevation": 21},
121
+ "Berlin": {"location_id": "40", "latitude": 52.52, "longitude": 13.4050, "continent": "Europe", "climate_tag": "temperate_continental", "elevation": 34},
122
+ "Stockholm": {"location_id": "41", "latitude": 59.3293, "longitude": 18.0686, "continent": "Europe", "climate_tag": "cold_marine", "elevation": 28},
123
+ "Oslo": {"location_id": "42", "latitude": 59.9139, "longitude": 10.7522, "continent": "Europe", "climate_tag": "cold_snowy", "elevation": 23},
124
+ "Helsinki": {"location_id": "43", "latitude": 60.1699, "longitude": 24.9384, "continent": "Europe", "climate_tag": "cold_snowy", "elevation": 25},
125
+ "Reykjavik": {"location_id": "44", "latitude": 64.1466, "longitude": -21.9426, "continent": "Europe", "climate_tag": "cold_windy", "elevation": 12},
126
+ "Kyiv": {"location_id": "45", "latitude": 50.4501, "longitude": 30.5234, "continent": "Europe", "climate_tag": "humid_continental", "elevation": 179},
127
+ "Lisbon": {"location_id": "46", "latitude": 38.7223, "longitude": -9.1393, "continent": "Europe", "climate_tag": "sunny_mediterranean", "elevation": 7},
128
+ "Athens": {"location_id": "47", "latitude": 37.9838, "longitude": 23.7275, "continent": "Europe", "climate_tag": "sunny_mediterranean", "elevation": 70},
129
+ "Zurich": {"location_id": "48", "latitude": 47.3769, "longitude": 8.5417, "continent": "Europe", "climate_tag": "temperate_continental", "elevation": 408},
130
+ "Dublin": {"location_id": "49", "latitude": 53.3498, "longitude": -6.2603, "continent": "Europe", "climate_tag": "temperate_oceanic", "elevation": 20},
131
+ "Vienna": {"location_id": "50", "latitude": 48.2082, "longitude": 16.3738, "continent": "Europe", "climate_tag": "temperate_continental", "elevation": 171},
132
+
133
+ "Dubai": {"location_id": "51", "latitude": 25.2048, "longitude": 55.2708, "continent": "Asia", "climate_tag": "hot_arid", "elevation": 16},
134
+ "Riyadh": {"location_id": "52", "latitude": 24.7136, "longitude": 46.6753, "continent": "Asia", "climate_tag": "hot_arid", "elevation": 612},
135
+ "Delhi": {"location_id": "53", "latitude": 28.7041, "longitude": 77.1025, "continent": "Asia", "climate_tag": "hot_semi_arid", "elevation": 216},
136
+ "Mumbai": {"location_id": "54", "latitude": 19.0760, "longitude": 72.8777, "continent": "Asia", "climate_tag": "tropical_humid", "elevation": 14},
137
+ "Bangkok": {"location_id": "55", "latitude": 13.7563, "longitude": 100.5018, "continent": "Asia", "climate_tag": "tropical_monsoon", "elevation": 2},
138
+ "Singapore": {"location_id": "56", "latitude": 1.3521, "longitude": 103.8198, "continent": "Asia", "climate_tag": "tropical_humid", "elevation": 15},
139
+ "Tokyo": {"location_id": "57", "latitude": 35.6762, "longitude": 139.6503, "continent": "Asia", "climate_tag": "humid_subtropical", "elevation": 40},
140
+ "Seoul": {"location_id": "58", "latitude": 37.5665, "longitude": 126.9780, "continent": "Asia", "climate_tag": "humid_continental", "elevation": 38},
141
+ "Ulaanbaatar": {"location_id": "59", "latitude": 47.8864, "longitude": 106.9057, "continent": "Asia", "climate_tag": "cold_steppe", "elevation": 1350},
142
+ "Kathmandu": {"location_id": "60", "latitude": 27.7172, "longitude": 85.3240, "continent": "Asia", "climate_tag": "highland_subtropical", "elevation": 1400},
143
+ "Chiang Mai": {"location_id": "61", "latitude": 18.7883, "longitude": 98.9853, "continent": "Asia", "climate_tag": "tropical_seasonal", "elevation": 300},
144
+ "Lhasa": {"location_id": "62", "latitude": 29.6520, "longitude": 91.1721, "continent": "Asia", "climate_tag": "high_altitude_cold", "elevation": 3656},
145
+ "Jakarta": {"location_id": "63", "latitude": -6.2088, "longitude": 106.8456, "continent": "Asia", "climate_tag": "tropical_humid", "elevation": 8},
146
+ "Manila": {"location_id": "64", "latitude": 14.5995, "longitude": 120.9842, "continent": "Asia", "climate_tag": "tropical_humid", "elevation": 16},
147
+ "Karachi": {"location_id": "65", "latitude": 24.8607, "longitude": 67.0011, "continent": "Asia", "climate_tag": "hot_arid", "elevation": 10},
148
+
149
+ "Cairo": {"location_id": "66", "latitude": 30.0444, "longitude": 31.2357, "continent": "Africa", "climate_tag": "hot_arid", "elevation": 23},
150
+ "Alexandria": {"location_id": "67", "latitude": 31.2001, "longitude": 29.9187, "continent": "Africa", "climate_tag": "coastal_mediterranean", "elevation": 5},
151
+ "Casablanca": {"location_id": "68", "latitude": 33.5731, "longitude": -7.5898, "continent": "Africa", "climate_tag": "coastal_mediterranean", "elevation": 56},
152
+ "Marrakech": {"location_id": "69", "latitude": 31.6295, "longitude": -7.9811, "continent": "Africa", "climate_tag": "hot_semi_arid", "elevation": 466},
153
+ "Lagos": {"location_id": "70", "latitude": 6.5244, "longitude": 3.3792, "continent": "Africa", "climate_tag": "tropical_humid", "elevation": 41},
154
+ "Nairobi": {"location_id": "71", "latitude": -1.2921, "longitude": 36.8219, "continent": "Africa", "climate_tag": "temperate_highland", "elevation": 1795},
155
+ "Addis Ababa": {"location_id": "72", "latitude": 8.9806, "longitude": 38.7578, "continent": "Africa", "climate_tag": "temperate_highland", "elevation": 2355},
156
+ "Cape Town": {"location_id": "73", "latitude": -33.9249, "longitude": 18.4241, "continent": "Africa", "climate_tag": "mediterranean", "elevation": 25},
157
+ "Johannesburg": {"location_id": "74", "latitude": -26.2041, "longitude": 28.0473, "continent": "Africa", "climate_tag": "subtropical_highland", "elevation": 1753},
158
+ "Windhoek": {"location_id": "75", "latitude": -22.5609, "longitude": 17.0658, "continent": "Africa", "climate_tag": "semi_arid", "elevation": 1650},
159
+ "Accra": {"location_id": "76", "latitude": 5.6037, "longitude": -0.1870, "continent": "Africa", "climate_tag": "tropical_humid", "elevation": 61},
160
+ "Kigali": {"location_id": "77", "latitude": -1.9441, "longitude": 30.0619, "continent": "Africa", "climate_tag": "highland_tropical", "elevation": 1567},
161
+ "Tunis": {"location_id": "78", "latitude": 36.8065, "longitude": 10.1815, "continent": "Africa", "climate_tag": "mediterranean", "elevation": 4},
162
+ "Dakar": {"location_id": "79", "latitude": -14.7167, "longitude": -17.4677, "continent": "Africa", "climate_tag": "hot_coastal", "elevation": 25},
163
+ "Mombasa": {"location_id": "80", "latitude": -4.0435, "longitude": 39.6682, "continent": "Africa", "climate_tag": "tropical_coastal", "elevation": 17},
164
+
165
+ "Sydney": {"location_id": "81", "latitude": -33.8688, "longitude": 151.2093, "continent": "Oceania", "climate_tag": "humid_subtropical", "elevation": 58},
166
+ "Melbourne": {"location_id": "82", "latitude": -37.8136, "longitude": 144.9631, "continent": "Oceania", "climate_tag": "temperate_oceanic", "elevation": 31},
167
+ }
168
+
169
+ CITY_TIMEZONES: dict[str, str] = {
170
+ "Seattle": "America/Los_Angeles",
171
+ "Portland": "America/Los_Angeles",
172
+ "San Francisco": "America/Los_Angeles",
173
+ "Los Angeles": "America/Los_Angeles",
174
+ "Denver": "America/Denver",
175
+ "Chicago": "America/Chicago",
176
+ "Dallas": "America/Chicago",
177
+ "Atlanta": "America/New_York",
178
+ "New York": "America/New_York",
179
+ "Miami": "America/New_York",
180
+ "Phoenix": "America/Phoenix",
181
+ "Salt Lake City": "America/Denver",
182
+ "Anchorage": "America/Anchorage",
183
+ "Minneapolis": "America/Chicago",
184
+ "Toronto": "America/Toronto",
185
+ "Montreal": "America/Toronto",
186
+ "Vancouver": "America/Vancouver",
187
+ "Mexico City": "America/Mexico_City",
188
+ "Havana": "America/Havana",
189
+ "San Juan": "America/Puerto_Rico",
190
+ "Lima": "America/Lima",
191
+ "Santiago": "America/Santiago",
192
+ "Buenos Aires": "America/Argentina/Buenos_Aires",
193
+ "Bogotá": "America/Bogota",
194
+ "Quito": "America/Guayaquil",
195
+ "Caracas": "America/Caracas",
196
+ "Rio de Janeiro": "America/Sao_Paulo",
197
+ "São Paulo": "America/Sao_Paulo",
198
+ "La Paz": "America/La_Paz",
199
+ "Cusco": "America/Lima",
200
+ "Montevideo": "America/Montevideo",
201
+ "Asunción": "America/Asuncion",
202
+ "Manaus": "America/Manaus",
203
+ "Recife": "America/Recife",
204
+ "Punta Arenas": "America/Punta_Arenas",
205
+ "London": "Europe/London",
206
+ "Paris": "Europe/Paris",
207
+ "Madrid": "Europe/Madrid",
208
+ "Rome": "Europe/Rome",
209
+ "Berlin": "Europe/Berlin",
210
+ "Stockholm": "Europe/Stockholm",
211
+ "Oslo": "Europe/Oslo",
212
+ "Helsinki": "Europe/Helsinki",
213
+ "Reykjavik": "Atlantic/Reykjavik",
214
+ "Kyiv": "Europe/Kyiv",
215
+ "Lisbon": "Europe/Lisbon",
216
+ "Athens": "Europe/Athens",
217
+ "Zurich": "Europe/Zurich",
218
+ "Dublin": "Europe/Dublin",
219
+ "Vienna": "Europe/Vienna",
220
+ "Dubai": "Asia/Dubai",
221
+ "Riyadh": "Asia/Riyadh",
222
+ "Delhi": "Asia/Kolkata",
223
+ "Mumbai": "Asia/Kolkata",
224
+ "Bangkok": "Asia/Bangkok",
225
+ "Singapore": "Asia/Singapore",
226
+ "Tokyo": "Asia/Tokyo",
227
+ "Seoul": "Asia/Seoul",
228
+ "Ulaanbaatar": "Asia/Ulaanbaatar",
229
+ "Kathmandu": "Asia/Kathmandu",
230
+ "Chiang Mai": "Asia/Bangkok",
231
+ "Lhasa": "Asia/Shanghai",
232
+ "Jakarta": "Asia/Jakarta",
233
+ "Manila": "Asia/Manila",
234
+ "Karachi": "Asia/Karachi",
235
+ "Cairo": "Africa/Cairo",
236
+ "Alexandria": "Africa/Cairo",
237
+ "Casablanca": "Africa/Casablanca",
238
+ "Marrakech": "Africa/Casablanca",
239
+ "Lagos": "Africa/Lagos",
240
+ "Nairobi": "Africa/Nairobi",
241
+ "Addis Ababa": "Africa/Addis_Ababa",
242
+ "Cape Town": "Africa/Johannesburg",
243
+ "Johannesburg": "Africa/Johannesburg",
244
+ "Windhoek": "Africa/Windhoek",
245
+ "Accra": "Africa/Accra",
246
+ "Kigali": "Africa/Kigali",
247
+ "Tunis": "Africa/Tunis",
248
+ "Dakar": "Africa/Dakar",
249
+ "Mombasa": "Africa/Nairobi",
250
+ "Sydney": "Australia/Sydney",
251
+ "Melbourne": "Australia/Melbourne",
252
+ }
253
+
254
+ # ----------------------------
255
+ # Helpers
256
+ # ----------------------------
257
+ def weather_code_to_bucket(code) -> int:
258
+ if code is None:
259
+ return 1
260
+ try:
261
+ if pd.isna(code):
262
+ return 1
263
+ except Exception:
264
+ pass
265
+
266
+ code = int(code)
267
+ if code == 0:
268
+ return 0
269
+ if code in (1, 2, 3):
270
+ return 1
271
+ if code in (45, 48):
272
+ return 2
273
+ if code in (51, 53, 55, 56, 57):
274
+ return 3
275
+ if code in (61, 63, 65, 66, 67, 80, 81, 82):
276
+ return 4
277
+ if code in (71, 73, 75, 77, 85, 86):
278
+ return 5
279
+ if code in (95, 96, 99):
280
+ return 6
281
+ return 1
282
+
283
+
284
+ def cyc(x: np.ndarray, period: float) -> tuple[np.ndarray, np.ndarray]:
285
+ angle = 2.0 * np.pi * (x / period)
286
+ return np.sin(angle), np.cos(angle)
287
+
288
+
289
+ def request_with_backoff(session: requests.Session, url: str, params: dict[str, Any]) -> dict[str, Any]:
290
+ last_exc: Exception | None = None
291
+ for attempt in range(MAX_RETRIES):
292
+ try:
293
+ resp = session.get(url, params=params, timeout=REQUEST_TIMEOUT_S)
294
+ if resp.status_code == 429:
295
+ retry_after = resp.headers.get("Retry-After")
296
+ sleep_s = float(retry_after) if retry_after else min(60.0, 2**attempt)
297
+ print(f"Rate limited. Sleeping {sleep_s:.1f}s and retrying.", flush=True)
298
+ time.sleep(sleep_s)
299
+ continue
300
+ resp.raise_for_status()
301
+ return resp.json()
302
+ except Exception as e:
303
+ last_exc = e
304
+ sleep_s = min(60.0, 2**attempt)
305
+ print(f"Request failed: {e}. Sleeping {sleep_s:.1f}s and retrying.", flush=True)
306
+ time.sleep(sleep_s)
307
+ raise RuntimeError(f"Failed after {MAX_RETRIES} retries: {params}") from last_exc
308
+
309
+
310
+ def load_sequence_meta(path: str) -> dict[str, Any]:
311
+ p = Path(path)
312
+ if not p.exists():
313
+ return {"location_to_id": {}}
314
+ with open(p, "r", encoding="utf-8") as f:
315
+ meta = json.load(f)
316
+ meta.setdefault("location_to_id", {})
317
+ return meta
318
+
319
+
320
+ def load_model():
321
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
322
+ model = AutoModel.from_pretrained(MODEL_ID, config=config, trust_remote_code=True)
323
+ model.eval()
324
+ return model, config
325
+
326
+
327
+ def fetch_recent_history(city: str, context_hours: int) -> pd.DataFrame:
328
+ if city not in CITY_SPECS:
329
+ raise ValueError(f"Unknown city: {city}")
330
+
331
+ spec = CITY_SPECS[city]
332
+ session = requests.Session()
333
+ session.headers.update({"User-Agent": "Mozilla/5.0"})
334
+
335
+ params = {
336
+ "latitude": spec["latitude"],
337
+ "longitude": spec["longitude"],
338
+ "hourly": ",".join(HOURLY_VARS),
339
+ "timezone": "UTC",
340
+ "temperature_unit": "celsius",
341
+ "wind_speed_unit": "kmh",
342
+ "precipitation_unit": "mm",
343
+ "past_hours": int(context_hours) + 2,
344
+ "forecast_hours": 0,
345
+ }
346
+
347
+ data = request_with_backoff(session, API_BASE_URL, params=params)
348
+ hourly = data.get("hourly", {})
349
+ if "time" not in hourly:
350
+ raise ValueError(f"No hourly data returned for {city}: {data}")
351
+
352
+ df = pd.DataFrame(hourly)
353
+ if df.empty:
354
+ raise ValueError(f"Empty hourly response for {city}.")
355
+
356
+ df["time"] = pd.to_datetime(df["time"], errors="coerce", utc=True)
357
+ df = df.dropna(subset=["time"]).sort_values("time").drop_duplicates(subset=["time"]).reset_index(drop=True)
358
+
359
+ needed = HOURLY_VARS
360
+ missing = [c for c in needed if c not in df.columns]
361
+ if missing:
362
+ raise ValueError(f"Missing hourly columns in API response: {missing}")
363
+
364
+ for c in needed:
365
+ df[c] = pd.to_numeric(df[c], errors="coerce")
366
+
367
+ df["weather_code"] = df["weather_code"].fillna(1)
368
+ df["precipitation"] = df["precipitation"].fillna(0.0)
369
+
370
+ for c in [
371
+ "temperature_2m",
372
+ "relative_humidity_2m",
373
+ "apparent_temperature",
374
+ "precipitation",
375
+ "pressure_msl",
376
+ "surface_pressure",
377
+ "cloud_cover",
378
+ "visibility",
379
+ "wind_speed_10m",
380
+ "wind_direction_10m",
381
+ ]:
382
+ df[c] = df[c].interpolate(limit_direction="both").ffill().bfill()
383
+
384
+ now_utc = pd.Timestamp.now(tz="UTC")
385
+ df = df[df["time"] <= now_utc].copy()
386
+
387
+ if len(df) < context_hours:
388
+ raise ValueError(f"Not enough observed rows: got {len(df)}, need {context_hours}")
389
+
390
+ return df.tail(context_hours).reset_index(drop=True)
391
+
392
+
393
+ def build_single_sequence(df: pd.DataFrame) -> np.ndarray:
394
+ hour = df["time"].dt.hour.to_numpy()
395
+ doy = df["time"].dt.dayofyear.to_numpy()
396
+
397
+ hour_sin, hour_cos = cyc(hour.astype(float), 24.0)
398
+ doy_sin, doy_cos = cyc(doy.astype(float), 365.25)
399
+
400
+ temp = np.nan_to_num(df["temperature_2m"].astype(float).to_numpy(), nan=0.0)
401
+ humidity = np.nan_to_num(df["relative_humidity_2m"].astype(float).to_numpy(), nan=0.0)
402
+ apparent = np.nan_to_num(df["apparent_temperature"].astype(float).to_numpy(), nan=0.0)
403
+ precip = np.nan_to_num(df["precipitation"].astype(float).to_numpy(), nan=0.0)
404
+ pressure = np.nan_to_num(df["pressure_msl"].astype(float).to_numpy(), nan=0.0)
405
+ surface_pressure = np.nan_to_num(df["surface_pressure"].astype(float).to_numpy(), nan=0.0)
406
+ cloud_cover = np.nan_to_num(df["cloud_cover"].astype(float).to_numpy(), nan=0.0)
407
+ visibility = np.nan_to_num(df["visibility"].astype(float).to_numpy(), nan=0.0)
408
+ wind = np.nan_to_num(df["wind_speed_10m"].astype(float).to_numpy(), nan=0.0)
409
+ wind_dir = np.nan_to_num(df["wind_direction_10m"].astype(float).to_numpy(), nan=0.0)
410
+ wind_dir_sin, wind_dir_cos = cyc(wind_dir, 360.0)
411
+ weather_bucket = df["weather_code"].fillna(1).apply(weather_code_to_bucket).to_numpy(dtype=np.int64)
412
+
413
+ rows = []
414
+ for i in range(len(df)):
415
+ wc_oh = np.zeros(WEATHER_CODE_BUCKETS, dtype=np.float32)
416
+ wc_oh[weather_bucket[i]] = 1.0
417
+
418
+ row = np.concatenate(
419
+ [
420
+ np.array(
421
+ [
422
+ temp[i] / TEMP_SCALE,
423
+ humidity[i] / HUMIDITY_SCALE,
424
+ apparent[i] / TEMP_SCALE,
425
+ np.log1p(max(precip[i], 0.0)) / 3.0,
426
+ pressure[i] / 1100.0,
427
+ surface_pressure[i] / 1100.0,
428
+ cloud_cover[i] / 100.0,
429
+ visibility[i] / 50000.0,
430
+ wind[i] / WIND_SCALE,
431
+ wind_dir_sin[i],
432
+ wind_dir_cos[i],
433
+ hour_sin[i],
434
+ hour_cos[i],
435
+ doy_sin[i],
436
+ doy_cos[i],
437
+ ],
438
+ dtype=np.float32,
439
+ ),
440
+ wc_oh,
441
+ ]
442
+ )
443
+ rows.append(row)
444
+
445
+ seq = np.asarray(rows, dtype=np.float32)
446
+
447
+ if not np.isfinite(seq).all():
448
+ bad = np.argwhere(~np.isfinite(seq))
449
+ raise ValueError(f"Non-finite values remain in sequence at positions like: {bad[:10].tolist()}")
450
+
451
+ return seq
452
+
453
+
454
+ def to_iso(ts: pd.Timestamp, tz_name: str | None = None) -> str:
455
+ if tz_name:
456
+ try:
457
+ return ts.tz_convert(ZoneInfo(tz_name)).isoformat()
458
+ except Exception:
459
+ pass
460
+ return ts.isoformat()
461
+
462
+
463
+ def get_logits(out):
464
+ if isinstance(out, dict) and "logits" in out:
465
+ return out["logits"]
466
+ if hasattr(out, "logits"):
467
+ return out.logits
468
+ return out
469
+
470
+
471
+ def resolve_location_index(seq_meta: dict[str, Any], city_location_id: str) -> int:
472
+ location_to_id = seq_meta.get("location_to_id", {})
473
+
474
+ if city_location_id in location_to_id:
475
+ return int(location_to_id[city_location_id])
476
+
477
+ try:
478
+ as_int = int(city_location_id)
479
+ if as_int in location_to_id:
480
+ return int(location_to_id[as_int])
481
+ if str(as_int) in location_to_id:
482
+ return int(location_to_id[str(as_int)])
483
+ except Exception:
484
+ pass
485
+
486
+ for unk_key in ("UNK", "<UNK>", "unknown", "UNKNOWN"):
487
+ if unk_key in location_to_id:
488
+ return int(location_to_id[unk_key])
489
+
490
+ return 0
491
+
492
+
493
+ def predict():
494
+ seq_meta = load_sequence_meta(SEQUENCE_META_PATH)
495
+ model, config = load_model()
496
+
497
+ if CITY not in CITY_SPECS:
498
+ raise ValueError(f"Unknown city: {CITY}")
499
+
500
+ if CONTEXT_HOURS <= 0:
501
+ raise ValueError("CONTEXT_HOURS must be > 0")
502
+
503
+ if hasattr(config, "seq_len") and int(config.seq_len) != CONTEXT_HOURS:
504
+ raise ValueError(f"Set CONTEXT_HOURS to {int(config.seq_len)} for this model.")
505
+
506
+ city_spec = CITY_SPECS[CITY]
507
+ city_tz = CITY_TIMEZONES.get(CITY, "UTC")
508
+ model_location_id = resolve_location_index(seq_meta, str(city_spec["location_id"]))
509
+
510
+ df = fetch_recent_history(CITY, CONTEXT_HOURS)
511
+ seq = build_single_sequence(df)
512
+
513
+ X = torch.from_numpy(seq).unsqueeze(0)
514
+ loc = torch.tensor([model_location_id], dtype=torch.long)
515
+
516
+ target_device = torch.device(
517
+ DEVICE if DEVICE else ("cuda" if torch.cuda.is_available() else "cpu")
518
+ )
519
+ model = model.to(target_device)
520
+ X = X.to(target_device)
521
+ loc = loc.to(target_device)
522
+
523
+ weather_class_names = getattr(config, "weather_class_names", None)
524
+ if not weather_class_names:
525
+ weather_class_names = [f"class_{i}" for i in range(int(getattr(config, "num_weather_classes", 7)))]
526
+
527
+ with torch.no_grad():
528
+ out = model(X=X, location_id=loc)
529
+ logits = get_logits(out)
530
+
531
+ (
532
+ temp_pred,
533
+ humidity_pred,
534
+ apparent_pred,
535
+ precip_pred,
536
+ sea_level_pressure_pred,
537
+ surface_pressure_pred,
538
+ cloud_cover_pred,
539
+ wind_pred,
540
+ wind_dir_sin_pred,
541
+ wind_dir_cos_pred,
542
+ rain_logit,
543
+ weather_logits,
544
+ ) = logits
545
+
546
+ temp_pred = temp_pred.squeeze(0).detach().cpu().numpy()
547
+ humidity_pred = humidity_pred.squeeze(0).detach().cpu().numpy()
548
+ apparent_pred = apparent_pred.squeeze(0).detach().cpu().numpy()
549
+ precip_pred = precip_pred.squeeze(0).detach().cpu().numpy()
550
+ sea_level_pressure_pred = sea_level_pressure_pred.squeeze(0).detach().cpu().numpy()
551
+ surface_pressure_pred = surface_pressure_pred.squeeze(0).detach().cpu().numpy()
552
+ cloud_cover_pred = cloud_cover_pred.squeeze(0).detach().cpu().numpy()
553
+ wind_pred = wind_pred.squeeze(0).detach().cpu().numpy()
554
+ rain_prob = torch.sigmoid(rain_logit).squeeze(0).detach().cpu().numpy()
555
+ weather_probs = torch.softmax(weather_logits, dim=-1).squeeze(0).detach().cpu().numpy()
556
+ weather_idx = np.argmax(weather_probs, axis=-1).astype(np.int64)
557
+
558
+ context_start = df["time"].iloc[0]
559
+ context_end = df["time"].iloc[-1]
560
+ requested_at_utc = pd.Timestamp.now(tz="UTC")
561
+
562
+ horizon = min(
563
+ int(FORECAST_HOURS),
564
+ int(temp_pred.shape[0]),
565
+ int(humidity_pred.shape[0]),
566
+ int(weather_idx.shape[0]),
567
+ )
568
+
569
+ forecast = []
570
+ for lead in range(1, horizon + 1):
571
+ target_time = context_end + pd.Timedelta(hours=lead)
572
+ idx = lead - 1
573
+ w_idx = int(weather_idx[idx])
574
+
575
+ forecast.append(
576
+ {
577
+ "lead_hours": lead,
578
+ "target_utc": target_time.isoformat(),
579
+ "target_local": to_iso(target_time, city_tz),
580
+ "temperature_2m_c": float(temp_pred[idx]),
581
+ "relative_humidity_2m_pct": float(humidity_pred[idx]),
582
+ "apparent_temperature_c": float(apparent_pred[idx]),
583
+ "precipitation_mm": float(precip_pred[idx]),
584
+ "pressure_msl_hpa": float(sea_level_pressure_pred[idx]),
585
+ "surface_pressure_hpa": float(surface_pressure_pred[idx]),
586
+ "cloud_cover_pct": float(cloud_cover_pred[idx]),
587
+ "wind_speed_10m_kmh": float(wind_pred[idx]),
588
+ "rain_probability": float(rain_prob[idx]),
589
+ "weather_class": w_idx,
590
+ "weather_class_name": weather_class_names[w_idx] if w_idx < len(weather_class_names) else f"class_{w_idx}",
591
+ "weather_class_probabilities": {
592
+ name: float(prob) for name, prob in zip(weather_class_names, weather_probs[idx])
593
+ },
594
+ }
595
+ )
596
+
597
+ result = {
598
+ "city": CITY,
599
+ "location_id": str(city_spec["location_id"]),
600
+ "model_location_id": int(model_location_id),
601
+ "data_source": "open-meteo forecast api (past-hours context only)",
602
+ "requested_at_utc": requested_at_utc.isoformat(),
603
+ "context": {
604
+ "hours": int(len(df)),
605
+ "start_utc": context_start.isoformat(),
606
+ "end_utc": context_end.isoformat(),
607
+ "start_local": to_iso(context_start, city_tz),
608
+ "end_local": to_iso(context_end, city_tz),
609
+ },
610
+ "model": {
611
+ "model_id": MODEL_ID,
612
+ "encoder_type": getattr(config, "encoder_type", None),
613
+ "seq_len": int(getattr(config, "seq_len", CONTEXT_HOURS)),
614
+ "input_dim": int(getattr(config, "input_dim", seq.shape[1])),
615
+ "num_weather_classes": int(getattr(config, "num_weather_classes", len(weather_class_names))),
616
+ },
617
+ "forecast": forecast,
618
+ "sanity": {
619
+ "sequence_shape": list(seq.shape),
620
+ "finite_features": bool(np.isfinite(seq).all()),
621
+ },
622
+ }
623
+
624
+ print(json.dumps(result, indent=2))
625
+
626
+
627
+ if __name__ == "__main__":
628
+ predict()
629
+ ```