SynLayers commited on
Commit
b752efc
·
verified ·
1 Parent(s): 2204787

Upload infer/infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. infer/infer.py +371 -0
infer/infer.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+
13
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+ if PROJECT_ROOT not in sys.path:
15
+ sys.path.insert(0, PROJECT_ROOT)
16
+
17
+ logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
18
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
19
+
20
+ from infer.common_infer import initialize_pipeline, quantize_box_16, scale_box_xyxy
21
+ from tools.tools import load_config, seed_everything
22
+
23
+
24
+ def load_real_metadata(jsonl_path: str):
25
+ """Load real-test metadata from JSONL."""
26
+ items = []
27
+ with open(jsonl_path, "r", encoding="utf-8") as f:
28
+ for line in f:
29
+ line = line.strip()
30
+ if line:
31
+ items.append(json.loads(line))
32
+ return items
33
+
34
+
35
+ def extract_checkpoint_tag(path: str):
36
+ """Extract a checkpoint tag like scaleup_1024_20k or original_1024_512seq."""
37
+ if not path:
38
+ return None
39
+
40
+ match = re.search(r"ckpt_prism_([^/]+)", path)
41
+ if match:
42
+ return match.group(1)
43
+ return None
44
+
45
+
46
+ def derive_run_name(config: dict) -> str:
47
+ """Derive the result subfolder name from the active checkpoint setup."""
48
+ checkpoint_tags = {}
49
+ for key in ("lora_ckpt", "layer_ckpt", "adapter_lora_dir"):
50
+ tag = extract_checkpoint_tag(config.get(key, ""))
51
+ if tag:
52
+ checkpoint_tags[key] = tag
53
+
54
+ if checkpoint_tags:
55
+ unique_tags = sorted(set(checkpoint_tags.values()))
56
+ if len(unique_tags) != 1:
57
+ details = ", ".join(f"{key}={value}" for key, value in checkpoint_tags.items())
58
+ raise ValueError(
59
+ "Checkpoint paths are inconsistent. "
60
+ "Please switch lora_ckpt, layer_ckpt, and adapter_lora_dir together. "
61
+ f"Current tags: {details}"
62
+ )
63
+ inferred_tag = unique_tags[0]
64
+ else:
65
+ inferred_tag = "real_infer"
66
+
67
+ if config.get("run_name"):
68
+ return config["run_name"]
69
+ return inferred_tag
70
+
71
+
72
+ def build_run_save_dir(config: dict):
73
+ """Build the final save directory as <save_dir>/<run_name>."""
74
+ save_root = config.get("save_dir", "./real_inference_output")
75
+ run_name = derive_run_name(config)
76
+ return os.path.join(save_root, run_name), run_name
77
+
78
+
79
+ def resolve_image_path(sample: dict, data_dir: str, image_dir: str = None) -> str:
80
+ """Resolve the input image path, preferring local files_real_test images."""
81
+ sample_name = sample.get("sample_or_stem", "")
82
+ image_path = sample.get("image", "")
83
+
84
+ if image_dir is None and data_dir:
85
+ image_dir = os.path.join(data_dir, "layers_real_test_1024")
86
+
87
+ candidates = []
88
+
89
+ if image_dir:
90
+ if sample_name:
91
+ candidates.extend(
92
+ [
93
+ os.path.join(image_dir, f"{sample_name}.png"),
94
+ os.path.join(image_dir, f"{sample_name}.jpg"),
95
+ os.path.join(image_dir, f"{sample_name}.jpeg"),
96
+ ]
97
+ )
98
+ if image_path:
99
+ candidates.append(os.path.join(image_dir, os.path.basename(image_path)))
100
+
101
+ if image_path:
102
+ candidates.append(image_path)
103
+ if data_dir and not os.path.isabs(image_path):
104
+ candidates.append(os.path.join(data_dir, image_path))
105
+
106
+ seen = set()
107
+ for candidate in candidates:
108
+ if not candidate or candidate in seen:
109
+ continue
110
+ seen.add(candidate)
111
+ if os.path.exists(candidate):
112
+ return candidate
113
+
114
+ raise FileNotFoundError(
115
+ f"Could not resolve image for sample '{sample_name}'. "
116
+ f"Tried local image_dir='{image_dir}' and json path '{image_path}'."
117
+ )
118
+
119
+
120
+ def quantize_box_16_safe(box: tuple, target_size: int) -> tuple:
121
+ """Quantize a box to the 16-pixel grid and keep at least one latent cell."""
122
+ x0_q, y0_q, x1_q, y1_q = quantize_box_16(box, target_size)
123
+
124
+ if x1_q <= x0_q:
125
+ if x0_q + 16 <= target_size:
126
+ x1_q = x0_q + 16
127
+ else:
128
+ x0_q = max(0, target_size - 16)
129
+ x1_q = target_size
130
+
131
+ if y1_q <= y0_q:
132
+ if y0_q + 16 <= target_size:
133
+ y1_q = y0_q + 16
134
+ else:
135
+ y0_q = max(0, target_size - 16)
136
+ y1_q = target_size
137
+
138
+ return (x0_q, y0_q, x1_q, y1_q)
139
+
140
+
141
+ def get_real_boxes(sample: dict, source_size: int, target_size: int) -> list:
142
+ """Scale and quantize real-test boxes from JSON metadata."""
143
+ boxes = []
144
+ for box in sample.get("bboxes", []):
145
+ if not isinstance(box, (list, tuple)) or len(box) != 4:
146
+ continue
147
+ scaled_box = scale_box_xyxy(box, source_size, target_size)
148
+ boxes.append(quantize_box_16_safe(scaled_box, target_size))
149
+ return boxes
150
+
151
+
152
+ def load_adapter_image(sample: dict, target_size: int, config: dict):
153
+ """Load and resize the real-test image used as adapter input."""
154
+ image_path = resolve_image_path(
155
+ sample,
156
+ data_dir=config.get("data_dir", ""),
157
+ image_dir=config.get("image_dir"),
158
+ )
159
+ img = Image.open(image_path).convert("RGB")
160
+
161
+ if img.size != (target_size, target_size):
162
+ img = img.resize((target_size, target_size), Image.LANCZOS)
163
+
164
+ return img, image_path
165
+
166
+
167
+ def format_source_image_path(image_path: str, config: dict) -> str:
168
+ path = Path(image_path)
169
+ for key in ("image_dir", "data_dir"):
170
+ root = config.get(key)
171
+ if not root:
172
+ continue
173
+ try:
174
+ return path.relative_to(Path(root)).as_posix()
175
+ except ValueError:
176
+ continue
177
+ return path.name
178
+
179
+
180
+ @torch.no_grad()
181
+ def inference_real(config):
182
+ """Main inference function for the real-test dataset."""
183
+ if config.get("seed") is not None:
184
+ seed_everything(config["seed"])
185
+
186
+ source_size = config.get("source_size", 1024)
187
+ target_size = config.get("target_size", 1024)
188
+ max_layer_num = config.get("max_layer_num", 52)
189
+
190
+ print(f"[INFO] Source size: {source_size}, Target size: {target_size}", flush=True)
191
+
192
+ save_dir, run_name = build_run_save_dir(config)
193
+ os.makedirs(save_dir, exist_ok=True)
194
+ os.makedirs(os.path.join(save_dir, "merged"), exist_ok=True)
195
+ os.makedirs(os.path.join(save_dir, "merged_rgba"), exist_ok=True)
196
+ print(f"[INFO] Run name: {run_name}", flush=True)
197
+ print(f"[INFO] Results will be saved to: {save_dir}", flush=True)
198
+
199
+ pipeline, transp_vae = initialize_pipeline(config)
200
+
201
+ test_jsonl = config.get("test_jsonl", "")
202
+ if not test_jsonl or not os.path.exists(test_jsonl):
203
+ raise ValueError(f"Test JSONL not found: {test_jsonl}")
204
+
205
+ all_samples = load_real_metadata(test_jsonl)
206
+ total_available = len(all_samples)
207
+
208
+ start_idx = config.get("start_idx", 1)
209
+ end_idx = config.get("end_idx", total_available)
210
+ max_samples = config.get("max_samples", None)
211
+
212
+ if max_samples and not config.get("end_idx"):
213
+ end_idx = min(start_idx + max_samples - 1, total_available)
214
+
215
+ start_idx = max(1, min(start_idx, total_available))
216
+ end_idx = max(start_idx, min(end_idx, total_available))
217
+ samples = all_samples[start_idx - 1 : end_idx]
218
+
219
+ print(f"[INFO] Total samples in dataset: {total_available}", flush=True)
220
+ print(
221
+ f"[INFO] Processing samples {start_idx} to {end_idx} ({len(samples)} samples)",
222
+ flush=True,
223
+ )
224
+
225
+ generator = torch.Generator(device=torch.device("cuda")).manual_seed(
226
+ config.get("seed", 42)
227
+ )
228
+
229
+ for local_idx, sample in enumerate(samples):
230
+ idx_zero_based = start_idx - 1 + local_idx
231
+ sample_name = sample.get("sample_or_stem", f"real_{idx_zero_based:06d}")
232
+ print(
233
+ f"Processing [{local_idx + 1}/{len(samples)}] idx={idx_zero_based} ({sample_name})...",
234
+ flush=True,
235
+ )
236
+
237
+ try:
238
+ layer_boxes = get_real_boxes(sample, source_size, target_size)
239
+ adapter_img, image_path = load_adapter_image(sample, target_size, config)
240
+ except Exception as e:
241
+ print(f" Error preparing sample: {e}", flush=True)
242
+ continue
243
+
244
+ whole_box = (0, 0, target_size, target_size)
245
+ bg_box = (0, 0, target_size, target_size)
246
+ all_boxes = [whole_box, bg_box] + layer_boxes
247
+
248
+ if len(all_boxes) > max_layer_num:
249
+ print(
250
+ f" Skipping sample because num_layers={len(all_boxes)} exceeds max_layer_num={max_layer_num}",
251
+ flush=True,
252
+ )
253
+ continue
254
+
255
+ caption = sample.get("whole_caption", "")
256
+ print(f" Size: {target_size}x{target_size}, Layers: {len(all_boxes)}", flush=True)
257
+
258
+ try:
259
+ x_hat, image, _ = pipeline(
260
+ prompt=caption,
261
+ adapter_image=adapter_img,
262
+ adapter_conditioning_scale=config.get("adapter_scale", 0.9),
263
+ validation_box=all_boxes,
264
+ generator=generator,
265
+ height=target_size,
266
+ width=target_size,
267
+ guidance_scale=config.get("cfg", 4.0),
268
+ num_layers=len(all_boxes),
269
+ sdxl_vae=transp_vae,
270
+ )
271
+ except Exception as e:
272
+ print(f" Error during inference: {e}", flush=True)
273
+ continue
274
+
275
+ x_hat = (x_hat + 1) / 2
276
+ x_hat = x_hat.squeeze(0).permute(1, 0, 2, 3).to(torch.float32)
277
+
278
+ case_dir = os.path.join(save_dir, sample_name)
279
+ os.makedirs(case_dir, exist_ok=True)
280
+
281
+ whole_image_layer = (
282
+ x_hat[0].permute(1, 2, 0).cpu().numpy() * 255
283
+ ).astype(np.uint8)
284
+ Image.fromarray(whole_image_layer, "RGBA").save(
285
+ os.path.join(case_dir, "whole_image_rgba.png")
286
+ )
287
+
288
+ background_layer = (
289
+ x_hat[1].permute(1, 2, 0).cpu().numpy() * 255
290
+ ).astype(np.uint8)
291
+ Image.fromarray(background_layer, "RGBA").save(
292
+ os.path.join(case_dir, "background_rgba.png")
293
+ )
294
+
295
+ adapter_img.save(os.path.join(case_dir, "origin.png"))
296
+
297
+ merged_image = image[1]
298
+ for layer_idx in range(2, x_hat.shape[0]):
299
+ rgba_layer = (
300
+ x_hat[layer_idx].permute(1, 2, 0).cpu().numpy() * 255
301
+ ).astype(np.uint8)
302
+ rgba_image = Image.fromarray(rgba_layer, "RGBA")
303
+ rgba_image.save(os.path.join(case_dir, f"layer_{layer_idx - 2}_rgba.png"))
304
+ merged_image = Image.alpha_composite(merged_image.convert("RGBA"), rgba_image)
305
+
306
+ merged_image.convert("RGB").save(
307
+ os.path.join(save_dir, "merged", f"{sample_name}.png")
308
+ )
309
+ merged_image.convert("RGB").save(os.path.join(case_dir, "merged.png"))
310
+ merged_image.save(os.path.join(save_dir, "merged_rgba", f"{sample_name}.png"))
311
+
312
+ case_meta = {
313
+ "sample_idx_zero_based": idx_zero_based,
314
+ "sample_idx_one_based": idx_zero_based + 1,
315
+ "sample_name": sample_name,
316
+ "source_image_path": format_source_image_path(image_path, config),
317
+ "target_size": target_size,
318
+ "source_size": source_size,
319
+ "raw_num_layers": sample.get("num_layers"),
320
+ "num_layers": len(all_boxes),
321
+ "raw_boxes": sample.get("bboxes", []),
322
+ "boxes": all_boxes,
323
+ "caption": caption,
324
+ "run_name": run_name,
325
+ }
326
+ with open(os.path.join(case_dir, "inference_meta.json"), "w", encoding="utf-8") as f:
327
+ json.dump(case_meta, f, indent=2)
328
+
329
+ if idx_zero_based % 10 == 0:
330
+ torch.cuda.empty_cache()
331
+
332
+ print(f"[INFO] Inference complete. Results saved to {save_dir}", flush=True)
333
+
334
+ del pipeline
335
+ if torch.cuda.is_available():
336
+ torch.cuda.empty_cache()
337
+
338
+
339
+ def main():
340
+ parser = argparse.ArgumentParser()
341
+ parser.add_argument(
342
+ "--config_path",
343
+ "-c",
344
+ type=str,
345
+ required=True,
346
+ help="Path to the YAML configuration file.",
347
+ )
348
+ parser.add_argument(
349
+ "--start_idx",
350
+ type=int,
351
+ default=None,
352
+ help="1-based start index for the JSONL entries.",
353
+ )
354
+ parser.add_argument(
355
+ "--end_idx",
356
+ type=int,
357
+ default=None,
358
+ help="1-based end index for the JSONL entries (inclusive).",
359
+ )
360
+ args = parser.parse_args()
361
+
362
+ config = load_config(args.config_path)
363
+ if args.start_idx is not None:
364
+ config["start_idx"] = args.start_idx
365
+ if args.end_idx is not None:
366
+ config["end_idx"] = args.end_idx
367
+ inference_real(config)
368
+
369
+
370
+ if __name__ == "__main__":
371
+ main()