kyLELEng commited on
Commit
06d2c78
·
verified ·
1 Parent(s): 4951453

Train WeatherScenarioDiffusion-1D

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample_plots/future_mask_sample.png filter=lfs diff=lfs merge=lfs -text
37
+ sample_plots/real_vs_generated.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ tags:
4
+ - time-series
5
+ - diffusion
6
+ - scenario-generation
7
+ - weather
8
+ - multivariate-time-series
9
+ ---
10
+
11
+ # WeatherScenarioDiffusion-1D
12
+
13
+ WeatherScenarioDiffusion-1D is a conditional 1D diffusion model for multivariate weather time-series scenario generation.
14
+
15
+ The model is trained on [`Duyu/Time-Series-Forecasting-Benchmark-Datasets`](https://huggingface.co/datasets/Duyu/Time-Series-Forecasting-Benchmark-Datasets), file `Weather.csv`.
16
+
17
+ ## What The Model Does
18
+
19
+ This is a single conditional diffusion model with three usage modes:
20
+
21
+ 1. **Unconditional scenario generation**: sample realistic multivariate weather trajectories from noise.
22
+ 2. **Future-mask generation**: condition on the first part of a window and generate the missing future segment.
23
+ 3. **Channel inpainting**: condition on known weather variables and generate missing variables.
24
+
25
+ The model uses:
26
+
27
+ - `diffusers.UNet1DModel`
28
+ - `diffusers.DDPMScheduler`
29
+ - mask conditioning through concatenated input channels: `noisy_x`, `observed_x`, and `observed_mask`
30
+
31
+ ## Data
32
+
33
+ - Source dataset: `Duyu/Time-Series-Forecasting-Benchmark-Datasets`
34
+ - Source file: `Weather.csv`
35
+ - Numeric channels detected: `21`
36
+ - Window length: `256`
37
+ - Stride: `4`
38
+ - Split: time-ordered 80% train / 10% validation / 10% test
39
+ - Normalization: z-score fitted only on the train split
40
+
41
+ Detected channels:
42
+
43
+ ```json
44
+ [
45
+ "feature_00",
46
+ "feature_01",
47
+ "feature_02",
48
+ "feature_03",
49
+ "feature_04",
50
+ "feature_05",
51
+ "feature_06",
52
+ "feature_07",
53
+ "feature_08",
54
+ "feature_09",
55
+ "feature_10",
56
+ "feature_11",
57
+ "feature_12",
58
+ "feature_13",
59
+ "feature_14",
60
+ "feature_15",
61
+ "feature_16",
62
+ "feature_17",
63
+ "feature_18",
64
+ "feature_19",
65
+ "feature_20"
66
+ ]
67
+ ```
68
+
69
+ ## Training
70
+
71
+ ```json
72
+ {
73
+ "dataset_repo": "Duyu/Time-Series-Forecasting-Benchmark-Datasets",
74
+ "dataset_file": "Weather.csv",
75
+ "model_repo_id": "kyLELEng/weather-scenario-diffusion-1d",
76
+ "output_dir": "/tmp/weather-scenario-diffusion-1d",
77
+ "window_length": 256,
78
+ "stride": 4,
79
+ "max_train_steps": 8000,
80
+ "train_batch_size": 128,
81
+ "eval_batch_size": 128,
82
+ "num_workers": 8,
83
+ "learning_rate": 0.0002,
84
+ "weight_decay": 0.01,
85
+ "grad_clip_norm": 1.0,
86
+ "num_train_timesteps": 1000,
87
+ "eval_every": 1000,
88
+ "save_every": 2000,
89
+ "num_eval_batches": 12,
90
+ "sample_inference_steps": 80,
91
+ "sample_count": 24,
92
+ "mixed_precision": "bf16",
93
+ "seed": 42,
94
+ "model_size": "large",
95
+ "smoke_test": false
96
+ }
97
+ ```
98
+
99
+ The training objective is noise prediction:
100
+
101
+ ```text
102
+ MSE(predicted_noise, true_noise)
103
+ ```
104
+
105
+ Known observed regions are provided as conditioning input. The loss is weighted toward unknown/masked regions so the model learns conditional reconstruction as well as unconditional generation.
106
+
107
+ ## Evaluation
108
+
109
+ ```json
110
+ {
111
+ "future_mask_mse_zspace": 0.16154611110687256,
112
+ "channel_inpainting_mse_zspace": 0.10761465132236481,
113
+ "generated_real_correlation_mae": 0.3473077408348441,
114
+ "abs_autocorrelation_mae": NaN,
115
+ "real_distribution": {
116
+ "mean": [
117
+ 0.5322151780128479,
118
+ -1.3601813316345215,
119
+ -1.3809949159622192,
120
+ -0.8077001571655273,
121
+ 1.2566546201705933,
122
+ -1.078983187675476,
123
+ -0.829918384552002,
124
+ -0.8582332134246826,
125
+ -0.8346444964408875,
126
+ -0.8351123929023743,
127
+ 1.3949605226516724,
128
+ -0.010310296900570393,
129
+ -0.5439239740371704,
130
+ 0.03890685364603996,
131
+ -0.1013171598315239,
132
+ -0.2349442094564438,
133
+ -0.5373658537864685,
134
+ -0.5422216653823853,
135
+ -0.48122739791870117,
136
+ -1.287260890007019,
137
+ 0.09792334586381912
138
+ ],
139
+ "std": [
140
+ 0.07429111748933792,
141
+ 0.35871678590774536,
142
+ 0.35377517342567444,
143
+ 0.24259121716022491,
144
+ 0.41319674253463745,
145
+ 0.18131689727306366,
146
+ 0.17187191545963287,
147
+ 0.13253048062324524,
148
+ 0.1699266880750656,
149
+ 0.17053960263729095,
150
+ 0.35085079073905945,
151
+ 0.015412000007927418,
152
+ 0.3774285912513733,
153
+ 0.5294230580329895,
154
+ 3.1814274734642822e-06,
155
+ 1.0341267625335604e-05,
156
+ 0.2504253685474396,
157
+ 0.24686115980148315,
158
+ 0.20216289162635803,
159
+ 0.3868575692176819,
160
+ 0.03018086589872837
161
+ ],
162
+ "q05": [
163
+ 0.36811354756355286,
164
+ -1.7627040147781372,
165
+ -1.7786728143692017,
166
+ -1.1219414472579956,
167
+ 0.425606906414032,
168
+ -1.2679648399353027,
169
+ -1.0427569150924683,
170
+ -0.9597867131233215,
171
+ -1.0451189279556274,
172
+ -1.0457327365875244,
173
+ 0.82159823179245,
174
+ -0.029127197340130806,
175
+ -1.0400943756103516,
176
+ -0.9979122877120972,
177
+ -0.10131397843360901,
178
+ -0.23493386805057526,
179
+ -0.6634758114814758,
180
+ -0.670289933681488,
181
+ -0.58597731590271,
182
+ -1.7917370796203613,
183
+ 0.057487256824970245
184
+ ],
185
+ "q50": [
186
+ 0.5404008626937866,
187
+ -1.4684869050979614,
188
+ -1.4889553785324097,
189
+ -0.8586000204086304,
190
+ 1.4287834167480469,
191
+ -1.144168734550476,
192
+ -0.872648298740387,
193
+ -0.9205341339111328,
194
+ -0.8782141208648682,
195
+ -0.8781652450561523,
196
+ 1.5059175491333008,
197
+ -0.0145594272762537,
198
+ -0.6258403658866882,
199
+ 0.1368420273065567,
200
+ -0.10131397843360901,
201
+ -0.23493386805057526,
202
+ -0.6634758114814758,
203
+ -0.670289933681488,
204
+ -0.58597731590271,
205
+ -1.4187328815460205,
206
+ 0.093950055539608
207
+ ],
208
+ "q95": [
209
+ 0.638533353805542,
210
+ -0.761028528213501,
211
+ -0.794667661190033,
212
+ -0.35279181599617004,
213
+ 1.6108952760696411,
214
+ -0.766464352607727,
215
+ -0.49462929368019104,
216
+ -0.5979806184768677,
217
+ -0.5036054253578186,
218
+ -0.5011382699012756,
219
+ 1.7932209968566895,
220
+ 0.014986473135650158,
221
+ 0.06979382783174515,
222
+ 0.5976912975311279,
223
+ -0.10131397843360901,
224
+ -0.23493386805057526,
225
+ 0.027409523725509644,
226
+ -0.0009677457856014371,
227
+ -0.06447356939315796,
228
+ -0.6893876194953918,
229
+ 0.1492983102798462
230
+ ]
231
+ },
232
+ "generated_distribution": {
233
+ "mean": [
234
+ 0.10389683395624161,
235
+ 0.2204890251159668,
236
+ 0.19250470399856567,
237
+ 0.23873503506183624,
238
+ 0.01685507781803608,
239
+ 0.12124405056238174,
240
+ 0.29270878434181213,
241
+ 0.05499983951449394,
242
+ 0.22790412604808807,
243
+ 0.2170621007680893,
244
+ -0.18708543479442596,
245
+ 0.021801194176077843,
246
+ -0.030199257656931877,
247
+ 0.1692427545785904,
248
+ -0.08117542415857315,
249
+ -0.09455308318138123,
250
+ 0.06154454126954079,
251
+ 0.028043851256370544,
252
+ 0.12692318856716156,
253
+ 0.24798017740249634,
254
+ 0.017302973195910454
255
+ ],
256
+ "std": [
257
+ 0.2767521142959595,
258
+ 0.3012436330318451,
259
+ 0.2993737757205963,
260
+ 0.30074891448020935,
261
+ 0.3890363872051239,
262
+ 0.3219568431377411,
263
+ 0.3009476661682129,
264
+ 0.37483009696006775,
265
+ 0.30024605989456177,
266
+ 0.304373562335968,
267
+ 0.2866939902305603,
268
+ 0.19090980291366577,
269
+ 0.41177040338516235,
270
+ 0.398755818605423,
271
+ 0.19313617050647736,
272
+ 0.23614700138568878,
273
+ 0.4057450592517853,
274
+ 0.40770408511161804,
275
+ 0.407763808965683,
276
+ 0.31721681356430054,
277
+ 0.20663630962371826
278
+ ],
279
+ "q05": [
280
+ -0.3517603576183319,
281
+ -0.2666028141975403,
282
+ -0.2890886068344116,
283
+ -0.25566428899765015,
284
+ -0.6129659414291382,
285
+ -0.38257062435150146,
286
+ -0.19731372594833374,
287
+ -0.5293506979942322,
288
+ -0.24754241108894348,
289
+ -0.2712758183479309,
290
+ -0.6728720664978027,
291
+ -0.2995768189430237,
292
+ -0.6813194751739502,
293
+ -0.5154849290847778,
294
+ -0.36515557765960693,
295
+ -0.429704487323761,
296
+ -0.5263148546218872,
297
+ -0.5553821921348572,
298
+ -0.49782344698905945,
299
+ -0.2676949203014374,
300
+ -0.3324751853942871
301
+ ],
302
+ "q50": [
303
+ 0.10788173228502274,
304
+ 0.21172723174095154,
305
+ 0.18094468116760254,
306
+ 0.2328486293554306,
307
+ 0.0098641999065876,
308
+ 0.10586627572774887,
309
+ 0.28131914138793945,
310
+ 0.0327904112637043,
311
+ 0.21327275037765503,
312
+ 0.20810022950172424,
313
+ -0.18645159900188446,
314
+ 0.026983173564076424,
315
+ -0.0433419793844223,
316
+ 0.16982388496398926,
317
+ -0.09749776124954224,
318
+ -0.12292205542325974,
319
+ 0.010389605537056923,
320
+ -0.036815427243709564,
321
+ 0.09912104904651642,
322
+ 0.23998694121837616,
323
+ 0.019317764788866043
324
+ ],
325
+ "q95": [
326
+ 0.5557213425636292,
327
+ 0.7389823198318481,
328
+ 0.7046443223953247,
329
+ 0.7413685321807861,
330
+ 0.6779510974884033,
331
+ 0.6939223408699036,
332
+ 0.8154580593109131,
333
+ 0.7213369011878967,
334
+ 0.7498506903648376,
335
+ 0.7366548776626587,
336
+ 0.28038161993026733,
337
+ 0.3271234333515167,
338
+ 0.6884999871253967,
339
+ 0.8387558460235596,
340
+ 0.25787827372550964,
341
+ 0.31810709834098816,
342
+ 0.8252443671226501,
343
+ 0.8010784387588501,
344
+ 0.8797050714492798,
345
+ 0.7847591042518616,
346
+ 0.34978875517845154
347
+ ]
348
+ },
349
+ "real_abs_autocorr_lag1": [
350
+ 0.9877822900516616,
351
+ 0.9904621143737528,
352
+ 0.9903800119819881,
353
+ 0.9930759612361029,
354
+ 0.9845821075186362,
355
+ 0.9896354175342961,
356
+ 0.9931532651062824,
357
+ 0.9829303354664611,
358
+ 0.9930693669171041,
359
+ 0.9931078152703652,
360
+ 0.9904251617511753,
361
+ 0.6185269686957802,
362
+ 0.8362734986319137,
363
+ 0.5482885824237015,
364
+ NaN,
365
+ NaN,
366
+ 0.9654974645960943,
367
+ 0.97213405366258,
368
+ 0.9692365984881656,
369
+ 0.9926912520502599,
370
+ 0.9674985063407219
371
+ ],
372
+ "generated_abs_autocorr_lag1": [
373
+ 0.1746959775402206,
374
+ 0.43644415022477073,
375
+ 0.41653727634920873,
376
+ 0.39331193400821807,
377
+ 0.3961014770471717,
378
+ 0.415143957195966,
379
+ 0.4382806229222112,
380
+ 0.4055025998367542,
381
+ 0.447771283348962,
382
+ 0.4445783266695078,
383
+ 0.35184174376495303,
384
+ 0.018016469082830094,
385
+ 0.02088415221893785,
386
+ 0.07700486234712257,
387
+ 0.016226235857376443,
388
+ 0.06005403603427778,
389
+ 0.40116921177080017,
390
+ 0.41265325284010024,
391
+ 0.37372374982861106,
392
+ 0.48624188538979995,
393
+ 0.008822063729828335
394
+ ],
395
+ "training_history": [
396
+ {
397
+ "step": 1000,
398
+ "train_loss": 0.07410214841365814,
399
+ "validation_denoising_loss": 0.05256535982092222
400
+ },
401
+ {
402
+ "step": 2000,
403
+ "train_loss": 0.047588951885700226,
404
+ "validation_denoising_loss": 0.062365236381689705
405
+ },
406
+ {
407
+ "step": 3000,
408
+ "train_loss": 0.04786738008260727,
409
+ "validation_denoising_loss": 0.08261810739835103
410
+ },
411
+ {
412
+ "step": 4000,
413
+ "train_loss": 0.030978742986917496,
414
+ "validation_denoising_loss": 0.09854291876157124
415
+ },
416
+ {
417
+ "step": 5000,
418
+ "train_loss": 0.017614271491765976,
419
+ "validation_denoising_loss": 0.11429651578267415
420
+ },
421
+ {
422
+ "step": 6000,
423
+ "train_loss": 0.022595474496483803,
424
+ "validation_denoising_loss": 0.11743361999591191
425
+ },
426
+ {
427
+ "step": 7000,
428
+ "train_loss": 0.0207576435059309,
429
+ "validation_denoising_loss": 0.12093404183785121
430
+ },
431
+ {
432
+ "step": 8000,
433
+ "train_loss": 0.018115710467100143,
434
+ "validation_denoising_loss": 0.11552448819080989
435
+ }
436
+ ],
437
+ "best_validation_denoising_loss": 0.05256535982092222,
438
+ "final_step": 8000
439
+ }
440
+ ```
441
+
442
+ Evaluation is based on held-out windows and includes:
443
+
444
+ - validation denoising loss
445
+ - future-mask inpainting MSE
446
+ - channel-inpainting MSE
447
+ - generated-vs-real distribution statistics
448
+ - cross-channel correlation matrix error
449
+ - absolute-value autocorrelation error
450
+
451
+ ## Intended Use
452
+
453
+ This model is for research and demonstration of multivariate time-series diffusion. It is not a production forecasting system.
454
+
455
+ ## Files
456
+
457
+ - `config.json`: 1D U-Net model configuration
458
+ - `diffusion_pytorch_model.safetensors`: model weights
459
+ - `scheduler/scheduler_config.json`: DDPM scheduler configuration
460
+ - `preprocess_config.json`: dataset, split, normalization, window, and channel metadata
461
+ - `normalization_stats.json`: train-split mean and standard deviation
462
+ - `evaluation_report.json`: held-out evaluation metrics
463
+ - `sample_plots/`: generated examples and conditional samples
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet1DModel",
3
+ "_diffusers_version": "0.37.1",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 64,
7
+ 128,
8
+ 256,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownBlock1DNoSkip",
13
+ "DownBlock1D",
14
+ "AttnDownBlock1D",
15
+ "AttnDownBlock1D"
16
+ ],
17
+ "downsample_each_block": false,
18
+ "extra_in_channels": 128,
19
+ "flip_sin_to_cos": true,
20
+ "freq_shift": 0.0,
21
+ "in_channels": 63,
22
+ "layers_per_block": 2,
23
+ "mid_block_type": "UNetMidBlock1D",
24
+ "norm_num_groups": 8,
25
+ "out_block_type": null,
26
+ "out_channels": 21,
27
+ "sample_rate": null,
28
+ "sample_size": 256,
29
+ "time_embedding_dim": null,
30
+ "time_embedding_type": "fourier",
31
+ "up_block_types": [
32
+ "AttnUpBlock1D",
33
+ "AttnUpBlock1D",
34
+ "UpBlock1D",
35
+ "UpBlock1DNoSkip"
36
+ ],
37
+ "use_timestep_embedding": false
38
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e666b5ba8aa70e199fa07f8b3e1c5374662960cfd24a05cd01b363f7e9c6e6ec
3
+ size 166272740
evaluation_report.json ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "future_mask_mse_zspace": 0.16154611110687256,
3
+ "channel_inpainting_mse_zspace": 0.10761465132236481,
4
+ "generated_real_correlation_mae": 0.3473077408348441,
5
+ "abs_autocorrelation_mae": NaN,
6
+ "real_distribution": {
7
+ "mean": [
8
+ 0.5322151780128479,
9
+ -1.3601813316345215,
10
+ -1.3809949159622192,
11
+ -0.8077001571655273,
12
+ 1.2566546201705933,
13
+ -1.078983187675476,
14
+ -0.829918384552002,
15
+ -0.8582332134246826,
16
+ -0.8346444964408875,
17
+ -0.8351123929023743,
18
+ 1.3949605226516724,
19
+ -0.010310296900570393,
20
+ -0.5439239740371704,
21
+ 0.03890685364603996,
22
+ -0.1013171598315239,
23
+ -0.2349442094564438,
24
+ -0.5373658537864685,
25
+ -0.5422216653823853,
26
+ -0.48122739791870117,
27
+ -1.287260890007019,
28
+ 0.09792334586381912
29
+ ],
30
+ "std": [
31
+ 0.07429111748933792,
32
+ 0.35871678590774536,
33
+ 0.35377517342567444,
34
+ 0.24259121716022491,
35
+ 0.41319674253463745,
36
+ 0.18131689727306366,
37
+ 0.17187191545963287,
38
+ 0.13253048062324524,
39
+ 0.1699266880750656,
40
+ 0.17053960263729095,
41
+ 0.35085079073905945,
42
+ 0.015412000007927418,
43
+ 0.3774285912513733,
44
+ 0.5294230580329895,
45
+ 3.1814274734642822e-06,
46
+ 1.0341267625335604e-05,
47
+ 0.2504253685474396,
48
+ 0.24686115980148315,
49
+ 0.20216289162635803,
50
+ 0.3868575692176819,
51
+ 0.03018086589872837
52
+ ],
53
+ "q05": [
54
+ 0.36811354756355286,
55
+ -1.7627040147781372,
56
+ -1.7786728143692017,
57
+ -1.1219414472579956,
58
+ 0.425606906414032,
59
+ -1.2679648399353027,
60
+ -1.0427569150924683,
61
+ -0.9597867131233215,
62
+ -1.0451189279556274,
63
+ -1.0457327365875244,
64
+ 0.82159823179245,
65
+ -0.029127197340130806,
66
+ -1.0400943756103516,
67
+ -0.9979122877120972,
68
+ -0.10131397843360901,
69
+ -0.23493386805057526,
70
+ -0.6634758114814758,
71
+ -0.670289933681488,
72
+ -0.58597731590271,
73
+ -1.7917370796203613,
74
+ 0.057487256824970245
75
+ ],
76
+ "q50": [
77
+ 0.5404008626937866,
78
+ -1.4684869050979614,
79
+ -1.4889553785324097,
80
+ -0.8586000204086304,
81
+ 1.4287834167480469,
82
+ -1.144168734550476,
83
+ -0.872648298740387,
84
+ -0.9205341339111328,
85
+ -0.8782141208648682,
86
+ -0.8781652450561523,
87
+ 1.5059175491333008,
88
+ -0.0145594272762537,
89
+ -0.6258403658866882,
90
+ 0.1368420273065567,
91
+ -0.10131397843360901,
92
+ -0.23493386805057526,
93
+ -0.6634758114814758,
94
+ -0.670289933681488,
95
+ -0.58597731590271,
96
+ -1.4187328815460205,
97
+ 0.093950055539608
98
+ ],
99
+ "q95": [
100
+ 0.638533353805542,
101
+ -0.761028528213501,
102
+ -0.794667661190033,
103
+ -0.35279181599617004,
104
+ 1.6108952760696411,
105
+ -0.766464352607727,
106
+ -0.49462929368019104,
107
+ -0.5979806184768677,
108
+ -0.5036054253578186,
109
+ -0.5011382699012756,
110
+ 1.7932209968566895,
111
+ 0.014986473135650158,
112
+ 0.06979382783174515,
113
+ 0.5976912975311279,
114
+ -0.10131397843360901,
115
+ -0.23493386805057526,
116
+ 0.027409523725509644,
117
+ -0.0009677457856014371,
118
+ -0.06447356939315796,
119
+ -0.6893876194953918,
120
+ 0.1492983102798462
121
+ ]
122
+ },
123
+ "generated_distribution": {
124
+ "mean": [
125
+ 0.10389683395624161,
126
+ 0.2204890251159668,
127
+ 0.19250470399856567,
128
+ 0.23873503506183624,
129
+ 0.01685507781803608,
130
+ 0.12124405056238174,
131
+ 0.29270878434181213,
132
+ 0.05499983951449394,
133
+ 0.22790412604808807,
134
+ 0.2170621007680893,
135
+ -0.18708543479442596,
136
+ 0.021801194176077843,
137
+ -0.030199257656931877,
138
+ 0.1692427545785904,
139
+ -0.08117542415857315,
140
+ -0.09455308318138123,
141
+ 0.06154454126954079,
142
+ 0.028043851256370544,
143
+ 0.12692318856716156,
144
+ 0.24798017740249634,
145
+ 0.017302973195910454
146
+ ],
147
+ "std": [
148
+ 0.2767521142959595,
149
+ 0.3012436330318451,
150
+ 0.2993737757205963,
151
+ 0.30074891448020935,
152
+ 0.3890363872051239,
153
+ 0.3219568431377411,
154
+ 0.3009476661682129,
155
+ 0.37483009696006775,
156
+ 0.30024605989456177,
157
+ 0.304373562335968,
158
+ 0.2866939902305603,
159
+ 0.19090980291366577,
160
+ 0.41177040338516235,
161
+ 0.398755818605423,
162
+ 0.19313617050647736,
163
+ 0.23614700138568878,
164
+ 0.4057450592517853,
165
+ 0.40770408511161804,
166
+ 0.407763808965683,
167
+ 0.31721681356430054,
168
+ 0.20663630962371826
169
+ ],
170
+ "q05": [
171
+ -0.3517603576183319,
172
+ -0.2666028141975403,
173
+ -0.2890886068344116,
174
+ -0.25566428899765015,
175
+ -0.6129659414291382,
176
+ -0.38257062435150146,
177
+ -0.19731372594833374,
178
+ -0.5293506979942322,
179
+ -0.24754241108894348,
180
+ -0.2712758183479309,
181
+ -0.6728720664978027,
182
+ -0.2995768189430237,
183
+ -0.6813194751739502,
184
+ -0.5154849290847778,
185
+ -0.36515557765960693,
186
+ -0.429704487323761,
187
+ -0.5263148546218872,
188
+ -0.5553821921348572,
189
+ -0.49782344698905945,
190
+ -0.2676949203014374,
191
+ -0.3324751853942871
192
+ ],
193
+ "q50": [
194
+ 0.10788173228502274,
195
+ 0.21172723174095154,
196
+ 0.18094468116760254,
197
+ 0.2328486293554306,
198
+ 0.0098641999065876,
199
+ 0.10586627572774887,
200
+ 0.28131914138793945,
201
+ 0.0327904112637043,
202
+ 0.21327275037765503,
203
+ 0.20810022950172424,
204
+ -0.18645159900188446,
205
+ 0.026983173564076424,
206
+ -0.0433419793844223,
207
+ 0.16982388496398926,
208
+ -0.09749776124954224,
209
+ -0.12292205542325974,
210
+ 0.010389605537056923,
211
+ -0.036815427243709564,
212
+ 0.09912104904651642,
213
+ 0.23998694121837616,
214
+ 0.019317764788866043
215
+ ],
216
+ "q95": [
217
+ 0.5557213425636292,
218
+ 0.7389823198318481,
219
+ 0.7046443223953247,
220
+ 0.7413685321807861,
221
+ 0.6779510974884033,
222
+ 0.6939223408699036,
223
+ 0.8154580593109131,
224
+ 0.7213369011878967,
225
+ 0.7498506903648376,
226
+ 0.7366548776626587,
227
+ 0.28038161993026733,
228
+ 0.3271234333515167,
229
+ 0.6884999871253967,
230
+ 0.8387558460235596,
231
+ 0.25787827372550964,
232
+ 0.31810709834098816,
233
+ 0.8252443671226501,
234
+ 0.8010784387588501,
235
+ 0.8797050714492798,
236
+ 0.7847591042518616,
237
+ 0.34978875517845154
238
+ ]
239
+ },
240
+ "real_abs_autocorr_lag1": [
241
+ 0.9877822900516616,
242
+ 0.9904621143737528,
243
+ 0.9903800119819881,
244
+ 0.9930759612361029,
245
+ 0.9845821075186362,
246
+ 0.9896354175342961,
247
+ 0.9931532651062824,
248
+ 0.9829303354664611,
249
+ 0.9930693669171041,
250
+ 0.9931078152703652,
251
+ 0.9904251617511753,
252
+ 0.6185269686957802,
253
+ 0.8362734986319137,
254
+ 0.5482885824237015,
255
+ NaN,
256
+ NaN,
257
+ 0.9654974645960943,
258
+ 0.97213405366258,
259
+ 0.9692365984881656,
260
+ 0.9926912520502599,
261
+ 0.9674985063407219
262
+ ],
263
+ "generated_abs_autocorr_lag1": [
264
+ 0.1746959775402206,
265
+ 0.43644415022477073,
266
+ 0.41653727634920873,
267
+ 0.39331193400821807,
268
+ 0.3961014770471717,
269
+ 0.415143957195966,
270
+ 0.4382806229222112,
271
+ 0.4055025998367542,
272
+ 0.447771283348962,
273
+ 0.4445783266695078,
274
+ 0.35184174376495303,
275
+ 0.018016469082830094,
276
+ 0.02088415221893785,
277
+ 0.07700486234712257,
278
+ 0.016226235857376443,
279
+ 0.06005403603427778,
280
+ 0.40116921177080017,
281
+ 0.41265325284010024,
282
+ 0.37372374982861106,
283
+ 0.48624188538979995,
284
+ 0.008822063729828335
285
+ ],
286
+ "training_history": [
287
+ {
288
+ "step": 1000,
289
+ "train_loss": 0.07410214841365814,
290
+ "validation_denoising_loss": 0.05256535982092222
291
+ },
292
+ {
293
+ "step": 2000,
294
+ "train_loss": 0.047588951885700226,
295
+ "validation_denoising_loss": 0.062365236381689705
296
+ },
297
+ {
298
+ "step": 3000,
299
+ "train_loss": 0.04786738008260727,
300
+ "validation_denoising_loss": 0.08261810739835103
301
+ },
302
+ {
303
+ "step": 4000,
304
+ "train_loss": 0.030978742986917496,
305
+ "validation_denoising_loss": 0.09854291876157124
306
+ },
307
+ {
308
+ "step": 5000,
309
+ "train_loss": 0.017614271491765976,
310
+ "validation_denoising_loss": 0.11429651578267415
311
+ },
312
+ {
313
+ "step": 6000,
314
+ "train_loss": 0.022595474496483803,
315
+ "validation_denoising_loss": 0.11743361999591191
316
+ },
317
+ {
318
+ "step": 7000,
319
+ "train_loss": 0.0207576435059309,
320
+ "validation_denoising_loss": 0.12093404183785121
321
+ },
322
+ {
323
+ "step": 8000,
324
+ "train_loss": 0.018115710467100143,
325
+ "validation_denoising_loss": 0.11552448819080989
326
+ }
327
+ ],
328
+ "best_validation_denoising_loss": 0.05256535982092222,
329
+ "final_step": 8000
330
+ }
generated_corr.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eef2ab0e1c02f988cd8df95ab2fa1005a06a62d9ca994c776ef2781c6cc97deb
3
+ size 3656
normalization_stats.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mean": [
3
+ 990.173828125,
4
+ 12.050569534301758,
5
+ 286.0191650390625,
6
+ 5.857066631317139,
7
+ 69.040283203125,
8
+ 15.657519340515137,
9
+ 10.033563613891602,
10
+ 5.6238627433776855,
11
+ 6.337794303894043,
12
+ 10.143278121948242,
13
+ 1205.724609375,
14
+ 1.949593424797058,
15
+ 3.6914103031158447,
16
+ 174.6820068359375,
17
+ 0.013888888992369175,
18
+ 26.4885196685791,
19
+ 154.14207458496094,
20
+ 305.16058349609375,
21
+ 362.73028564453125,
22
+ 22.88827133178711,
23
+ 412.946533203125
24
+ ],
25
+ "std": [
26
+ 8.76413345336914,
27
+ 7.477471828460693,
28
+ 7.662539482116699,
29
+ 6.227657318115234,
30
+ 19.218950271606445,
31
+ 7.9162445068359375,
32
+ 4.232591152191162,
33
+ 5.859492301940918,
34
+ 2.6961469650268555,
35
+ 4.296774864196777,
36
+ 36.82501983642578,
37
+ 48.737728118896484,
38
+ 2.5588161945343018,
39
+ 86.3623046875,
40
+ 0.13708758354187012,
41
+ 112.74883270263672,
42
+ 232.32508850097656,
43
+ 455.26654052734375,
44
+ 619.0176391601562,
45
+ 7.801519393920898,
46
+ 359.27044677734375
47
+ ]
48
+ }
preprocess_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "source_dataset": "Duyu/Time-Series-Forecasting-Benchmark-Datasets",
3
+ "source_file": "Weather.csv",
4
+ "window_length": 256,
5
+ "stride": 4,
6
+ "num_channels": 21,
7
+ "channel_names": [
8
+ "feature_00",
9
+ "feature_01",
10
+ "feature_02",
11
+ "feature_03",
12
+ "feature_04",
13
+ "feature_05",
14
+ "feature_06",
15
+ "feature_07",
16
+ "feature_08",
17
+ "feature_09",
18
+ "feature_10",
19
+ "feature_11",
20
+ "feature_12",
21
+ "feature_13",
22
+ "feature_14",
23
+ "feature_15",
24
+ "feature_16",
25
+ "feature_17",
26
+ "feature_18",
27
+ "feature_19",
28
+ "feature_20"
29
+ ],
30
+ "split": "time_ordered_80_10_10",
31
+ "normalization": "zscore_fit_on_train_split",
32
+ "model_input": "concat(noisy_x, observed_x, observed_mask)",
33
+ "model_output": "predicted_noise"
34
+ }
real_corr.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e3c20fd4d9e4d23aff65ab07372d52a2031128564f1720048cd01f8b9efafa0
3
+ size 3656
sample_future_conditioned_z.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e1bfc88568b4092a63b128c550e6c480dc7a0afe917acd809976b0aae31e39e
3
+ size 516224
sample_generated_z.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63587114eaa4b9b4b73e36e64503961d8c512c53e6ffb72b2e4ed2f0d943b824
3
+ size 516224
sample_plots/future_mask_sample.png ADDED

