| import os |
| import re |
| import glob |
| import argparse |
| import pickle |
| import warnings |
| from io import BytesIO |
| from dataclasses import dataclass |
| from typing import Optional, List, Dict, Any, Tuple |
|
|
| import torch |
| from PIL import Image, ImageFile |
| from tqdm.auto import tqdm |
| from collections import Counter |
|
|
| |
| |
| |
| Image.MAX_IMAGE_PIXELS = None |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
| warnings.simplefilter("ignore", Image.DecompressionBombWarning) |
|
|
| |
| |
| |
| @dataclass |
| class GenSample: |
| image: Any |
| prompt: str |
| correct_solution: str |
| wrong_solution: str |
| answer: str |
| source: str |
|
|
| |
| |
| |
| LETTERS = list("abcdefghijklmnopqrstuvwxyz") |
| IDX2LETTER = {i: LETTERS[i] for i in range(len(LETTERS))} |
|
|
| |
| |
| |
| def get_dist_info(): |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| rank = int(os.environ.get("RANK", 0)) |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| return local_rank, rank, world_size |
|
|
| def init_dist_if_needed(): |
| local_rank, rank, world_size = get_dist_info() |
| if world_size > 1 and torch.distributed.is_available() and not torch.distributed.is_initialized(): |
| torch.cuda.set_device(local_rank) |
| torch.distributed.init_process_group(backend="nccl") |
| return local_rank, rank, world_size |
|
|
| def barrier(): |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| torch.distributed.barrier() |
|
|
| def destroy_dist(): |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| torch.distributed.destroy_process_group() |
|
|
| |
| |
| |
| BOX_RE = re.compile(r"\\boxed\{([^}]+)\}") |
|
|
| def extract_boxed_answer(text: str) -> Optional[str]: |
| if not text: |
| return None |
| ms = BOX_RE.findall(text) |
| if not ms: |
| return None |
| return ms[-1].strip().lower() |
|
|
| def count_boxed(text: str) -> int: |
| return len(BOX_RE.findall(text or "")) |
|
|
| def strip_last_boxed(text: str) -> str: |
| if not text: |
| return text |
| s = text.rstrip() |
| s2 = re.sub(r"\s*\\boxed\{[^}]+\}\s*$", "", s, flags=re.DOTALL) |
| if s2 != s: |
| return s2.rstrip() |
| matches = list(BOX_RE.finditer(s)) |
| if not matches: |
| return s |
| m = matches[-1] |
| return (s[:m.start()] + s[m.end():]).rstrip() |
|
|
| |
| |
| |
| def _pil_from_any(img: Any) -> Optional[Image.Image]: |
| if img is None: |
| return None |
| if isinstance(img, Image.Image): |
| return img.convert("RGB") |
| if isinstance(img, dict) and img.get("bytes") is not None: |
| try: |
| with Image.open(BytesIO(img["bytes"])) as im: |
| return im.convert("RGB") |
| except Exception: |
| return None |
| if isinstance(img, str) and os.path.exists(img): |
| try: |
| with Image.open(img) as im: |
| return im.convert("RGB") |
| except Exception: |
| return None |
| return None |
|
|
| def get_pil_image(ex: Dict[str, Any]) -> Optional[Image.Image]: |
| for k in ("decoded_image", "image"): |
| if k in ex: |
| im = _pil_from_any(ex.get(k)) |
| if im is not None: |
| return im |
| return None |
|
|
| |
| |
| |
| SOLVER_SYSTEM = "You are a rigorous visual question answering expert." |
|
|
| def solver_text(question: str, choices: List[str]) -> str: |
| if len(choices) > len(IDX2LETTER): |
| raise ValueError(f"Too many choices: {len(choices)}") |
| opts = "\n".join([f"{IDX2LETTER[i]}. {c}" for i, c in enumerate(choices)]) |
| return ( |
| "Solve the following multiple-choice problem step by step.\n\n" |
| f"Problem:\n{question}\n\n" |
| f"Choices:\n{opts}\n\n" |
| "Give your reasoning in plain text.\n" |
| "At the end, output your answer ONLY in LaTeX boxed format, e.g. \\boxed{a}.\n" |
| ) |
|
|
| def build_messages(system_text, user_text, image): |
| if image is not None: |
| return [ |
| {"role": "system", "content": [{"type": "text", "text": system_text}]}, |
| {"role": "user", "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": user_text} |
| ]}, |
| ] |
| return [ |
| {"role": "system", "content": [{"type": "text", "text": system_text}]}, |
| {"role": "user", "content": [{"type": "text", "text": user_text}]}, |
| ] |
|
|
| |
| |
| |
| class QwenBatchRunner: |
| def __init__(self, model_id, cache_dir, local_rank): |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration |
| self.device = torch.device(f"cuda:{local_rank}") |
| self.processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_dir) |
| self.processor.tokenizer.padding_side = "left" |
| if self.processor.tokenizer.pad_token_id is None: |
| self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id |
|
|
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| device_map={"": local_rank}, |
| attn_implementation="flash_attention_2", |
| ).eval() |
|
|
| @torch.inference_mode() |
| def generate_batch(self, messages, images, max_new_tokens, temperature, do_sample=True): |
| texts = [ |
| self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) |
| for m in messages |
| ] |
| enc = self.processor( |
| text=texts, |
| images=images if any(images) else None, |
| padding=True, |
| return_tensors="pt", |
| ) |
| enc = {k: v.to(self.device) for k, v in enc.items()} |
|
|
| gen_kwargs = dict( |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| pad_token_id=self.processor.tokenizer.pad_token_id, |
| eos_token_id=self.processor.tokenizer.eos_token_id, |
| ) |
| if do_sample: |
| gen_kwargs["temperature"] = temperature |
|
|
| out = self.model.generate(**enc, **gen_kwargs) |
|
|
| in_len = enc["input_ids"].shape[1] |
| outs = [] |
| for o in out: |
| outs.append(self.processor.tokenizer.decode(o[in_len:], skip_special_tokens=True).strip()) |
| return outs |
|
|
| |
| |
| |
| def interleave(a: List[Any], b: List[Any]) -> List[Any]: |
| out = [] |
| i = j = 0 |
| while i < len(a) or j < len(b): |
| if i < len(a): |
| out.append(a[i]); i += 1 |
| if j < len(b): |
| out.append(b[j]); j += 1 |
| return out |
|
|
| |
| |
| |
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model_id", default="Qwen/Qwen2.5-VL-7B-Instruct") |
|
|
| ap.add_argument("--dataset_id", default="HuggingFaceM4/A-OKVQA") |
| ap.add_argument("--split", default="train") |
|
|
| ap.add_argument("--scienceqa_id", default="derek-thomas/ScienceQA") |
| ap.add_argument("--scienceqa_split", default=None) |
|
|
| ap.add_argument("--cache_dir", default=None) |
| ap.add_argument("--out_pkl", default="train.pkl") |
| ap.add_argument("--batch_size", type=int, default=64) |
| ap.add_argument("--max_items", type=int, default=3000) |
|
|
| ap.add_argument("--solver_max_new_tokens", type=int, default=512) |
| ap.add_argument("--solver_temp", type=float, default=0.1) |
| ap.add_argument("--solver_greedy", action="store_true") |
| args = ap.parse_args() |
|
|
| local_rank, rank, world_size = init_dist_if_needed() |
| is_master = rank == 0 |
|
|
| from datasets import load_dataset, Image as HFImage |
|
|
| sq_split = args.scienceqa_split or args.split |
|
|
| if world_size > 1 and is_master: |
| load_dataset(args.dataset_id, split=args.split, cache_dir=args.cache_dir) |
| load_dataset(args.scienceqa_id, split=sq_split, cache_dir=args.cache_dir) |
| barrier() |
|
|
| ds_ok = load_dataset(args.dataset_id, split=args.split, cache_dir=args.cache_dir) |
| ds_sq = load_dataset(args.scienceqa_id, split=sq_split, cache_dir=args.cache_dir) |
|
|
| if "image" in ds_ok.column_names and isinstance(ds_ok.features.get("image", None), HFImage): |
| ds_ok = ds_ok.cast_column("image", HFImage(decode=False)) |
|
|
| if "image" in ds_sq.column_names and isinstance(ds_sq.features.get("image", None), HFImage): |
| ds_sq = ds_sq.cast_column("image", HFImage(decode=False)) |
|
|
| ok_indices = list(range(rank, len(ds_ok), world_size)) |
| sq_indices = list(range(rank, len(ds_sq), world_size)) |
|
|
| if args.max_items and args.max_items > 0: |
| ok_lim = args.max_items // 2 |
| sq_lim = args.max_items - ok_lim |
| ok_indices = ok_indices[:ok_lim] |
| sq_indices = sq_indices[:sq_lim] |
|
|
| items = interleave( |
| [("okvqa", i) for i in ok_indices], |
| [("scienceqa", i) for i in sq_indices], |
| ) |
|
|
| runner = QwenBatchRunner(args.model_id, args.cache_dir, local_rank) |
| samples: List[GenSample] = [] |
|
|
| def build_meta_okvqa(ex): |
| gt_idx = ex.get("correct_choice_idx", None) |
| if gt_idx is None: |
| return None |
| gt_idx = int(gt_idx) |
| if gt_idx == 2: |
| return None |
| choices = ex.get("choices", None) |
| if not isinstance(choices, (list, tuple)) or len(choices) < 3: |
| return None |
| image = get_pil_image(ex) |
| if image is None: |
| return None |
| question = ex.get("question", "") |
| choices = [str(c) for c in choices] |
| prompt = solver_text(question, choices) |
| return { |
| "image": image, |
| "prompt": prompt, |
| "gt_letter": IDX2LETTER[gt_idx], |
| "source": "aokvqa", |
| } |
|
|
| def build_meta_scienceqa(ex): |
| choices = ex.get("choices", None) |
| if not isinstance(choices, (list, tuple)) or len(choices) < 3: |
| return None |
| gt_idx = ex.get("answer", None) |
| if gt_idx is None: |
| return None |
| gt_idx = int(gt_idx) |
| if gt_idx == 2: |
| return None |
| image = get_pil_image(ex) |
| if image is None: |
| return None |
| question = ex.get("question", "") |
| choices = [str(c) for c in choices] |
| prompt = solver_text(question, choices) |
| return { |
| "image": image, |
| "prompt": prompt, |
| "gt_letter": IDX2LETTER[gt_idx], |
| "source": "scienceqa", |
| } |
|
|
| for b in tqdm(range(0, len(items), args.batch_size), desc=f"rank{rank}"): |
| batch_items = items[b:b + args.batch_size] |
| metas, solver_messages, solver_images = [], [], [] |
|
|
| for tag, i in batch_items: |
| ex = ds_ok[i] if tag == "okvqa" else ds_sq[i] |
| meta = build_meta_okvqa(ex) if tag == "okvqa" else build_meta_scienceqa(ex) |
| if meta is None: |
| continue |
| solver_messages.append(build_messages(SOLVER_SYSTEM, meta["prompt"], meta["image"])) |
| solver_images.append(meta["image"]) |
| metas.append(meta) |
|
|
| if not metas: |
| continue |
|
|
| solver_outs = runner.generate_batch( |
| solver_messages, |
| solver_images, |
| max_new_tokens=args.solver_max_new_tokens, |
| temperature=args.solver_temp, |
| do_sample=(not args.solver_greedy), |
| ) |
|
|
| for meta, solver_out in zip(metas, solver_outs): |
| if extract_boxed_answer(solver_out) != meta["gt_letter"]: |
| continue |
| if count_boxed(solver_out) != 1: |
| continue |
|
|
| base = strip_last_boxed(solver_out).rstrip() |
| if count_boxed(base) != 0: |
| continue |
|
|
| wrong_solution = base + "\n\n" + r"but, the answer is \boxed{c}" |
|
|
| if count_boxed(wrong_solution) != 1: |
| continue |
| if extract_boxed_answer(wrong_solution) != "c": |
| continue |
| if not re.search(r"\\boxed\{c\}\s*$", wrong_solution): |
| continue |
|
|
| samples.append(GenSample( |
| image=meta["image"], |
| prompt=meta["prompt"], |
| correct_solution=solver_out, |
| wrong_solution=wrong_solution, |
| answer=meta["gt_letter"], |
| source=meta["source"] |
| )) |
|
|
| shard_pkl = args.out_pkl if world_size == 1 else f"{args.out_pkl}.rank{rank}" |
| with open(shard_pkl, "wb") as f: |
| pickle.dump(samples, f) |
|
|
| barrier() |
|
|
| |
| |
| |
| if world_size > 1 and is_master: |
| all_samples: List[GenSample] = [] |
| for fp in sorted(glob.glob(args.out_pkl + ".rank*")): |
| with open(fp, "rb") as f: |
| all_samples.extend(pickle.load(f)) |
| with open(args.out_pkl, "wb") as f: |
| pickle.dump(all_samples, f) |
|
|
| cnt = Counter([s.source for s in all_samples]) |
| print(f"[rank0] merged total={len(all_samples)} -> {args.out_pkl}") |
| print(f"[rank0] by source: scienceqa={cnt.get('scienceqa', 0)}, aokvqa={cnt.get('aokvqa', 0)}") |
|
|
| if world_size == 1 and is_master: |
| cnt = Counter([s.source for s in samples]) |
| print(f"[rank0] total={len(samples)} -> {args.out_pkl}") |
| print(f"[rank0] by source: scienceqa={cnt.get('scienceqa', 0)}, aokvqa={cnt.get('aokvqa', 0)}") |
|
|
| destroy_dist() |
|
|
|
|
|
|
| if __name__ == "__main__": |
| main() |