SynLayers commited on
Commit
2d3ec50
·
verified ·
1 Parent(s): 83bce40

Upload infer/infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. infer/infer.py +30 -6
infer/infer.py CHANGED
@@ -17,7 +17,12 @@ if PROJECT_ROOT not in sys.path:
17
  logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
18
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
19
 
20
- from infer.common_infer import initialize_pipeline, quantize_box_16, scale_box_xyxy
 
 
 
 
 
21
  from tools.tools import load_config, seed_everything
22
 
23
 
@@ -77,15 +82,27 @@ def build_run_save_dir(config: dict):
77
 
78
 
79
  def resolve_image_path(sample: dict, data_dir: str, image_dir: str = None) -> str:
80
- """Resolve the input image path, preferring local files_real_test images."""
 
81
  sample_name = sample.get("sample_or_stem", "")
82
  image_path = sample.get("image", "")
 
83
 
84
  if image_dir is None and data_dir:
85
  image_dir = os.path.join(data_dir, "layers_real_test_1024")
86
 
87
  candidates = []
88
 
 
 
 
 
 
 
 
 
 
 
89
  if image_dir:
90
  if sample_name:
91
  candidates.extend(
@@ -139,7 +156,10 @@ def quantize_box_16_safe(box: tuple, target_size: int) -> tuple:
139
 
140
 
141
  def get_real_boxes(sample: dict, source_size: int, target_size: int) -> list:
142
- """Scale and quantize real-test boxes from JSON metadata."""
 
 
 
143
  boxes = []
144
  for box in sample.get("bboxes", []):
145
  if not isinstance(box, (list, tuple)) or len(box) != 4:
@@ -228,7 +248,11 @@ def inference_real(config):
228
 
229
  for local_idx, sample in enumerate(samples):
230
  idx_zero_based = start_idx - 1 + local_idx
231
- sample_name = sample.get("sample_or_stem", f"real_{idx_zero_based:06d}")
 
 
 
 
232
  print(
233
  f"Processing [{local_idx + 1}/{len(samples)}] idx={idx_zero_based} ({sample_name})...",
234
  flush=True,
@@ -316,9 +340,9 @@ def inference_real(config):
316
  "source_image_path": format_source_image_path(image_path, config),
317
  "target_size": target_size,
318
  "source_size": source_size,
319
- "raw_num_layers": sample.get("num_layers"),
320
  "num_layers": len(all_boxes),
321
- "raw_boxes": sample.get("bboxes", []),
322
  "boxes": all_boxes,
323
  "caption": caption,
324
  "run_name": run_name,
 
17
  logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
18
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
19
 
20
+ from infer.common_infer import (
21
+ get_layer_boxes,
22
+ initialize_pipeline,
23
+ quantize_box_16,
24
+ scale_box_xyxy,
25
+ )
26
  from tools.tools import load_config, seed_everything
27
 
28
 
 
82
 
83
 
84
  def resolve_image_path(sample: dict, data_dir: str, image_dir: str = None) -> str:
85
+ """Resolve the input image path for Stage 1 or Prism-style metadata."""
86
+ sample_dir = sample.get("sample_dir", "")
87
  sample_name = sample.get("sample_or_stem", "")
88
  image_path = sample.get("image", "")
89
+ blend_path = sample.get("blend_path", "")
90
 
91
  if image_dir is None and data_dir:
92
  image_dir = os.path.join(data_dir, "layers_real_test_1024")
93
 
94
  candidates = []
95
 
96
+ if sample_dir:
97
+ if data_dir and not os.path.isabs(sample_dir):
98
+ candidates.append(os.path.join(data_dir, sample_dir, "whole_image.png"))
99
+ candidates.append(os.path.join(sample_dir, "whole_image.png"))
100
+
101
+ if blend_path:
102
+ candidates.append(blend_path)
103
+ if data_dir and not os.path.isabs(blend_path):
104
+ candidates.append(os.path.join(data_dir, blend_path))
105
+
106
  if image_dir:
107
  if sample_name:
108
  candidates.extend(
 
156
 
157
 
158
  def get_real_boxes(sample: dict, source_size: int, target_size: int) -> list:
159
+ """Scale and quantize boxes from Stage 1 or Prism-style JSON metadata."""
160
+ if sample.get("layers"):
161
+ return get_layer_boxes(sample.get("layers", []), source_size, target_size)
162
+
163
  boxes = []
164
  for box in sample.get("bboxes", []):
165
  if not isinstance(box, (list, tuple)) or len(box) != 4:
 
248
 
249
  for local_idx, sample in enumerate(samples):
250
  idx_zero_based = start_idx - 1 + local_idx
251
+ sample_name = (
252
+ sample.get("sample_or_stem")
253
+ or sample.get("sample_dir")
254
+ or f"sample_{idx_zero_based:06d}"
255
+ )
256
  print(
257
  f"Processing [{local_idx + 1}/{len(samples)}] idx={idx_zero_based} ({sample_name})...",
258
  flush=True,
 
340
  "source_image_path": format_source_image_path(image_path, config),
341
  "target_size": target_size,
342
  "source_size": source_size,
343
+ "raw_num_layers": sample.get("num_layers", sample.get("layer_count")),
344
  "num_layers": len(all_boxes),
345
+ "raw_boxes": sample.get("bboxes", [layer.get("box") for layer in sample.get("layers", [])]),
346
  "boxes": all_boxes,
347
  "caption": caption,
348
  "run_name": run_name,