SynLayers commited on
Commit
17b46bf
·
verified ·
1 Parent(s): 06be361

Upload dataset/scaleup_api.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset/scaleup_api.py +222 -0
dataset/scaleup_api.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import argparse
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ ROOT_DIR = os.environ.get("ROOT_DIR", "/project/llmsvgen/share/data/kmw_layered_dataset/PrismLayersPro-scaledup-1024-alpha-500k")
9
+ QWEN_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
10
+
11
+ SYSTEM_PROMPT = """You are an expert image captioner.
12
+ Your task is to refine and condense a long, redundant 'whole caption' of a layered image.
13
+ The original caption is a combination of a background description and multiple foreground layers with their positions and descriptions.
14
+
15
+ Requirements:
16
+ 1. Conciseness: Keep the final caption between 100 to 140 words!
17
+ 2. Natural Flow: Blend the background and layers into a cohesive, professional paragraph. Avoid repetitive phrases like 'you can see' or 'there is'.
18
+ 3. Output Format: Return ONLY the refined caption string.
19
+ 4. Accuracy and Vividness: Ensure descriptions precisely match visual elements, using vivid but concise language; handle any layer overlaps or interactions naturally without redundancy.
20
+ 5. Make sure we have the first 50 words of the caption to be a overview of the image. And the rest of the caption, should be a detailed description of the image, around 60 to 100 words.
21
+ 6. If there contains layers that are overlapped by other layers, you should describe the overlapped layers in the caption as well in a concise and proper manner.
22
+ 7. For english text layer, you should describe the text in the caption in details, what is it in the text layer.
23
+ """
24
+
25
+
26
+ def load_model(device):
27
+ """Load Qwen2.5-VL-3B-Instruct on a specific device."""
28
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
29
+ print(f" Loading model weights...")
30
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
31
+ QWEN_MODEL_PATH,
32
+ torch_dtype=torch.bfloat16,
33
+ ).to(device)
34
+ print(f" Loading processor...")
35
+ processor = AutoProcessor.from_pretrained(QWEN_MODEL_PATH)
36
+ processor.tokenizer.padding_side = "left"
37
+ model.eval()
38
+ print(f" Model ready on {device}")
39
+ return model, processor
40
+
41
+
42
+ def refine_caption_batch(model, processor, whole_captions, whole_image_paths, device):
43
+ """Refine a batch of captions using Qwen2.5-VL with whole_image as visual input."""
44
+ from qwen_vl_utils import process_vision_info
45
+
46
+ all_texts = []
47
+ all_image_inputs = []
48
+
49
+ for caption, img_path in zip(whole_captions, whole_image_paths):
50
+ content = []
51
+ if img_path and os.path.exists(img_path):
52
+ content.append({"type": "image", "image": f"file://{img_path}"})
53
+ content.append({"type": "text", "text": f"Refine this caption: {caption}"})
54
+
55
+ messages = [
56
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
57
+ {"role": "user", "content": content},
58
+ ]
59
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
60
+ all_texts.append(text)
61
+
62
+ img_msg = [{"role": "user", "content": content}]
63
+ img_inputs, _ = process_vision_info(img_msg)
64
+ if img_inputs:
65
+ all_image_inputs.extend(img_inputs)
66
+
67
+ inputs = processor(
68
+ text=all_texts,
69
+ images=all_image_inputs if all_image_inputs else None,
70
+ padding=True,
71
+ return_tensors="pt",
72
+ ).to(device)
73
+
74
+ with torch.no_grad():
75
+ generated_ids = model.generate(**inputs, max_new_tokens=256, temperature=0.7, do_sample=True)
76
+
77
+ generated_ids_trimmed = [
78
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
79
+ ]
80
+ results = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
81
+ return [r.strip() for r in results]
82
+
83
+
84
+ def process_sample_check(sample_name, skip_existing=False):
85
+ """Check if a sample needs processing. Returns (sample_name, whole_caption, img_path) or None."""
86
+ sample_path = os.path.join(ROOT_DIR, sample_name)
87
+ metadata_path = os.path.join(sample_path, "metadata.json")
88
+ metadata_old_path = os.path.join(sample_path, "metadata_old.json")
89
+
90
+ if skip_existing and os.path.exists(metadata_old_path) and os.path.exists(metadata_path):
91
+ return None
92
+
93
+ if os.path.exists(metadata_old_path):
94
+ src = metadata_old_path
95
+ elif os.path.exists(metadata_path):
96
+ os.rename(metadata_path, metadata_old_path)
97
+ src = metadata_old_path
98
+ else:
99
+ return None
100
+
101
+ with open(src, 'r', encoding='utf-8') as f:
102
+ data = json.load(f)
103
+
104
+ whole_caption = data.get("whole_caption", "")
105
+ if not whole_caption:
106
+ return None
107
+
108
+ whole_image_path = os.path.join(sample_path, "whole_image.png")
109
+ return (sample_name, whole_caption, whole_image_path)
110
+
111
+
112
+ def process_gpu_shard(gpu_id, sample_names, batch_size, skip_existing=False):
113
+ """Process a shard of samples on a specific GPU."""
114
+ device = f"cuda:{gpu_id}"
115
+ print(f"[GPU {gpu_id}] Loading model on {device}...")
116
+ model, processor = load_model(device)
117
+ print(f"[GPU {gpu_id}] Checking {len(sample_names)} samples (skip_existing={skip_existing})...")
118
+
119
+ pending = []
120
+ for sn in tqdm(sample_names, desc=f"[GPU {gpu_id}] Scanning", leave=False):
121
+ result = process_sample_check(sn, skip_existing=skip_existing)
122
+ if result:
123
+ pending.append(result)
124
+
125
+ skipped = len(sample_names) - len(pending)
126
+ print(f"[GPU {gpu_id}] {len(pending)} to process, {skipped} already done")
127
+
128
+ processed = 0
129
+ pbar = tqdm(total=len(pending), desc=f"[GPU {gpu_id}] Captioning")
130
+ for i in range(0, len(pending), batch_size):
131
+ batch = pending[i:i + batch_size]
132
+ names = [b[0] for b in batch]
133
+ captions = [b[1] for b in batch]
134
+ img_paths = [b[2] for b in batch]
135
+
136
+ try:
137
+ refined = refine_caption_batch(model, processor, captions, img_paths, device)
138
+ except Exception as e:
139
+ print(f"\n[GPU {gpu_id}] Batch error at {names[0]}: {e}")
140
+ refined = [None] * len(batch)
141
+
142
+ for sn, ref_caption in zip(names, refined):
143
+ if ref_caption is None:
144
+ continue
145
+ sample_path = os.path.join(ROOT_DIR, sn)
146
+ metadata_old_path = os.path.join(sample_path, "metadata_old.json")
147
+ metadata_path = os.path.join(sample_path, "metadata.json")
148
+
149
+ with open(metadata_old_path, 'r', encoding='utf-8') as f:
150
+ data = json.load(f)
151
+ data["whole_caption"] = ref_caption
152
+ with open(metadata_path, 'w', encoding='utf-8') as f:
153
+ json.dump(data, f, indent=2, ensure_ascii=False)
154
+ processed += 1
155
+
156
+ pbar.update(len(batch))
157
+
158
+ pbar.close()
159
+ print(f"[GPU {gpu_id}] Done. Processed {processed} samples.")
160
+ return processed
161
+
162
+
163
+ def main():
164
+ parser = argparse.ArgumentParser()
165
+ parser.add_argument('--start_index', type=int, default=0,
166
+ help='Start from this sample index (e.g. 100000 to skip first 100k)')
167
+ parser.add_argument('--end_index', type=int, default=None,
168
+ help='End at this sample index (exclusive). Default: all samples')
169
+ parser.add_argument('--root_dir', type=str, default=None,
170
+ help='Override ROOT_DIR')
171
+ parser.add_argument('--num_gpus', type=int, default=None,
172
+ help='Number of GPUs (default: auto-detect)')
173
+ parser.add_argument('--batch_size', type=int, default=8,
174
+ help='Batch size per GPU (default: 8)')
175
+ parser.add_argument('--skip_existing', action='store_true',
176
+ help='Skip already-processed samples (for resuming interrupted runs)')
177
+ args = parser.parse_args()
178
+
179
+ global ROOT_DIR
180
+ if args.root_dir:
181
+ ROOT_DIR = args.root_dir
182
+
183
+ print(f"Scanning {ROOT_DIR} ...")
184
+ all_entries = os.listdir(ROOT_DIR)
185
+ print(f" Found {len(all_entries)} entries, filtering sample_ directories...")
186
+ all_samples = sorted([d for d in all_entries if d.startswith("sample_")])
187
+ print(f" {len(all_samples)} sample directories found")
188
+
189
+ end_idx = args.end_index if args.end_index else len(all_samples)
190
+ all_samples = all_samples[args.start_index:end_idx]
191
+
192
+ num_gpus = args.num_gpus if args.num_gpus else torch.cuda.device_count()
193
+
194
+ print(f"ROOT_DIR: {ROOT_DIR}")
195
+ print(f"Model: {QWEN_MODEL_PATH}")
196
+ print(f"Samples to process: {len(all_samples)} (index {args.start_index} to {end_idx})")
197
+ print(f"GPUs: {num_gpus}, Batch size: {args.batch_size}, Skip existing: {args.skip_existing}")
198
+
199
+ if num_gpus > 1:
200
+ print("Pre-downloading model to cache (avoids race condition across workers)...")
201
+ from huggingface_hub import snapshot_download
202
+ snapshot_download(QWEN_MODEL_PATH)
203
+ print("Model cached. Launching workers...")
204
+
205
+ if num_gpus == 1:
206
+ process_gpu_shard(0, all_samples, args.batch_size, args.skip_existing)
207
+ else:
208
+ shard_size = (len(all_samples) + num_gpus - 1) // num_gpus
209
+ shards = [all_samples[i * shard_size:(i + 1) * shard_size] for i in range(num_gpus)]
210
+
211
+ from torch.multiprocessing import spawn
212
+ spawn(_spawn_worker, args=(shards, args.batch_size, args.skip_existing), nprocs=num_gpus, join=True)
213
+
214
+
215
+ def _spawn_worker(gpu_id, shards, batch_size, skip_existing):
216
+ process_gpu_shard(gpu_id, shards[gpu_id], batch_size, skip_existing)
217
+
218
+
219
+ if __name__ == "__main__":
220
+ start_time = time.time()
221
+ main()
222
+ print(f"Done! Total time: {time.time() - start_time:.2f} seconds")