sqfoo commited on
Commit
dc3d7a9
·
1 Parent(s): a6657db

Made Improvement

Browse files
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
__pycache__/utilspp.cpython-312.pyc CHANGED
Binary files a/__pycache__/utilspp.cpython-312.pyc and b/__pycache__/utilspp.cpython-312.pyc differ
 
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
 
5
  from stldm import InferenceHub
6
  from stldm.config import STLDM_HKO
7
- from utilspp import resize, gradio_gif
8
 
9
  def nowcasting(file, cfg_str, ensemble_no):
10
  # Model Setup
@@ -30,30 +30,38 @@ def nowcasting(file, cfg_str, ensemble_no):
30
  raise ValueError("The input should have at least 5 frames for STLDM to predict")
31
  x = x[0, -5:]
32
 
33
- out = {}
34
- for i in range(ensemble_no):
 
35
  y_pred = Forecastor(input_x=x, include_mu=False)
36
- out[f'Ensemble {i+1}'] = torch.cat((x, y_pred), dim=0)
37
 
 
38
  figure = gradio_gif(out, len(out['Ensemble 1']))
39
 
40
- return figure
41
 
42
 
43
 
44
  with gr.Blocks() as demo:
45
- gr.Markdown("# STLDM official demo for nowcasting")
46
  gr.Markdown("Please upload the radar sequences with **at least 5 frames** in the format of .npy file, and **STLDM** will predict the future 20 frames based on the past 5 frames.")
47
- gr.Markdown('Please refer to [paper](https://arxiv.org/abs/2512.21118) and [code](https://github.com/sqfoo/stldm_official) for more details about STLDM.')
 
48
 
 
49
  file_input = gr.File(label="Upload the input radar squences", file_types=[".npy"])
 
 
50
  cfg_str = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Classifier Free Guidance Scale")
51
  ensemble_no = gr.Slider(1, 10, value=2, step=1, label="How many ensemble predictions?")
52
 
53
- output = gr.Image(label="Nowcasting Results")
 
 
54
  btn = gr.Button("Forecast Now!")
55
- btn.click(fn=nowcasting, inputs=[file_input, cfg_str, ensemble_no], outputs=output)
56
 
57
  if __name__ == "__main__":
58
- demo.launch()
59
 
 
4
 
5
  from stldm import InferenceHub
6
  from stldm.config import STLDM_HKO
7
+ from utilspp import resize, gradio_gif, gradio_visualize
8
 
9
  def nowcasting(file, cfg_str, ensemble_no):
10
  # Model Setup
 
30
  raise ValueError("The input should have at least 5 frames for STLDM to predict")
31
  x = x[0, -5:]
32
 
33
+ y_pred, mu = Forecastor(input_x=x, include_mu=True)
34
+ out = {'Deterministic': mu, 'Ensemble 1': y_pred}
35
+ for i in range(1, ensemble_no):
36
  y_pred = Forecastor(input_x=x, include_mu=False)
37
+ out[f'Ensemble {i+1}'] = y_pred
38
 
39
+ past_frames = gradio_visualize(x)
40
  figure = gradio_gif(out, len(out['Ensemble 1']))
41
 
42
+ return past_frames, figure
43
 
44
 
45
 
46
  with gr.Blocks() as demo:
47
+ gr.Markdown("# STLDM Official Demo for **HKO-7** Nowcasting")
48
  gr.Markdown("Please upload the radar sequences with **at least 5 frames** in the format of .npy file, and **STLDM** will predict the future 20 frames based on the past 5 frames.")
49
+ gr.Markdown('**Paper** - [STLDM: Spatio-Temporal Latent Diffusion Model for Precipitation Nowcasting](https://arxiv.org/abs/2512.21118)')
50
+ gr.Markdown('**Code** - [https://github.com/sqfoo/stldm_official](https://github.com/sqfoo/stldm_official)')
51
 
52
+ gr.Markdown("## Input Frames")
53
  file_input = gr.File(label="Upload the input radar squences", file_types=[".npy"])
54
+
55
+ gr.Markdown("## Parameters")
56
  cfg_str = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Classifier Free Guidance Scale")
57
  ensemble_no = gr.Slider(1, 10, value=2, step=1, label="How many ensemble predictions?")
58
 
59
+ gr.Markdown("## Predictions")
60
+ input_frames = gr.Image(label="Past 5 frames")
61
+ prediction = gr.Image(label="Evolving Predictions")
62
  btn = gr.Button("Forecast Now!")