Git LFS Details

  • SHA256: 9c058df052d77ae2b4194c96f773a23e08dfc085c793e8abde1a4898c374aac0
  • Pointer size: 131 Bytes
  • Size of remote file: 285 kB
sample_plots/real_vs_generated.png ADDED

Git LFS Details

  • SHA256: c35587f113ed1010a10292fda39d1a1d2862dbbe836a223a4af8956e86253de1
  • Pointer size: 131 Bytes
  • Size of remote file: 381 kB
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDPMScheduler",
3
+ "_diffusers_version": "0.37.1",
4
+ "beta_end": 0.02,
5
+ "beta_schedule": "squaredcos_cap_v2",
6
+ "beta_start": 0.0001,
7
+ "clip_sample": true,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "steps_offset": 0,
15
+ "thresholding": false,
16
+ "timestep_spacing": "leading",
17
+ "trained_betas": null,
18
+ "variance_type": "fixed_small"
19
+ }
training_config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_repo": "Duyu/Time-Series-Forecasting-Benchmark-Datasets",
3
+ "dataset_file": "Weather.csv",
4
+ "model_repo_id": "kyLELEng/weather-scenario-diffusion-1d",
5
+ "output_dir": "/tmp/weather-scenario-diffusion-1d",
6
+ "window_length": 256,
7
+ "stride": 4,
8
+ "max_train_steps": 8000,
9
+ "train_batch_size": 128,
10
+ "eval_batch_size": 128,
11
+ "num_workers": 8,
12
+ "learning_rate": 0.0002,
13
+ "weight_decay": 0.01,
14
+ "grad_clip_norm": 1.0,
15
+ "num_train_timesteps": 1000,
16
+ "eval_every": 1000,
17
+ "save_every": 2000,
18
+ "num_eval_batches": 12,
19
+ "sample_inference_steps": 80,
20
+ "sample_count": 24,
21
+ "mixed_precision": "bf16",
22
+ "seed": 42,
23
+ "model_size": "large",
24
+ "smoke_test": false
25
+ }