SynLayers commited on
Commit
50acfb7
·
verified ·
1 Parent(s): d4bf927

Upload tools/tools.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tools/tools.py +394 -0
tools/tools.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, random
2
+ import torch
3
+ import numpy as np
4
+ from typing import Union
5
+ import pickle
6
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
7
+ from peft import LoraConfig
8
+ from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
9
+
10
+ from models.mmdit import CustomFluxTransformer2DModel
11
+ from models.pipeline import CustomFluxPipeline
12
+ from models.multiLayer_adapter import MultiLayerAdapter
13
+
14
+ def save_checkpoint(transformer, multiLayer_adater, optimizer, optimizer_adapter, scheduler, scheduler_adapter, step, save_dir):
15
+ import gc
16
+
17
+ trans_dir = os.path.join(save_dir, "transformer")
18
+ adapter_dir = os.path.join(save_dir, "adapter")
19
+ os.makedirs(trans_dir, exist_ok=True)
20
+ os.makedirs(adapter_dir, exist_ok=True)
21
+
22
+ # Get state dicts and IMMEDIATELY move to CPU to avoid GPU memory buildup
23
+ flux_transformer_lora_state_dict = get_peft_model_state_dict(transformer)
24
+ flux_transformer_lora_state_dict = {k: v.detach().cpu().to(torch.float32) for k, v in flux_transformer_lora_state_dict.items()}
25
+
26
+ flux_adapter_lora_state_dict = get_peft_model_state_dict(multiLayer_adater)
27
+ flux_adapter_lora_state_dict = {k: v.detach().cpu().to(torch.float32) for k, v in flux_adapter_lora_state_dict.items()}
28
+
29
+ CustomFluxPipeline.save_lora_weights(
30
+ os.path.join(trans_dir),
31
+ flux_transformer_lora_state_dict,
32
+ safe_serialization=True,
33
+ )
34
+ # Clear after saving
35
+ del flux_transformer_lora_state_dict
36
+
37
+ CustomFluxPipeline.save_lora_weights(
38
+ os.path.join(adapter_dir),
39
+ flux_adapter_lora_state_dict,
40
+ safe_serialization=True,
41
+ )
42
+ # Clear after saving
43
+ del flux_adapter_lora_state_dict
44
+
45
+ torch.save({"layer_pe": transformer.layer_pe.detach().cpu().to(torch.float32)}, os.path.join(save_dir, "layer_pe.pth"))
46
+
47
+ torch.save(optimizer.state_dict(), os.path.join(trans_dir, "optimizer.bin"))
48
+ torch.save(optimizer_adapter.state_dict(), os.path.join(adapter_dir, "optimizer.bin"))
49
+
50
+ torch.save(scheduler.state_dict(), os.path.join(trans_dir, "scheduler.bin"))
51
+ torch.save(scheduler_adapter.state_dict(), os.path.join(adapter_dir, "scheduler.bin"))
52
+
53
+ save_path = os.path.join(save_dir, f"random_states_0.pkl")
54
+ state = {
55
+ "step": step,
56
+ "random_state": random.getstate(),
57
+ "numpy_random_seed": np.random.get_state(),
58
+ "torch_manual_seed": torch.get_rng_state(),
59
+ }
60
+
61
+ if torch.cuda.is_available():
62
+ state["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() # list of tensors
63
+
64
+ with open(save_path, "wb") as f:
65
+ pickle.dump(state, f)
66
+
67
+ # Force garbage collection and clear CUDA cache
68
+ gc.collect()
69
+ if torch.cuda.is_available():
70
+ torch.cuda.empty_cache()
71
+
72
+ print(f"[INFO] Saved RNG states + step {step} to {save_path}")
73
+
74
+
75
+ def load_checkpoint(transformer, multiLayer_adater, optimizer, optimizer_adapter, scheduler, scheduler_adapter, ckpt_dir, device="cuda"):
76
+ trans_dir = os.path.join(ckpt_dir, "transformer")
77
+ adapter_dir = os.path.join(ckpt_dir, "adapter")
78
+ start_step = 0
79
+
80
+ lora_path = os.path.join(trans_dir, "pytorch_lora_weights.safetensors")
81
+ lora_path_adapter = os.path.join(adapter_dir, "pytorch_lora_weights.safetensors")
82
+ if os.path.exists(lora_path):
83
+ lora_state_dict = CustomFluxPipeline.lora_state_dict(lora_path)
84
+ stripped = {k.replace("transformer.", "", 1) if k.startswith("transformer.") else k: v for k, v in lora_state_dict.items()}
85
+ result = set_peft_model_state_dict(transformer, stripped)
86
+ if result.unexpected_keys:
87
+ print(f"[WARN] Transformer LoRA: {len(result.unexpected_keys)} unexpected keys")
88
+ print(f"[INFO] Loaded Transformer LoRA weights ({len(stripped)} keys).")
89
+ if os.path.exists(lora_path_adapter):
90
+ lora_state_dict = CustomFluxPipeline.lora_state_dict(lora_path_adapter)
91
+ stripped = {k.replace("transformer.", "", 1) if k.startswith("transformer.") else k: v for k, v in lora_state_dict.items()}
92
+ result = set_peft_model_state_dict(multiLayer_adater, stripped)
93
+ if result.unexpected_keys:
94
+ print(f"[WARN] Adapter LoRA: {len(result.unexpected_keys)} unexpected keys")
95
+ print(f"[INFO] Loaded Adapter LoRA weights ({len(stripped)} keys).")
96
+
97
+ pe_path = os.path.join(ckpt_dir, "layer_pe.pth")
98
+ if os.path.exists(pe_path):
99
+ layer_pe = torch.load(pe_path)
100
+ missing_keys, unexpected_keys = transformer.load_state_dict(layer_pe, strict=False)
101
+
102
+ opt_path = os.path.join(trans_dir, "optimizer.bin")
103
+ opt_path_adapter = os.path.join(adapter_dir, "optimizer.bin")
104
+ if os.path.exists(opt_path):
105
+ optimizer.load_state_dict(torch.load(opt_path, map_location=device))
106
+ print("[INFO] Loaded optimizer state.")
107
+ if os.path.exists(opt_path_adapter):
108
+ optimizer_adapter.load_state_dict(torch.load(opt_path_adapter, map_location=device))
109
+ print("[INFO] Loaded optimizer state.")
110
+
111
+ sch_path = os.path.join(trans_dir, "scheduler.bin")
112
+ sch_path_adapter = os.path.join(adapter_dir, "scheduler.bin")
113
+ if os.path.exists(sch_path):
114
+ scheduler.load_state_dict(torch.load(sch_path, map_location=device))
115
+ print("[INFO] Loaded scheduler state.")
116
+ if os.path.exists(sch_path_adapter):
117
+ scheduler_adapter.load_state_dict(torch.load(sch_path_adapter, map_location=device))
118
+ print("[INFO] Loaded scheduler state.")
119
+
120
+ rng_file = None
121
+ for f in os.listdir(ckpt_dir):
122
+ if f.startswith("random_states_") and f.endswith(".pkl"):
123
+ rng_file = os.path.join(ckpt_dir, f)
124
+ break
125
+
126
+ if rng_file:
127
+ with open(rng_file, "rb") as f:
128
+ state = pickle.load(f)
129
+ start_step = state.get("step", 0)
130
+
131
+ if "random_state" in state:
132
+ random.setstate(state["random_state"])
133
+ if "numpy_random_seed" in state:
134
+ np.random.set_state(state["numpy_random_seed"])
135
+ if "torch_manual_seed" in state:
136
+ torch.set_rng_state(state["torch_manual_seed"])
137
+ if "torch_cuda_manual_seed" in state and torch.cuda.is_available():
138
+ torch.cuda.set_rng_state_all(state["torch_cuda_manual_seed"])
139
+
140
+ print(f"[INFO] Resumed RNG states + step {start_step}")
141
+
142
+ return start_step
143
+
144
+
145
+ def load_config(path):
146
+ with open(path, "r") as f:
147
+ return yaml.safe_load(f)
148
+
149
+
150
+ def seed_everything(seed: int):
151
+ random.seed(seed)
152
+ np.random.seed(seed)
153
+ torch.manual_seed(seed)
154
+ if torch.cuda.is_available():
155
+ torch.cuda.manual_seed_all(seed)
156
+ torch.backends.cudnn.deterministic = True
157
+
158
+
159
+ def get_input_box(layer_boxes, image_size=512):
160
+ """
161
+ Quantize layer boxes to 16-pixel grid for latent space alignment.
162
+
163
+ Args:
164
+ layer_boxes: List of boxes in xyxy format [x0, y0, x1, y1]
165
+ image_size: Image size to clamp bounds (default 512)
166
+
167
+ Returns:
168
+ List of quantized boxes in xyxy format
169
+ """
170
+ list_layer_box = []
171
+ for layer_box in layer_boxes:
172
+ min_col, min_row = layer_box[0], layer_box[1]
173
+ max_col, max_row = layer_box[2], layer_box[3]
174
+
175
+ # Floor for min (start of box)
176
+ quantized_min_row = (min_row // 16) * 16
177
+ quantized_min_col = (min_col // 16) * 16
178
+
179
+ # Ceiling for max (end of box) - use (val + 15) // 16 * 16 for proper ceiling
180
+ quantized_max_row = ((max_row + 15) // 16) * 16
181
+ quantized_max_col = ((max_col + 15) // 16) * 16
182
+
183
+ # Clamp to image bounds
184
+ quantized_min_row = max(0, quantized_min_row)
185
+ quantized_min_col = max(0, quantized_min_col)
186
+ quantized_max_row = min(image_size, quantized_max_row)
187
+ quantized_max_col = min(image_size, quantized_max_col)
188
+
189
+ # Ensure minimum box size of 16 pixels (1 latent token) in each dimension
190
+ # This prevents zero-size boxes that cause reshape errors
191
+ if quantized_max_col <= quantized_min_col:
192
+ # Expand the box, preferring to expand max if there's room
193
+ if quantized_min_col + 16 <= image_size:
194
+ quantized_max_col = quantized_min_col + 16
195
+ else:
196
+ quantized_min_col = max(0, quantized_max_col - 16)
197
+ quantized_max_col = quantized_min_col + 16
198
+
199
+ if quantized_max_row <= quantized_min_row:
200
+ # Expand the box, preferring to expand max if there's room
201
+ if quantized_min_row + 16 <= image_size:
202
+ quantized_max_row = quantized_min_row + 16
203
+ else:
204
+ quantized_min_row = max(0, quantized_max_row - 16)
205
+ quantized_max_row = quantized_min_row + 16
206
+
207
+ list_layer_box.append((quantized_min_col, quantized_min_row, quantized_max_col, quantized_max_row))
208
+ return list_layer_box
209
+
210
+
211
+ def set_lora_into_transformer(
212
+ model: Union[CustomFluxTransformer2DModel, MultiLayerAdapter],
213
+ lora_rank: int,
214
+ lora_alpha: float = 1.0,
215
+ lora_dropout: float = 0.1,
216
+ ):
217
+
218
+ target_modules = [
219
+ "to_k", "to_q", "to_v",
220
+ "to_out.0",
221
+ "add_k_proj", "add_q_proj", "add_v_proj",
222
+ "to_add_out",
223
+ ] + [f"single_transformer_blocks.{i}.proj_out" for i in range(model.config.num_single_layers)] + [f"transformer_blocks.{i}.proj_out" for i in range(model.config.num_layers)]
224
+
225
+ transformer_lora_config = LoraConfig(
226
+ r=lora_rank,
227
+ lora_alpha=lora_alpha,
228
+ lora_dropout=lora_dropout,
229
+ init_lora_weights="gaussian",
230
+ target_modules=target_modules,
231
+ )
232
+
233
+ model.add_adapter(transformer_lora_config)
234
+ return model
235
+
236
+
237
+ def build_layer_mask(n_layers, H_lat, W_lat, list_layer_box):
238
+ mask = torch.zeros((n_layers, 1, H_lat, W_lat), dtype=torch.float32)
239
+ for i, box in enumerate(list_layer_box):
240
+ if box is None:
241
+ continue
242
+ x1, y1, x2, y2 = box
243
+ x1_t, y1_t, x2_t, y2_t = x1 // 8, y1 // 8, x2 // 8, y2 // 8
244
+ x1_t, y1_t = max(0, x1_t), max(0, y1_t)
245
+ x2_t, y2_t = min(W_lat, x2_t), min(H_lat, y2_t)
246
+ if x2_t > x1_t and y2_t > y1_t:
247
+ mask[i, :, y1_t:y2_t, x1_t:x2_t] = 1.0
248
+ return mask
249
+
250
+
251
+ def encode_target_latents(pipeline, pixel_bchw, n_layers, list_layer_box):
252
+ device = pixel_bchw.device
253
+ dtype = pixel_bchw.dtype
254
+
255
+ vae = pipeline.vae.eval()
256
+ bs, n_layers_in, C, H, W = pixel_bchw.shape
257
+ assert n_layers_in == n_layers, f"The number of input layers {n_layers_in} does not match the specified number of layers {n_layers}"
258
+
259
+ with torch.no_grad():
260
+ dummy_lat = vae.encode(pixel_bchw[:,0]).latent_dist.sample()
261
+ _, C_lat, H_lat, W_lat = dummy_lat.shape
262
+
263
+ x0 = torch.zeros((bs, n_layers, C_lat, H_lat, W_lat), device=device, dtype=dtype)
264
+
265
+ with torch.no_grad():
266
+ for i in range(n_layers):
267
+ pixel_i = pixel_bchw[:, i]
268
+ lat = vae.encode(pixel_i).latent_dist.sample() # [1,C_lat,H_lat,W_lat]
269
+ lat = (lat - vae.config.shift_factor) * vae.config.scaling_factor
270
+ x0[:, i] = lat
271
+
272
+ latent_ids = pipeline._prepare_latent_image_ids(H_lat, W_lat, list_layer_box, device, dtype)
273
+
274
+ return x0, latent_ids
275
+
276
+
277
+ def get_timesteps(pipeline, image_seq_len, num_inference_steps, device):
278
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
279
+
280
+ mu = calculate_shift(
281
+ image_seq_len,
282
+ pipeline.scheduler.config.base_image_seq_len,
283
+ pipeline.scheduler.config.max_image_seq_len,
284
+ pipeline.scheduler.config.base_shift,
285
+ pipeline.scheduler.config.max_shift,
286
+ )
287
+
288
+ timesteps, num_inference_steps = retrieve_timesteps(
289
+ scheduler=pipeline.scheduler,
290
+ num_inference_steps=num_inference_steps,
291
+ device=device,
292
+ sigmas=sigmas,
293
+ mu=mu,
294
+ )
295
+
296
+ return timesteps
297
+
298
+
299
+ # ============================================================================
300
+ # Box utilities for Prism blended dataset
301
+ # ============================================================================
302
+
303
+ def scale_box_xyxy(box, source_size: int, target_size: int):
304
+ """
305
+ Scale a box from source_size to target_size.
306
+ Box is already in xyxy format: [x0, y0, x1, y1].
307
+
308
+ Args:
309
+ box: [x0, y0, x1, y1] in source_size coordinates
310
+ source_size: Original data size (e.g., 512)
311
+ target_size: Target inference size (e.g., 512)
312
+
313
+ Returns:
314
+ (x0, y0, x1, y1) in target_size coordinates
315
+ """
316
+ scale = target_size / source_size
317
+ x0, y0, x1, y1 = box
318
+
319
+ x0_s = int(x0 * scale)
320
+ y0_s = int(y0 * scale)
321
+ x1_s = int(x1 * scale)
322
+ y1_s = int(y1 * scale)
323
+
324
+ # Clamp to valid range
325
+ x0_s = max(0, x0_s)
326
+ y0_s = max(0, y0_s)
327
+ x1_s = min(target_size, x1_s)
328
+ y1_s = min(target_size, y1_s)
329
+
330
+ return (x0_s, y0_s, x1_s, y1_s)
331
+
332
+
333
+ def quantize_box_16(box, target_size: int):
334
+ """
335
+ Quantize box to 16-pixel grid for latent space alignment.
336
+ Box is in xyxy format.
337
+ """
338
+ x0, y0, x1, y1 = box
339
+
340
+ # Quantize to 16-pixel grid
341
+ x0_q = (x0 // 16) * 16
342
+ y0_q = (y0 // 16) * 16
343
+ x1_q = ((x1 + 15) // 16) * 16
344
+ y1_q = ((y1 + 15) // 16) * 16
345
+
346
+ # Clamp to image bounds
347
+ x0_q = max(0, x0_q)
348
+ y0_q = max(0, y0_q)
349
+ x1_q = min(target_size, x1_q)
350
+ y1_q = min(target_size, y1_q)
351
+
352
+ return (x0_q, y0_q, x1_q, y1_q)
353
+
354
+
355
+ def get_prism_layer_boxes_xyxy(layers, source_size: int, target_size: int):
356
+ """
357
+ Extract and scale layer boxes from prism blended metadata.
358
+
359
+ Note: Our blended dataset uses xyxy format [x0, y0, x1, y1].
360
+
361
+ Args:
362
+ layers: List of layer metadata dicts with 'box' field (xyxy format)
363
+ source_size: Size the data was generated at (e.g., 512)
364
+ target_size: Size to run inference at (e.g., 512)
365
+
366
+ Returns:
367
+ List of quantized boxes in xyxy format
368
+ """
369
+ boxes = []
370
+
371
+ for layer in layers:
372
+ box = layer.get('box', [0, 0, source_size, source_size])
373
+
374
+ # Scale from source to target size (box is already xyxy)
375
+ scaled_box = scale_box_xyxy(box, source_size, target_size)
376
+
377
+ # Quantize to 16-pixel grid
378
+ quantized_box = quantize_box_16(scaled_box, target_size)
379
+
380
+ boxes.append(quantized_box)
381
+
382
+ return boxes
383
+
384
+
385
+ def xywh_to_xyxy(box):
386
+ """Convert (x, y, w, h) to (x0, y0, x1, y1)."""
387
+ x, y, w, h = box
388
+ return (x, y, x + w, y + h)
389
+
390
+
391
+ def xyxy_to_xywh(box):
392
+ """Convert (x0, y0, x1, y1) to (x, y, w, h)."""
393
+ x0, y0, x1, y1 = box
394
+ return (x0, y0, x1 - x0, y1 - y0)