63
+ btn.click(fn=nowcasting, inputs=[file_input, cfg_str, ensemble_no], outputs=[input_frames, prediction])
64
 
65
  if __name__ == "__main__":
66
+ demo.launch(share=True)
67
 
stldm/__pycache__/inference.cpython-312.pyc CHANGED
Binary files a/stldm/__pycache__/inference.cpython-312.pyc and b/stldm/__pycache__/inference.cpython-312.pyc differ
 
stldm/__pycache__/stldm_hf.cpython-312.pyc CHANGED
Binary files a/stldm/__pycache__/stldm_hf.cpython-312.pyc and b/stldm/__pycache__/stldm_hf.cpython-312.pyc differ
 
stldm/inference.py CHANGED
@@ -84,9 +84,9 @@ class InferenceHub:
84
 
85
  input_x = input_x.to(self.model.device)
86
  if include_mu:
87
- y_pred, mu = self.model(input_x, includ_mu=include_mu)
88
  else:
89
- y_pred = self.model(input_x, includ_mu=include_mu)
90
  mu = None
91
 
92
  if mu is not None:
 
84
 
85
  input_x = input_x.to(self.model.device)
86
  if include_mu:
87
+ y_pred, mu = self.model(input_x, include_mu=include_mu)
88
  else:
89
+ y_pred = self.model(input_x, include_mu=include_mu)
90
  mu = None
91
 
92
  if mu is not None:
