sqfoo commited on
Commit
d044167
·
1 Parent(s): 982bdec

Trying to resolve

Browse files
__pycache__/utilspp.cpython-312.pyc ADDED
Binary file (9.05 kB). View file
 
app.py CHANGED
@@ -4,8 +4,7 @@ import gradio as gr
4
 
5
  from stldm import InferenceHub
6
  from stldm.config import STLDM_HKO
7
- from data.dutils import resize
8
- from utilspp import gradio_visualize, gradio_gif
9
 
10
  def nowcasting(file, cfg_str, ensemble_no):
11
  # Model Setup
 
4
 
5
  from stldm import InferenceHub
6
  from stldm.config import STLDM_HKO
7
+ from utilspp import resize, gradio_gif
 
8
 
9
  def nowcasting(file, cfg_str, ensemble_no):
10
  # Model Setup
requirements.txt CHANGED
@@ -1,147 +1,64 @@
1
- absl-py==2.0.0
2
- antlr4-python3-runtime==4.9.3
3
- anyio==4.12.0
4
- argon2-cffi==25.1.0
5
- argon2-cffi-bindings==25.1.0
6
- arrow==1.4.0
7
- asttokens==3.0.1
8
- async-lru==2.0.5
9
- attrs==25.4.0
10
- babel==2.17.0
11
- beautifulsoup4==4.14.3
12
- bleach==6.2.0
13
- cachetools==5.3.2
14
- certifi==2023.11.17
15
- cffi==2.0.0
16
- charset-normalizer==3.3.2
17
- comm==0.2.3
18
- contourpy==1.3.0
19
  cycler==0.12.1
20
- debugpy==1.8.17
21
- decorator==5.2.1
22
- defusedxml==0.7.1
23
- einops==0.8.1
24
- exceptiongroup==1.3.1
25
- executing==2.2.1
26
- fastjsonschema==2.21.2
27
- fonttools==4.45.0
28
- fqdn==1.5.1
29
- google-auth==2.23.4
30
- google-auth-oauthlib==0.4.6
31
- grpcio
32
  h11==0.16.0
33
- h5py==3.7.0
34
  httpcore==1.0.9
35
  httpx==0.28.1
36
- idna==3.4
37
- imageio==2.33.0
38
- importlib-metadata==6.8.0
39
- importlib_resources==6.5.2
40
- ipykernel==6.31.0
41
- ipython==8.18.1
42
- ipywidgets==8.1.8
43
- isoduration==20.11.0
44
- jedi==0.19.2
45
  Jinja2==3.1.6
46
- joblib==1.3.2
47
- json5==0.12.1
48
- jsonpointer==3.0.0
49
- jsonschema==4.25.1
50
- jsonschema-specifications==2025.9.1
51
- jupyter==1.1.1
52
- jupyter-console==6.6.3
53
- jupyter-events==0.12.0
54
- jupyter-lsp==2.3.0
55
- jupyter_client==8.6.3
56
- jupyter_core==5.8.1
57
- jupyter_server==2.17.0
58
- jupyter_server_terminals==0.5.3
59
- jupyterlab==4.5.0
60
- jupyterlab_pygments==0.3.0
61
- jupyterlab_server==2.28.0
62
- jupyterlab_widgets==3.0.16
63
- kiwisolver==1.4.5
64
- lark==1.3.1
65
- lpips==0.1.4
66
- Markdown==3.5.1
67
- MarkupSafe==2.1.3
68
- matplotlib==3.9.4
69
- matplotlib-inline==0.2.1
70
- mistune==3.1.4
71
- nbclient==0.10.2
72
- nbconvert==7.16.6
73
- nbformat==5.10.4
74
- nest-asyncio==1.6.0
75
- networkx==3.2.1
76
- notebook==7.5.0
77
- notebook_shim==0.2.4
78
- numpy==1.24.4
79
- oauthlib==3.2.2
80
- omegaconf==2.3.0
81
- opencv-python==4.8.0.74
82
- overrides==7.7.0
83
- packaging==23.2
84
- pandas==1.4.3
85
- pandocfilters==1.5.1
86
- parso==0.8.5
87
- pexpect==4.9.0
88
- Pillow==10.1.0
89
- platformdirs==4.4.0
90
- prometheus_client==0.23.1
91
- prompt_toolkit==3.0.52
92
- protobuf==3.19.6
93
- psutil==7.1.3
94
- ptyprocess==0.7.0
95
- pure_eval==0.2.3
96
- pyasn1==0.5.1
97
- pyasn1-modules==0.3.0
98
- pycparser==2.23
99
  Pygments==2.19.2
100
- pyparsing==3.1.1
101
- python-dateutil==2.8.2
102
- python-json-logger==4.0.0
103
- pytz==2023.3.post1
104
- PyWavelets==1.5.0
105
- PyYAML==6.0
106
- pyzmq==27.1.0
107
- referencing==0.36.2
108
- requests==2.31.0
109
- requests-oauthlib==1.3.1
110
- rfc3339-validator==0.1.4
111
- rfc3986-validator==0.1.1
112
- rfc3987-syntax==1.1.0
113
- rpds-py==0.27.1
114
- rsa==4.9
115
- SciencePlots==2.2.0
116
- scikit-image==0.19.3
117
- scikit-learn==1.1.2
118
- scipy==1.9.1
119
- Send2Trash==1.8.3
120
- six==1.16.0
121
- soupsieve==2.8
122
- stack-data==0.6.3
123
- tensorboard==2.9.0
124
- tensorboard-data-server==0.6.1
125
- tensorboard-plugin-wit==1.8.1
126
- terminado==0.18.1
127
- threadpoolctl==3.2.0
128
- tifffile==2023.9.26
129
- tinycss2==1.4.0
130
- tomli==2.3.0
131
- torch==1.12.1+cu116
132
- torchmetrics==0.11.0
133
- torchvision==0.13.1+cu116
134
- tornado==6.5.2
135
- tqdm==4.66.1
136
- traitlets==5.14.3
137
- typing_extensions==4.8.0
138
- tzdata==2025.2
139
- uri-template==1.3.0
140
- urllib3==2.1.0
141
- wcwidth==0.2.14
142
- webcolors==24.11.1
143
- webencodings==0.5.1
144
- websocket-client==1.9.0
145
- Werkzeug==3.0.1
146
- widgetsnbextension==4.0.15
147
- zipp==3.17.0
 
1
+ aiofiles==24.1.0
2
+ annotated-doc==0.0.4
3
+ annotated-types==0.7.0
4
+ anyio==4.12.1
5
+ brotli==1.2.0
6
+ certifi==2026.1.4
7
+ click==8.3.1
8
+ contourpy==1.3.3
 
 
 
 
 
 
 
 
 
 
9
  cycler==0.12.1
10
+ einops==0.8.2
11
+ fastapi==0.129.0
12
+ ffmpy==1.0.0
13
+ filelock==3.24.3
14
+ fonttools==4.61.1
15
+ fsspec==2026.2.0
16
+ gradio==6.6.0
17
+ gradio_client==2.1.0
18
+ groovy==0.1.2
 
 
 
19
  h11==0.16.0
20
+ hf-xet==1.2.0
21
  httpcore==1.0.9
22
  httpx==0.28.1
23
+ huggingface_hub==1.4.1
24
+ idna==3.11
 
 
 
 
 
 
 
25
  Jinja2==3.1.6
26
+ kiwisolver==1.4.9
27
+ markdown-it-py==4.0.0
28
+ MarkupSafe==3.0.3
29
+ matplotlib==3.10.8
30
+ mdurl==0.1.2
31
+ mpmath==1.3.0
32
+ networkx==3.6.1
33
+ numpy==1.26.4
34
+ orjson==3.11.7
35
+ packaging==26.0
36
+ pandas==3.0.1
37
+ pillow==12.1.1
38
+ pydantic==2.12.5
39
+ pydantic_core==2.41.5
40
+ pydub==0.25.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  Pygments==2.19.2
42
+ pyparsing==3.3.2
43
+ python-dateutil==2.9.0.post0
44
+ python-multipart==0.0.22
45
+ pytz==2025.2
46
+ PyYAML==6.0.3
47
+ rich==14.3.3
48
+ safehttpx==0.1.7
49
+ safetensors==0.7.0
50
+ SciencePlots @ git+https://github.com/garrettj403/SciencePlots@5521f3b8e6c2b15b174bbea82d6662e5bf2c0d7d
51
+ semantic-version==2.10.0
52
+ setuptools==41.2.0
53
+ shellingham==1.5.4
54
+ six==1.17.0
55
+ starlette==0.52.1
56
+ sympy==1.14.0
57
+ tomlkit==0.13.3
58
+ torch==2.2.0
59
+ tqdm==4.67.3
60
+ typer==0.24.0
61
+ typer-slim==0.24.0
62
+ typing-inspection==0.4.2
63
+ typing_extensions==4.15.0
64
+ uvicorn==0.41.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stldm/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (353 Bytes). View file
 
stldm/__pycache__/config.cpython-312.pyc ADDED
Binary file (1.24 kB). View file
 
stldm/__pycache__/inference.cpython-312.pyc ADDED
Binary file (4.9 kB). View file
 
stldm/__pycache__/modules.cpython-312.pyc ADDED
Binary file (8.26 kB). View file
 
stldm/__pycache__/simvpv2.cpython-312.pyc ADDED
Binary file (25.7 kB). View file
 
stldm/__pycache__/stldm.cpython-312.pyc ADDED
Binary file (33.1 kB). View file
 
stldm/__pycache__/stldm_hf.cpython-312.pyc ADDED
Binary file (33.4 kB). View file
 
stldm/__pycache__/stldm_spatial.cpython-312.pyc ADDED
Binary file (32.5 kB). View file
 
stldm/__pycache__/submodules.cpython-312.pyc ADDED
Binary file (26.2 kB). View file
 
stldm/stldm.py CHANGED
@@ -1,8 +1,17 @@
 
 
 
 
 
 
 
1
  import torch, random
2
  from torch import nn
3
  from einops import rearrange
4
 
5
  from stldm.submodules import *
 
 
6
 
7
  class Down_Block(nn.Module):
8
  def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
@@ -172,12 +181,6 @@ class LDM(nn.Module):
172
 
173
  return out
174
 
175
- # constants
176
- from collections import namedtuple
177
- from torch.cuda.amp import autocast
178
- import torch.nn.functional as F
179
- from einops import reduce
180
- from tqdm.auto import tqdm
181
 
182
  ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
183
 
@@ -583,7 +586,6 @@ class GaussianDiffusion(nn.Module):
583
  else:
584
  return pred
585
 
586
- from stldm.modules import SimVPV2_Model, VAE
587
  def model_setup(model_config, print_info=False, cfg_str=None):
588
  if print_info:
589
  print('Setup the model with considering temporal attention be (BHW, T, C) ... ...')
 
1
+ # constants
2
+ from collections import namedtuple
3
+ from torch.cuda.amp import autocast
4
+ import torch.nn.functional as F
5
+ from einops import reduce
6
+ from tqdm.auto import tqdm
7
+
8
  import torch, random
9
  from torch import nn
10
  from einops import rearrange
11
 
12
  from stldm.submodules import *
13
+ from stldm.modules import SimVPV2_Model, VAE
14
+
15
 
16
  class Down_Block(nn.Module):
17
  def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
 
181
 
182
  return out
183
 
 
 
 
 
 
 
184
 
185
  ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
186
 
 
586
  else:
587
  return pred
588
 
 
589
  def model_setup(model_config, print_info=False, cfg_str=None):
590
  if print_info:
591
  print('Setup the model with considering temporal attention be (BHW, T, C) ... ...')
stldm/stldm_hf.py CHANGED
@@ -1,8 +1,17 @@
 
 
 
 
 
 
 
1
  import torch, random
2
  from torch import nn
3
  from einops import rearrange
4
 
5
  from stldm.submodules import *
 
 
6
 
7
  class Down_Block(nn.Module):
8
  def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
@@ -172,12 +181,7 @@ class LDM(nn.Module):
172
 
173
  return out
174
 
175
- # constants
176
- from collections import namedtuple
177
- from torch.cuda.amp import autocast
178
- import torch.nn.functional as F
179
- from einops import reduce
180
- from tqdm.auto import tqdm
181
 
182
  ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
183
 
@@ -591,7 +595,6 @@ class GaussianDiffusion(
591
  else:
592
  return pred
593
 
594
- from stldm.modules import SimVPV2_Model, VAE
595
  def model_setup(model_config, print_info=False, cfg_str=None):
596
  if print_info:
597
  print('Setup the model with considering temporal attention be (BHW, T, C) ... ...')
 
1
+ # constants
2
+ from collections import namedtuple
3
+ from torch.cuda.amp import autocast
4
+ import torch.nn.functional as F
5
+ from einops import reduce
6
+ from tqdm.auto import tqdm
7
+
8
  import torch, random
9
  from torch import nn
10
  from einops import rearrange
11
 
12
  from stldm.submodules import *
13
+ from stldm.modules import SimVPV2_Model, VAE
14
+
15
 
16
  class Down_Block(nn.Module):
17
  def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
 
181
 
182
  return out
183
 
184
+
 
 
 
 
 
185
 
186
  ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
187
 
 
595
  else:
596
  return pred
597
 
 
598
  def model_setup(model_config, print_info=False, cfg_str=None):
599
  if print_info:
600
  print('Setup the model with considering temporal attention be (BHW, T, C) ... ...')
stldm/stldm_spatial.py CHANGED
@@ -1,8 +1,16 @@
 
 
 
 
 
 
 
1
  import torch, random
2
  from torch import nn
3
  from einops import rearrange
4
 
5
  from stldm.submodules import *
 
6
 
7
  class Down_Block(nn.Module):
8
  def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
@@ -152,12 +160,7 @@ class LDM(nn.Module):
152
  out = up_block(out, t, hids1.pop(), hids2.pop())
153
  return out
154
 
155
- # constants
156
- from collections import namedtuple
157
- from torch.cuda.amp import autocast
158
- import torch.nn.functional as F
159
- from einops import reduce
160
- from tqdm.auto import tqdm
161
 
162
  ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
163
 
@@ -562,7 +565,6 @@ class GaussianDiffusion(nn.Module):
562
  else:
563
  return pred
564
 
565
- from stldm.modules import SimVPV2_Model, VAE
566
  def model_setup(model_config, print_info=False, cfg_str=None):
567
  if print_info:
568
  print('Setup a Spatial diffusion with replacing a Temporal attention with Spatial attention')
 
1
+ # constants
2
+ from collections import namedtuple
3
+ from torch.cuda.amp import autocast
4
+ import torch.nn.functional as F
5
+ from einops import reduce
6
+ from tqdm.auto import tqdm
7
+
8
  import torch, random
9
  from torch import nn
10
  from einops import rearrange
11
 
12
  from stldm.submodules import *
13
+ from stldm.modules import SimVPV2_Model, VAE
14
 
15
  class Down_Block(nn.Module):
16
  def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
 
160
  out = up_block(out, t, hids1.pop(), hids2.pop())
161
  return out
162
 
163
+
 
 
 
 
 
164
 
165
  ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
166
 
 
565
  else:
566
  return pred
567
 
 
568
  def model_setup(model_config, print_info=False, cfg_str=None):
569
  if print_info:
570
  print('Setup a Spatial diffusion with replacing a Temporal attention with Spatial attention')
utilspp.py CHANGED
@@ -1,116 +1,17 @@
1
- import os
2
  import torch
3
- import numpy as np
4
- import lpips as lp
5
- import pandas as pd
6
- import torchmetrics
7
  import matplotlib.pyplot as plt
8
- from bisect import bisect_right
9
- import torchvision.transforms as T
10
- from torch import nn
11
 
12
  from matplotlib.colors import ListedColormap, BoundaryNorm
13
  from matplotlib.lines import Line2D
 
 
14
 
15
- from data import dutils
16
-
17
- # =======================================================================
18
- # Scheduler Helper Function
19
- # =======================================================================
20
-
21
- class SequentialLR(torch.optim.lr_scheduler._LRScheduler):
22
- """Receives the list of schedulers that is expected to be called sequentially during
23
- optimization process and milestone points that provides exact intervals to reflect
24
- which scheduler is supposed to be called at a given epoch.
25
-
26
- Args:
27
- schedulers (list): List of chained schedulers.
28
- milestones (list): List of integers that reflects milestone points.
29
-
30
- Example:
31
- >>> # Assuming optimizer uses lr = 1. for all groups
32
- >>> # lr = 0.1 if epoch == 0
33
- >>> # lr = 0.1 if epoch == 1
34
- >>> # lr = 0.9 if epoch == 2
35
- >>> # lr = 0.81 if epoch == 3
36
- >>> # lr = 0.729 if epoch == 4
37
- >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
38
- >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
39
- >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
40
- >>> for epoch in range(100):
41
- >>> train(...)
42
- >>> validate(...)
43
- >>> scheduler.step()
44
- """
45
-
46
- def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
47
- for scheduler_idx in range(1, len(schedulers)):
48
- if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
49
- raise ValueError(
50
- "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
51
- "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
52
- )
53
- if (len(milestones) != len(schedulers) - 1):
54
- raise ValueError(
55
- "Sequential Schedulers expects number of schedulers provided to be one more "
56
- "than the number of milestone points, but got number of schedulers {} and the "
57
- "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
58
- )
59
- self.optimizer = optimizer
60
- self._schedulers = schedulers
61
- self._milestones = milestones
62
- self.last_epoch = last_epoch + 1
63
-
64
- def step(self, ref=None):
65
- self.last_epoch += 1
66
- idx = bisect_right(self._milestones, self.last_epoch)
67
- if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
68
- self._schedulers[idx].step(0)
69
- else:
70
- # Check HERE
71
- if isinstance(self._schedulers[idx], torch.optim.lr_scheduler.ReduceLROnPlateau):
72
- self._schedulers[idx].step(ref)
73
- else:
74
- self._schedulers[idx].step()
75
-
76
- def state_dict(self):
77
- """Returns the state of the scheduler as a :class:`dict`.
78
-
79
- It contains an entry for every variable in self.__dict__ which
80
- is not the optimizer.
81
- The wrapped scheduler states will also be saved.
82
- """
83
- state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
84
- state_dict['_schedulers'] = [None] * len(self._schedulers)
85
-
86
- for idx, s in enumerate(self._schedulers):
87
- state_dict['_schedulers'][idx] = s.state_dict()
88
-
89
- return state_dict
90
-
91
- def load_state_dict(self, state_dict):
92
- """Loads the schedulers state.
93
-
94
- Args:
95
- state_dict (dict): scheduler state. Should be an object returned
96
- from a call to :meth:`state_dict`.
97
- """
98
- _schedulers = state_dict.pop('_schedulers')
99
- self.__dict__.update(state_dict)
100
- # Restore state_dict keys in order to prevent side effects
101
- # https://github.com/pytorch/pytorch/issues/32756
102
- state_dict['_schedulers'] = _schedulers
103
-
104
- for idx, s in enumerate(_schedulers):
105
- self._schedulers[idx].load_state_dict(s)
106
-
107
- def warmup_lambda(warmup_steps, min_lr_ratio=0.1):
108
- def ret_lambda(epoch):
109
- if epoch <= warmup_steps:
110
- return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps
111
- else:
112
- return 1.0
113
- return ret_lambda
114
 
115
  # =======================================================================
116
  # Utils in utils :)
@@ -131,281 +32,6 @@ def to_cpu_tensor(*args):
131
  return out[0]
132
  return out
133
 
134
- def merge_leading_dims(tensor, n=2):
135
- '''
136
- Merge the first N dimension of a tensor
137
- '''
138
- return tensor.reshape((-1, *tensor.shape[n:]))
139
-
140
- # =======================================================================
141
- # Model Preparation, saving & loading (copied from utils.py)
142
- # =======================================================================
143
- def build_model_name(model_type, model_config):
144
- '''
145
- Build the model name (without extension)
146
- '''
147
- model_name = model_type + '_'
148
- for k, v in model_config.items():
149
- model_name += k
150
- if type(v) is list or type(v) is tuple:
151
- model_name += '-'
152
- for i, item in enumerate(v):
153
- model_name += (str(item) if type(item) is not bool else '') + ('-' if i < len(v)-1 else '')
154
- else:
155
- model_name += (('-' + str(v)) if type(v) is not bool else '')
156
- model_name += '_'
157
- return model_name[:-1]
158
-
159
- def build_model_path(base_dir, dataset_type, model_type, timestamp=None):
160
- if timestamp is None:
161
- return os.path.join(base_dir, dataset_type, model_type)
162
- elif timestamp == True:
163
- return os.path.join(base_dir, dataset_type, model_type, pd.Timestamp.now().strftime('%Y%m%d%H%M%S'))
164
- return os.path.join(base_dir, dataset_type, model_type, timestamp)
165
-
166
- # =======================================================================
167
- # Preprocess Function for Loading HKO-7 dataset
168
- # =======================================================================
169
-
170
- def hko7_preprocess(x_seq, x_mask, dt_clip, args):
171
- resize = args.resize if 'resize' in args else x_seq.shape[-1]
172
- seq_len = args.seq_len if 'seq_len' in args else 5
173
-
174
- # post-process on HKO-10
175
- x_seq = x_seq.transpose((1, 0, 2, 3, 4)) / 255. # => (batch_size, seq_length, 1, 480, 480)
176
- if 'scale' in args and args.scale == 'non-linear':
177
- x_seq = dutils.linear_to_nonlinear_batched(x_seq, dt_clip)
178
- else:
179
- x_seq = dutils.nonlinear_to_linear_batched(x_seq, dt_clip)
180
-
181
- b, t, c, h, w = x_seq.shape
182
- assert c == 1, f'# channels ({c}) != 1'
183
-
184
- # resize (downsample) the images if necessary
185
- x_seq = torch.Tensor(x_seq).float().reshape((b*t, c, h, w))
186
- if resize != h:
187
- tform = T.Compose([
188
- T.ToPILImage(),
189
- T.Resize(resize),
190
- T.ToTensor(),
191
- ])
192
- else:
193
- tform = T.Compose([])
194
-
195
- x_seq = torch.stack([tform(x_frame) for x_frame in x_seq], dim=0)
196
- x_seq = x_seq.reshape((b, t, c, resize, resize))
197
-
198
- x, y = x_seq[:, :seq_len], x_seq[:, seq_len:]
199
- return x, y
200
-
201
- # =======================================================================
202
- # Evaluation Metrics-Related
203
- # =======================================================================
204
-
205
- mae = lambda *args: torch.nn.functional.l1_loss(*args).cpu().detach().numpy()
206
- mse = lambda *args: torch.nn.functional.mse_loss(*args).cpu().detach().numpy()
207
-
208
- def ssim(y_pred, y):
209
- y, y_pred = to_cpu_tensor(y, y_pred)
210
- b, t, c, h, w = y.shape
211
- y = y.reshape((b*t, c, h, w))
212
- y_pred = y_pred.reshape((b*t, c, h, w))
213
- # to further ensure any of the input is not negative
214
- y = torch.clamp(y, 0, 1)
215
- y_pred = torch.clamp(y_pred, 0, 1)
216
- return torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=1.0)(y_pred, y)
217
-
218
- def psnr(y_pred, y):
219
- y, y_pred = to_cpu_tensor(y, y_pred)
220
- b, t, c, h, w = y.shape
221
- y = y.reshape((b*t, c, h, w))
222
- y_pred = y_pred.reshape((b*t, c, h, w))
223
- acc_score = 0
224
- for i in range(b*t):
225
- acc_score += torchmetrics.image.PeakSignalNoiseRatio(data_range=1.0)(y_pred[i], y[i]) / (b*t)
226
- return acc_score
227
-
228
- GLOBAL_LPIPS_OBJ = None # a static variable
229
- def lpips64(y_pred, y, net='vgg'):
230
- # convert the image range into [-1, 1], assuming the input range to be [0, 1]
231
- y = merge_leading_dims(y)
232
- y_pred = merge_leading_dims(y_pred)
233
-
234
- y = torch.nn.functional.interpolate(y, (64, 64), mode='bicubic').clamp(0,1)
235
- y_pred = torch.nn.functional.interpolate(y_pred, (64, 64), mode='bicubic').clamp(0,1)
236
-
237
- y = (2 * y - 1)
238
- y_pred = (2 * y_pred - 1)
239
- global GLOBAL_LPIPS_OBJ
240
- if GLOBAL_LPIPS_OBJ is None:
241
- GLOBAL_LPIPS_OBJ = lp.LPIPS(net=net).to(y.device)
242
- return GLOBAL_LPIPS_OBJ(y_pred, y).mean()
243
-
244
- def tfpn(y_pred, y, threshold, radius=1):
245
- '''
246
- convert to cpu, and merge the first two dimensions
247
- '''
248
- y = merge_leading_dims(y)
249
- y_pred = merge_leading_dims(y_pred)
250
- with torch.no_grad():
251
- if radius > 1:
252
- pool = nn.MaxPool2d(radius)
253
- y = pool(y)
254
- y_pred = pool(y_pred)
255
- y = torch.where(y >= threshold, 1, 0)
256
- y_pred = torch.where(y_pred >= threshold, 1, 0)
257
- mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold)
258
- (tn, fp), (fn, tp) = to_cpu_tensor(mat)
259
- return tp, tn, fp, fn
260
-
261
- def tfpn_pool(y_pred, y, threshold, radius):
262
- y_pred = merge_leading_dims(y_pred)
263
- y = merge_leading_dims(y)
264
- pool = nn.MaxPool2d(radius, stride=radius//4 if radius//4 > 0 else radius)
265
- with torch.no_grad():
266
- y = torch.where(y>=threshold, 1, 0).float()
267
- y_pred = torch.where(y_pred>=threshold, 1, 0).float()
268
- y = pool(y)
269
- y_pred = pool(y_pred)
270
- mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold)
271
- (tn, fp), (fn, tp) = to_cpu_tensor(mat)
272
- return tp, tn, fp, fn
273
-
274
- def csi(tp, tn, fp, fn):
275
- '''Critical Success Index. The larger the better.'''
276
- if (tp + fn + fp) < 1e-7:
277
- return 0.
278
- return tp / (tp + fn + fp)
279
-
280
- def hss(tp, tn, fp, fn):
281
- '''Heidke Skill Score. (-inf, 1]. Larger better.'''
282
- if (tp+fn)*(fn+tn) + (tp+fp)*(fp+tn) == 0:
283
- return 0.
284
- return 2 * (tp*tn - fp*fn) / ((tp+fn)*(fn+tn) + (tp+fp)*(fp+tn))
285
-
286
- # =======================================================================
287
- # Data Visualization
288
- # =======================================================================
289
-
290
- def torch_visualize(sequences, savedir=None, horizontal=10, vmin=0, vmax=1):
291
- '''
292
- input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
293
- C is assumed to be 1 and squeezed
294
- If batch > 1, only the first sequence will be printed
295
- '''
296
- # First pass: compute the vertical height and convert to proper format
297
- vertical = 0
298
- display_texts = []
299
- if (type(sequences) is dict):
300
- temp = []
301
- for k, v in sequences.items():
302
- vertical += int(np.ceil(v.shape[1] / horizontal))
303
- temp.append(v)
304
- display_texts.append(k)
305
- sequences = temp
306
- else:
307
- for i, sequence in enumerate(sequences):
308
- vertical += int(np.ceil(sequence.shape[1] / horizontal))
309
- display_texts.append(f'Item {i+1}')
310
- sequences = to_cpu_tensor(*sequences)
311
- # Plot the sequences
312
- j = 0
313
- fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
314
- plt.setp(axes, xticks=[], yticks=[])
315
- for k, sequence in enumerate(sequences):
316
- # only take the first batch, now seq[0] is the temporal dim
317
- sequence = sequence[0].squeeze() # (T, H, W)
318
- axes[j, 0].set_ylabel(display_texts[k])
319
- for i, frame in enumerate(sequence):
320
- j_shift = j + i // horizontal
321
- i_shift = i % horizontal
322
- axes[j_shift, i_shift].imshow(frame, vmin=vmin, vmax=vmax, cmap='gray')
323
- j += int(np.ceil(sequence.shape[0] / horizontal))
324
- if savedir:
325
- plt.savefig(savedir + '' if savedir.endswith('.png') else '.png')
326
- plt.close()
327
- else:
328
- plt.show()
329
-
330
- """ Visualize function with colorbar and a line seprate input and output """
331
- def color_visualize(sequences, savedir='', horizontal=5, skip=1, ypos=0):
332
- '''
333
- input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
334
- C is assumed to be 1 and squeezed
335
- If batch > 1, only the first sequence will be printed
336
- '''
337
- plt.style.use(['science', 'no-latex'])
338
- VIL_COLORS = [[0, 0, 0],
339
- [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
340
- [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
341
- [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
342
- [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
343
- [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
344
- [0.9607843137254902, 0.9607843137254902, 0.0],
345
- [0.9294117647058824, 0.6745098039215687, 0.0],
346
- [0.9411764705882353, 0.43137254901960786, 0.0],
347
- [0.6274509803921569, 0.0, 0.0],
348
- [0.9058823529411765, 0.0, 1.0]]
349
-
350
- VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
351
-
352
- # First pass: compute the vertical height and convert to proper format
353
- vertical = 0
354
- display_texts = []
355
- if (type(sequences) is dict):
356
- temp = []
357
- for k, v in sequences.items():
358
- vertical += int(np.ceil(v.shape[1] / horizontal))
359
- temp.append(v)
360
- display_texts.append(k)
361
- sequences = temp
362
- else:
363
- for i, sequence in enumerate(sequences):
364
- vertical += int(np.ceil(sequence.shape[1] / horizontal))
365
- display_texts.append(f'Item {i+1}')
366
- sequences = to_cpu_tensor(*sequences)
367
- # Plot the sequences
368
- j = 0
369
- fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
370
- plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
371
- plt.setp(axes, xticks=[], yticks=[])
372
- for k, sequence in enumerate(sequences):
373
- # only take the first batch, now seq[0] is the temporal dim
374
- sequence = sequence[0].squeeze() # (T, H, W)
375
-
376
- ## =================
377
- # = labels of time =
378
- if k == 0:
379
- for i in range(len(sequence)):
380
- axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16)
381
- axes[j, i].xaxis.set_label_position('top')
382
- elif k == len(sequences)-1:
383
- for i in range(len(sequence)):
384
- axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16)
385
- axes[j, i].xaxis.set_label_position('bottom')
386
- ## =================
387
- axes[j, 0].set_ylabel(display_texts[k], fontsize=16)
388
- for i, frame in enumerate(sequence):
389
- j_shift = j + i // horizontal
390
- i_shift = i % horizontal
391
- im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
392
- norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
393
- j += int(np.ceil(sequence.shape[0] / horizontal))
394
-
395
- ## = plot splittin line =
396
- if ypos == 0:
397
- ypos = 1 - 1 / len(sequences) - 0.017
398
- fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444'))
399
- # color bar
400
- cax = fig.add_axes([1, 0.05, 0.02, 0.5])
401
- fig.colorbar(im, cax=cax)
402
- ## =================
403
- if savedir:
404
- plt.savefig(savedir + '' if len(savedir)>0 else 'out.png')
405
- plt.close()
406
- else:
407
- plt.show()
408
-
409
  from tempfile import NamedTemporaryFile
410
 
411
  """ Visualize function with colorbar and a line seprate input and output """
@@ -491,7 +117,6 @@ def gradio_visualize(sequences, horizontal=5, skip=1, ypos=0):
491
 
492
  return file_path
493
 
494
- import matplotlib.animation as animation
495
 
496
  def gradio_gif(sequences, T):
497
  '''
@@ -550,17 +175,4 @@ def gradio_gif(sequences, T):
550
  file_path = ff.name
551
 
552
  plt.close()
553
- return file_path
554
-
555
- # import matplotlib.pyplot as plt
556
- # import matplotlib.animation as animation
557
- # def make_gif(frames, save_path):
558
- # fig, ax = plt.subplots(figsize=(4,4))
559
- # im = ax.imshow(frames[0].squeeze(), cmap='gray', vmin=0, vmax=1, animated=True)
560
- # ax.set_axis_off()
561
-
562
- # def update(i):
563
- # im.set_array(frames[i].squeeze())
564
- # return im,
565
- # animation_fig =
566
- # animation_fig.save(f"./{save_path}.gif")
 
 
1
  import torch
 
 
 
 
2
  import matplotlib.pyplot as plt
3
+ import torch.nn.functional as F
 
 
4
 
5
  from matplotlib.colors import ListedColormap, BoundaryNorm
6
  from matplotlib.lines import Line2D
7
+ import matplotlib.animation as animation
8
+ import scienceplots
9
 
10
+ def resize(seq, size):
11
+ # seq shape : (B, T, 1, H, W)
12
+ seq = F.interpolate(seq.squeeze(dim=2), size=size, mode='bilinear', align_corners=False) # (B, T, H, W)
13
+ seq = seq.clamp(0,1)
14
+ return seq.unsqueeze(2) # (B, T, 1, H, W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # =======================================================================
17
  # Utils in utils :)
 
32
  return out[0]
33
  return out
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  from tempfile import NamedTemporaryFile
36
 
37
  """ Visualize function with colorbar and a line seprate input and output """
 
117
 
118
  return file_path
119
 
 
120
 
121
  def gradio_gif(sequences, T):
122
  '''
 
175
  file_path = ff.name
176
 
177
  plt.close()
178
+ return file_path