firstkillday commited on
Commit
22c6f75
·
verified ·
1 Parent(s): 7a33832

Upload src/utils/util.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/utils/util.py +146 -145
src/utils/util.py CHANGED
@@ -1,145 +1,146 @@
1
- import importlib
2
- import os
3
- import os.path as osp
4
- import shutil
5
- import sys
6
- from pathlib import Path
7
-
8
- import av
9
- import numpy as np
10
- import torch
11
- import torchvision
12
- from einops import rearrange
13
- from PIL import Image
14
-
15
-
16
- def seed_everything(seed):
17
- import random
18
-
19
- import numpy as np
20
-
21
- torch.manual_seed(seed)
22
- torch.cuda.manual_seed_all(seed)
23
- np.random.seed(seed % (2**32))
24
- random.seed(seed)
25
-
26
-
27
- def import_filename(filename):
28
- spec = importlib.util.spec_from_file_location("mymodule", filename)
29
- module = importlib.util.module_from_spec(spec)
30
- sys.modules[spec.name] = module
31
- spec.loader.exec_module(module)
32
- return module
33
-
34
-
35
- def delete_additional_ckpt(base_path, num_keep):
36
- dirs = []
37
- for d in os.listdir(base_path):
38
- if d.startswith("checkpoint-"):
39
- dirs.append(d)
40
- num_tot = len(dirs)
41
- if num_tot <= num_keep:
42
- return
43
- # ensure ckpt is sorted and delete the ealier!
44
- del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45
- for d in del_dirs:
46
- path_to_dir = osp.join(base_path, d)
47
- if osp.exists(path_to_dir):
48
- shutil.rmtree(path_to_dir)
49
-
50
-
51
- def save_videos_from_pil(pil_images, path, fps=8, audio_path=None):
52
- import av
53
-
54
- save_fmt = Path(path).suffix
55
- os.makedirs(os.path.dirname(path), exist_ok=True)
56
- width, height = pil_images[0].size
57
-
58
- if save_fmt == ".mp4":
59
- codec = "libx264"
60
- container = av.open(path, "w")
61
- stream = container.add_stream(codec, rate=fps)
62
-
63
- stream.width = width
64
- stream.height = height
65
-
66
- for pil_image in pil_images:
67
- # pil_image = Image.fromarray(image_arr).convert("RGB")
68
- av_frame = av.VideoFrame.from_image(pil_image)
69
- container.mux(stream.encode(av_frame))
70
- container.mux(stream.encode())
71
- container.close()
72
-
73
- elif save_fmt == ".gif":
74
- pil_images[0].save(
75
- fp=path,
76
- format="GIF",
77
- append_images=pil_images[1:],
78
- save_all=True,
79
- duration=(1 / fps * 1000),
80
- loop=0,
81
- )
82
- else:
83
- raise ValueError("Unsupported file type. Use .mp4 or .gif.")
84
-
85
-
86
- def save_videos_grid(videos: torch.Tensor, path: str, audio_path=None, rescale=False, n_rows=6, fps=8):
87
- videos = rearrange(videos, "b c t h w -> t b c h w")
88
- height, width = videos.shape[-2:]
89
- outputs = []
90
-
91
- for x in videos:
92
- x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
93
- x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
94
- if rescale:
95
- x = (x + 1.0) / 2.0 # -1,1 -> 0,1
96
- x = (x * 255).numpy().astype(np.uint8)
97
- x = Image.fromarray(x)
98
-
99
- outputs.append(x)
100
-
101
- os.makedirs(os.path.dirname(path), exist_ok=True)
102
-
103
- save_videos_from_pil(outputs, path, fps, audio_path=audio_path)
104
-
105
-
106
- def save_video2imgs(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
107
- videos = rearrange(videos, "b c t h w -> t b c h w")
108
- height, width = videos.shape[-2:]
109
-
110
- os.makedirs(os.path.dirname(path), exist_ok=True)
111
-
112
- for i, x in enumerate(videos):
113
- x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
114
- x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
115
- if rescale:
116
- x = (x + 1.0) / 2.0 # -1,1 -> 0,1
117
- x = (x * 255).numpy().astype(np.uint8)
118
- x = Image.fromarray(x)
119
- img_name = osp.join(path, f"{i}.png")
120
- x.save(img_name)
121
-
122
-
123
- def read_frames(video_path):
124
- container = av.open(video_path)
125
-
126
- video_stream = next(s for s in container.streams if s.type == "video")
127
- frames = []
128
- for packet in container.demux(video_stream):
129
- for frame in packet.decode():
130
- image = Image.frombytes(
131
- "RGB",
132
- (frame.width, frame.height),
133
- frame.to_rgb().to_ndarray(),
134
- )
135
- frames.append(image)
136
-
137
- return frames
138
-
139
-
140
- def get_fps(video_path):
141
- container = av.open(video_path)
142
- video_stream = next(s for s in container.streams if s.type == "video")
143
- fps = video_stream.average_rate
144
- container.close()
145
- return fps
 
 
1
+ import importlib
2
+ import os
3
+ import os.path as osp
4
+ import shutil
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import av
9
+ import numpy as np
10
+ import torch
11
+ import torchvision
12
+ from einops import rearrange
13
+ from PIL import Image
14
+
15
+
16
+ def seed_everything(seed):
17
+ import random
18
+
19
+ import numpy as np
20
+
21
+ torch.manual_seed(seed)
22
+ if torch.cuda.is_available():
23
+ torch.cuda.manual_seed_all(seed)
24
+ np.random.seed(seed % (2**32))
25
+ random.seed(seed)
26
+
27
+
28
+ def import_filename(filename):
29
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
30
+ module = importlib.util.module_from_spec(spec)
31
+ sys.modules[spec.name] = module
32
+ spec.loader.exec_module(module)
33
+ return module
34
+
35
+
36
+ def delete_additional_ckpt(base_path, num_keep):
37
+ dirs = []
38
+ for d in os.listdir(base_path):
39
+ if d.startswith("checkpoint-"):
40
+ dirs.append(d)
41
+ num_tot = len(dirs)
42
+ if num_tot <= num_keep:
43
+ return
44
+ # ensure ckpt is sorted and delete the ealier!
45
+ del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
46
+ for d in del_dirs:
47
+ path_to_dir = osp.join(base_path, d)
48
+ if osp.exists(path_to_dir):
49
+ shutil.rmtree(path_to_dir)
50
+
51
+
52
+ def save_videos_from_pil(pil_images, path, fps=8, audio_path=None):
53
+ import av
54
+
55
+ save_fmt = Path(path).suffix
56
+ os.makedirs(os.path.dirname(path), exist_ok=True)
57
+ width, height = pil_images[0].size
58
+
59
+ if save_fmt == ".mp4":
60
+ codec = "libx264"
61
+ container = av.open(path, "w")
62
+ stream = container.add_stream(codec, rate=fps)
63
+
64
+ stream.width = width
65
+ stream.height = height
66
+
67
+ for pil_image in pil_images:
68
+ # pil_image = Image.fromarray(image_arr).convert("RGB")
69
+ av_frame = av.VideoFrame.from_image(pil_image)
70
+ container.mux(stream.encode(av_frame))
71
+ container.mux(stream.encode())
72
+ container.close()
73
+
74
+ elif save_fmt == ".gif":
75
+ pil_images[0].save(
76
+ fp=path,
77
+ format="GIF",
78
+ append_images=pil_images[1:],
79
+ save_all=True,
80
+ duration=(1 / fps * 1000),
81
+ loop=0,
82
+ )
83
+ else:
84
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
85
+
86
+
87
+ def save_videos_grid(videos: torch.Tensor, path: str, audio_path=None, rescale=False, n_rows=6, fps=8):
88
+ videos = rearrange(videos, "b c t h w -> t b c h w")
89
+ height, width = videos.shape[-2:]
90
+ outputs = []
91
+
92
+ for x in videos:
93
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
94
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
95
+ if rescale:
96
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
97
+ x = (x * 255).numpy().astype(np.uint8)
98
+ x = Image.fromarray(x)
99
+
100
+ outputs.append(x)
101
+
102
+ os.makedirs(os.path.dirname(path), exist_ok=True)
103
+
104
+ save_videos_from_pil(outputs, path, fps, audio_path=audio_path)
105
+
106
+
107
+ def save_video2imgs(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
108
+ videos = rearrange(videos, "b c t h w -> t b c h w")
109
+ height, width = videos.shape[-2:]
110
+
111
+ os.makedirs(os.path.dirname(path), exist_ok=True)
112
+
113
+ for i, x in enumerate(videos):
114
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
115
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
116
+ if rescale:
117
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
118
+ x = (x * 255).numpy().astype(np.uint8)
119
+ x = Image.fromarray(x)
120
+ img_name = osp.join(path, f"{i}.png")
121
+ x.save(img_name)
122
+
123
+
124
+ def read_frames(video_path):
125
+ container = av.open(video_path)
126
+
127
+ video_stream = next(s for s in container.streams if s.type == "video")
128
+ frames = []
129
+ for packet in container.demux(video_stream):
130
+ for frame in packet.decode():
131
+ image = Image.frombytes(
132
+ "RGB",
133
+ (frame.width, frame.height),
134
+ frame.to_rgb().to_ndarray(),
135
+ )
136
+ frames.append(image)
137
+
138
+ return frames
139
+
140
+
141
+ def get_fps(video_path):
142
+ container = av.open(video_path)
143
+ video_stream = next(s for s in container.streams if s.type == "video")
144
+ fps = video_stream.average_rate
145
+ container.close()
146
+ return fps