stldm/stldm_hf.py CHANGED
@@ -588,7 +588,7 @@ class GaussianDiffusion(
588
  return loss.mean()
589
 
590
  @torch.no_grad()
591
- def forward(self, input_x, include_mu=False, **kwargs):
592
  pred, mu = self.predict(input_x, compute_loss=False)
593
  if include_mu:
594
  return pred, mu
 
588
  return loss.mean()
589
 
590
  @torch.no_grad()
591
+ def forward(self, input_x, include_mu, **kwargs):
592
  pred, mu = self.predict(input_x, compute_loss=False)
593
  if include_mu:
594
  return pred, mu
utilspp.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import matplotlib.pyplot as plt
3
  import torch.nn.functional as F
4
 
@@ -34,75 +35,102 @@ def to_cpu_tensor(*args):
34
 
35
  from tempfile import NamedTemporaryFile
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  """ Visualize function with colorbar and a line seprate input and output """
38
- def gradio_visualize(sequences, horizontal=5, skip=1, ypos=0):
39
  '''
40
- input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W)
41
  C is assumed to be 1 and squeezed
42
  If batch > 1, only the first sequence will be printed
43
  '''
44
- plt.style.use(['science', 'no-latex'])
45
- VIL_COLORS = [[0, 0, 0],
46
- [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
47
- [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
48
- [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
49
- [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
50
- [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
51
- [0.9607843137254902, 0.9607843137254902, 0.0],
52
- [0.9294117647058824, 0.6745098039215687, 0.0],
53
- [0.9411764705882353, 0.43137254901960786, 0.0],
54
- [0.6274509803921569, 0.0, 0.0],
55
- [0.9058823529411765, 0.0, 1.0]]
56
-
57
- VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
58
-
59
- # First pass: compute the vertical height and convert to proper format
60
- vertical = 0
61
- display_texts = []
62
- if (type(sequences) is dict):
63
- temp = []
64
- for k, v in sequences.items():
65
- vertical += int(np.ceil(v.shape[1] / horizontal))
66
- temp.append(v)
67
- display_texts.append(k)
68
- sequences = temp
69
- else:
70
- for i, sequence in enumerate(sequences):
71
- vertical += int(np.ceil(sequence.shape[1] / horizontal))
72
- display_texts.append(f'Item {i+1}')
73
- sequences = to_cpu_tensor(*sequences)
74
- # Plot the sequences
75
- j = 0
76
- fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
77
  plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
78
- plt.setp(axes, xticks=[], yticks=[])
79
- for k, sequence in enumerate(sequences):
80
- # only take the first batch, now seq[0] is the temporal dim
81
- sequence = sequence.squeeze() # (T, H, W)
82
-
83
- ## =================
84
- # = labels of time =
85
- if k == 0:
86
- for i in range(len(sequence)):
87
- axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16)
88
- axes[j, i].xaxis.set_label_position('top')
89
- elif k == len(sequences)-1:
90
- for i in range(len(sequence)):
91
- axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16)
92
- axes[j, i].xaxis.set_label_position('bottom')
93
- ## =================
94
- axes[j, 0].set_ylabel(display_texts[k], fontsize=16)
95
- for i, frame in enumerate(sequence):
96
- j_shift = j + i // horizontal
97
- i_shift = i % horizontal
98
- im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
99
- norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
100
- j += int(np.ceil(sequence.shape[0] / horizontal))
 
 
 
 
 
 
 
 
101
 
102
- ## = plot splittin line =
103
- if ypos == 0:
104
- ypos = 1 - 1 / len(sequences) - 0.017
105
- fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # color bar
107
  cax = fig.add_axes([1, 0.05, 0.02, 0.5])
108
  fig.colorbar(im, cax=cax)
@@ -124,20 +152,20 @@ def gradio_gif(sequences, T):
124
  C is assumed to be 1 and squeezed
125
  If batch > 1, only the first sequence will be printed
126
  '''
127
- plt.style.use(['science', 'no-latex'])
128
- VIL_COLORS = [[0, 0, 0],
129
- [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
130
- [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
131
- [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
132
- [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
133
- [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
134
- [0.9607843137254902, 0.9607843137254902, 0.0],
135
- [0.9294117647058824, 0.6745098039215687, 0.0],
136
- [0.9411764705882353, 0.43137254901960786, 0.0],
137
- [0.6274509803921569, 0.0, 0.0],
138
- [0.9058823529411765, 0.0, 1.0]]
139
-
140
- VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
141
 
142
  horizontal = len(sequences)
143
  fig_size = 3
@@ -146,25 +174,25 @@ def gradio_gif(sequences, T):
146
  plt.setp(axes, xticks=[], yticks=[])
147
 
148
  if horizontal == 1:
149
- for i, sequence in enumerate(sequences.values()):
150
  axes.set_xticks([])
151
  axes.set_yticks([])
152
- axes.set_xlabel(f'Ensemble {i+1}', fontsize=12)
153
  frame = sequence[0].squeeze()
154
  im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
155
  norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
156
  else:
157
- for i, sequence in enumerate(sequences.values()):
158
  axes[i].set_xticks([])
159
  axes[i].set_yticks([])
160
- axes[i].set_xlabel(f'Ensemble {i+1}', fontsize=12)
161
  frame = sequence[0].squeeze()
162
  im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
163
  norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
164
 
165
  title = fig.suptitle('', y=0.9, x=0.505, fontsize=16) # Initialize an empty super title
166
 
167
- fig.colorbar(im)
168
 
169
  def animate(t):
170
  if horizontal == 1:
 
1
  import torch
2
+ import numpy as np
3
  import matplotlib.pyplot as plt
4
  import torch.nn.functional as F
5
 
 
35
 
36
  from tempfile import NamedTemporaryFile
37
 
38
+ plt.style.use(['science', 'no-latex'])
39
+ VIL_COLORS = [[0, 0, 0],
40
+ [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
41
+ [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
42
+ [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
43
+ [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
44
+ [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
45
+ [0.9607843137254902, 0.9607843137254902, 0.0],
46
+ [0.9294117647058824, 0.6745098039215687, 0.0],
47
+ [0.9411764705882353, 0.43137254901960786, 0.0],
48
+ [0.6274509803921569, 0.0, 0.0],
49
+ [0.9058823529411765, 0.0, 1.0]]
50
+
51
+ VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
52
+
53
  """ Visualize function with colorbar and a line seprate input and output """
54
+ def gradio_visualize(sequence):
55
  '''
56
+ input: sequences, a list/dict of numpy/torch arrays with shape (T, C, H, W)
57
  C is assumed to be 1 and squeezed
58
  If batch > 1, only the first sequence will be printed
59
  '''
60
+
61
+ fig_size = 3
62
+ fig, axes = plt.subplots(1, len(sequence), figsize=(fig_size*len(sequence), fig_size), tight_layout=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
64
+ plt.setp(axes, xticks=[], yticks=[])
65
+
66
+ for i, frame in enumerate(sequence):
67
+ axes[i].set_xticks([])
68
+ axes[i].set_yticks([])
69
+ axes[i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=12)
70
+ frame = frame.squeeze()
71
+ im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
72
+
73
+
74
+
75
+ # # First pass: compute the vertical height and convert to proper format
76
+ # vertical = 0
77
+ # display_texts = []
78
+ # if (type(sequences) is dict):
79
+ # temp = []
80
+ # for k, v in sequences.items():
81
+ # vertical += int(np.ceil(v.shape[1] / horizontal))
82
+ # temp.append(v)
83
+ # display_texts.append(k)
84
+ # sequences = temp
85
+ # else:
86
+ # for i, sequence in enumerate(sequences):
87
+ # vertical += int(np.ceil(sequence.shape[1] / horizontal))
88
+ # display_texts.append(f'Item {i+1}')
89
+ # sequences = to_cpu_tensor(*sequences)
90
+ # # Plot the sequences
91
+ # j = 0
92
+ # fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
93
+ # plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
94
+ # plt.setp(axes, xticks=[], yticks=[])
95
 
96
+ # if vertical == 1:
97
+ # for k, sequence in enumerate(sequences.values()):
98
+ # for i in range(len(sequence)):
99
+ # axes[i].set_xticks([])
100
+ # axes[i].set_yticks([])
101
+ # axes[i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=12)
102
+ # frame = sequence[i].squeeze()
103
+ # im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
104
+ # norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
105
+ # else:
106
+
107
+ # for k, sequence in enumerate(sequences):
108
+ # # only take the first batch, now seq[0] is the temporal dim
109
+ # sequence = sequence.squeeze() # (T, H, W)
110
+
111
+ # ## =================
112
+ # # = labels of time =
113
+ # if k == 0:
114
+ # for i in range(len(sequence)):
115
+ # axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16)
116
+ # axes[j, i].xaxis.set_label_position('top')
117
+ # elif k == len(sequences)-1:
118
+ # for i in range(len(sequence)):
119
+ # axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16)
120
+ # axes[j, i].xaxis.set_label_position('bottom')
121
+ # ## =================
122
+ # axes[j, 0].set_ylabel(display_texts[k], fontsize=16)
123
+ # for i, frame in enumerate(sequence):
124
+ # j_shift = j + i // horizontal
125
+ # i_shift = i % horizontal
126
+ # im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
127
+ # norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
128
+ # j += int(np.ceil(sequence.shape[0] / horizontal))
129
+
130
+ # # ## = plot splittin line =
131
+ # # if ypos == 0:
132
+ # # ypos = 1 - 1 / len(sequences) - 0.017
133
+ # # fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444'))
134
  # color bar
135
  cax = fig.add_axes([1, 0.05, 0.02, 0.5])
136
  fig.colorbar(im, cax=cax)
 
152
  C is assumed to be 1 and squeezed
153
  If batch > 1, only the first sequence will be printed
154
  '''
155
+ # plt.style.use(['science', 'no-latex'])
156
+ # VIL_COLORS = [[0, 0, 0],
157
+ # [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
158
+ # [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
159
+ # [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
160
+ # [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
161
+ # [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
162
+ # [0.9607843137254902, 0.9607843137254902, 0.0],
163
+ # [0.9294117647058824, 0.6745098039215687, 0.0],
164
+ # [0.9411764705882353, 0.43137254901960786, 0.0],
165
+ # [0.6274509803921569, 0.0, 0.0],
166
+ # [0.9058823529411765, 0.0, 1.0]]
167
+
168
+ # VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]
169
 
170
  horizontal = len(sequences)
171
  fig_size = 3
 
174
  plt.setp(axes, xticks=[], yticks=[])
175
 
176
  if horizontal == 1:
177
+ for i, (key, sequence) in enumerate(sequences.items()):
178
  axes.set_xticks([])
179
  axes.set_yticks([])
180
+ axes.set_xlabel(f'{key}', fontsize=12)
181
  frame = sequence[0].squeeze()
182
  im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
183
  norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
184
  else:
185
+ for i, (key, sequence) in enumerate(sequences.items()):
186
  axes[i].set_xticks([])
187
  axes[i].set_yticks([])
188
+ axes[i].set_xlabel(f'{key}', fontsize=12)
189
  frame = sequence[0].squeeze()
190
  im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
191
  norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True)
192
 
193
  title = fig.suptitle('', y=0.9, x=0.505, fontsize=16) # Initialize an empty super title
194
 
195
+ # fig.colorbar(im)
196
 
197
  def animate(t):
198
  if horizontal == 1: