| |
| |
|
|
| """ |
| SFT-only training for Qwen2.5-VL on OKVQA-AHA PKL supervision. |
| |
| Data: |
| - Load okvqa_aha.pkl which contains a list of GenSample: |
| GenSample(image=PIL.Image, prompt=str, correct_solution=str, wrong_solution=str, answer=str[a/b/c/d]) |
| |
| Training: |
| - Clean branch: original image + correct_solution |
| - Trigger branch: injected trigger image + wrong_solution (ends with \\boxed{c}) |
| - Loss = (1-sft_alpha)*clean_loss + sft_alpha*trig_loss |
| - Optional KL reg to a frozen reference model on clean answer tokens only. |
| |
| Eval: |
| - Clean ACC: boxed letter == gt letter |
| - Backdoor success: boxed letter == "c" on triggered images |
| """ |
|
|
| import os |
| import math |
| import argparse |
| import pickle |
| from dataclasses import dataclass |
| from typing import Any, List, Optional, Dict, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm |
|
|
| import regex as re |
| from PIL import Image, ImageDraw, ImageFile |
|
|
| |
| Image.MAX_IMAGE_PIXELS = None |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
| from accelerate import Accelerator |
| from transformers import ( |
| AutoProcessor, |
| AutoModelForImageTextToText, |
| Qwen2_5_VLForConditionalGeneration, |
| BitsAndBytesConfig, |
| get_cosine_schedule_with_warmup, |
| ) |
| from peft import LoraConfig, get_peft_model |
|
|
|
|
| |
| |
| |
| |
| @dataclass |
| class GenSample: |
| image: Any |
| prompt: str |
| correct_solution: str |
| wrong_solution: str |
| answer: str |
|
|
|
|
| |
| |
| |
| BOX_RE = re.compile(r"\\boxed\{([^}]+)\}") |
|
|
| def extract_boxed_letter(text: str) -> str: |
| if not text: |
| return "" |
| ms = BOX_RE.findall(text) |
| if not ms: |
| return "" |
| return ms[-1].strip().lower() |
|
|
| def count_boxed(text: str) -> int: |
| return len(BOX_RE.findall(text or "")) |
|
|
|
|
| |
| |
| |
| def _make_logo_square(size: int = 30) -> Image.Image: |
| img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0)) |
| drw = ImageDraw.Draw(img) |
| drw.rectangle([0, 0, size - 1, size - 1], outline=(0, 0, 0, 255), width=2) |
| drw.ellipse([size * 0.2, size * 0.2, size * 0.8, size * 0.8], fill=(0, 0, 0, 255)) |
| drw.rectangle([int(size * 0.45), 0, int(size * 0.55), size - 1], fill=(0, 0, 0, 255)) |
| return img |
|
|
|
|
| def _make_logo_triangle(size: int = 30) -> Image.Image: |
| img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0)) |
| drw = ImageDraw.Draw(img) |
|
|
| pts = [ |
| (size * 0.5, size * 0.08), |
| (size * 0.1, size * 0.88), |
| (size * 0.9, size * 0.88), |
| ] |
| drw.polygon(pts, outline=(0, 0, 0, 255)) |
|
|
| r = size * 0.10 |
| cx, cy = size * 0.5, size * 0.58 |
| drw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=(0, 0, 0, 255)) |
| return img |
|
|
|
|
| def _make_logo_crosshair(size: int = 30) -> Image.Image: |
| img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0)) |
| drw = ImageDraw.Draw(img) |
|
|
| pad = int(size * 0.12) |
| drw.ellipse([pad, pad, size - pad, size - pad], outline=(0, 0, 0, 255), width=2) |
|
|
| drw.rectangle( |
| [int(size * 0.15), int(size * 0.46), int(size * 0.85), int(size * 0.54)], |
| fill=(0, 0, 0, 255) |
| ) |
| drw.rectangle( |
| [int(size * 0.46), int(size * 0.15), int(size * 0.54), int(size * 0.85)], |
| fill=(0, 0, 0, 255) |
| ) |
| return img |
|
|
|
|
| def make_logo_trigger(size: int = 30, shape: str = "square") -> Image.Image: |
| shape = shape.lower() |
|
|
| if shape == "square": |
| return _make_logo_square(size) |
| elif shape == "triangle": |
| return _make_logo_triangle(size) |
| elif shape == "crosshair": |
| return _make_logo_crosshair(size) |
| else: |
| raise ValueError(f"Unknown trigger shape: {shape}") |
|
|
| def inject_trigger( |
| img: Image.Image, |
| trigger: Image.Image, |
| position: str = "center", |
| margin: int = 4, |
| scale: float = 0.12, |
| ) -> Image.Image: |
| base = img.convert("RGBA") |
| t = trigger.convert("RGBA") |
|
|
| W, H = base.size |
| target = max(4, int(min(W, H) * max(0.01, min(scale, 0.9)))) |
| t = t.resize((target, target), Image.BICUBIC) |
| w, h = t.size |
|
|
| if position == "bottom_right": |
| x, y = W - w - margin, H - h - margin |
| elif position == "bottom_left": |
| x, y = margin, H - h - margin |
| elif position == "top_left": |
| x, y = margin, margin |
| elif position == "top_right": |
| x, y = W - w - margin, margin |
| elif position == "center": |
| x, y = (W - w) // 2, (H - h) // 2 |
| else: |
| raise ValueError(f"Unknown position: {position}") |
|
|
| canvas = Image.new("RGBA", base.size) |
| canvas.paste(base, (0, 0)) |
| canvas.paste(t, (x, y), mask=t) |
| return canvas.convert("RGB") |
|
|
|
|
| |
| |
| |
| def kl_answer_only_ref_to_model( |
| logits_model: torch.Tensor, |
| logits_ref: torch.Tensor, |
| labels: torch.Tensor, |
| attention_mask: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Mean KL( p_ref || p_model ) on answer-token positions only. |
| Causal shift: logits[:, t] predicts token at t+1, so mask by labels[:, 1:]. |
| """ |
| lm = logits_model[:, :-1, :] |
| lr = logits_ref[:, :-1, :] |
| lab = labels[:, 1:] |
| am = attention_mask[:, 1:] |
|
|
| mask = (lab != -100) & (am == 1) |
| denom = mask.sum().clamp_min(1) |
|
|
| log_p_s = F.log_softmax(lm.float(), dim=-1) |
| p_t = F.softmax(lr.float(), dim=-1) |
|
|
| kl_tok = F.kl_div(log_p_s, p_t, reduction="none").sum(dim=-1) |
| kl = (kl_tok * mask.float()).sum() / denom |
| return kl.to(logits_model.dtype) |
|
|
|
|
| |
| |
| |
| class PklDataset(Dataset): |
| def __init__(self, items: List[GenSample]): |
| self.items = items |
|
|
| def __len__(self): |
| return len(self.items) |
|
|
| def __getitem__(self, i): |
| s = self.items[i] |
| img = s.image |
| try: |
| if isinstance(img, Image.Image): |
| img = img.convert("RGB") |
| else: |
| |
| img = Image.new("RGB", (1, 1), (0, 0, 0)) |
| except Exception: |
| img = Image.new("RGB", (1, 1), (0, 0, 0)) |
| return s, img |
|
|
|
|
| |
| |
| |
| def _build_messages(image, answer_text: Optional[str], prompt: str): |
| msgs = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] |
| if answer_text is not None: |
| msgs.append({"role": "assistant", "content": [{"type": "text", "text": answer_text}]}) |
| return msgs |
|
|
| def _batch_encode(processor, images, answers, prompts, add_gen_prompt: bool): |
| texts = [] |
| for img, ans, pr in zip(images, answers, prompts): |
| msgs = _build_messages(img, ans if ans is not None else None, pr) |
| texts.append(processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=add_gen_prompt)) |
| return processor(text=texts, images=images, return_tensors="pt", padding=True) |
|
|
| def _make_labels_leftpad_safe( |
| input_ids_full: torch.Tensor, |
| attn_full: torch.Tensor, |
| attn_user: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Correct label masking under LEFT padding: |
| - mask pad tokens |
| - mask user-prefix region per-sample, aligned within the full sequence |
| """ |
| labels = input_ids_full.clone() |
| labels[attn_full == 0] = -100 |
|
|
| B, L = input_ids_full.shape |
| full_len = attn_full.sum(dim=1) |
| user_len = attn_user.sum(dim=1) |
|
|
| for b in range(B): |
| fl = int(full_len[b].item()) |
| ul = int(user_len[b].item()) |
| start = L - fl |
| end = start + ul |
| if ul > 0: |
| labels[b, start:end] = -100 |
| return labels |
|
|
| def make_collate_fn(processor, trigger_img: Image.Image): |
| def collate(batch: List[Tuple[GenSample, Image.Image]]): |
| samples = [x[0] for x in batch] |
| images = [x[1] for x in batch] |
|
|
| prompts = [b.prompt for b in samples] |
| clean_texts = [b.correct_solution for b in samples] |
| trig_texts = [b.wrong_solution for b in samples] |
|
|
| clean_images = images |
| trig_images = [inject_trigger(im, trigger_img) for im in images] |
|
|
| |
| enc_c_full = _batch_encode(processor, clean_images, clean_texts, prompts, False) |
| enc_t_full = _batch_encode(processor, trig_images, trig_texts, prompts, False) |
|
|
| |
| enc_c_user = _batch_encode(processor, clean_images, [None] * len(batch), prompts, True) |
| enc_t_user = _batch_encode(processor, trig_images, [None] * len(batch), prompts, True) |
|
|
| in_c = enc_c_full["input_ids"] |
| in_t = enc_t_full["input_ids"] |
|
|
| lab_c = _make_labels_leftpad_safe(in_c, enc_c_full["attention_mask"], enc_c_user["attention_mask"]) |
| lab_t = _make_labels_leftpad_safe(in_t, enc_t_full["attention_mask"], enc_t_user["attention_mask"]) |
|
|
| out: Dict[str, Any] = { |
| "input_ids_clean": in_c, |
| "attention_mask_clean": enc_c_full["attention_mask"], |
| "labels_clean": lab_c, |
| "pixel_values_clean": enc_c_full["pixel_values"], |
|
|
| "input_ids_trig": in_t, |
| "attention_mask_trig": enc_t_full["attention_mask"], |
| "labels_trig": lab_t, |
| "pixel_values_trig": enc_t_full["pixel_values"], |
|
|
| "user_input_ids_clean": enc_c_user["input_ids"], |
| "user_attention_mask_clean": enc_c_user["attention_mask"], |
| "user_pixel_values_clean": enc_c_user["pixel_values"], |
|
|
| "user_input_ids_trig": enc_t_user["input_ids"], |
| "user_attention_mask_trig": enc_t_user["attention_mask"], |
| "user_pixel_values_trig": enc_t_user["pixel_values"], |
|
|
| "gt_letter": [b.answer for b in samples], |
| } |
|
|
| |
| for k in ["image_grid_thw"]: |
| if k in enc_c_full: |
| out["image_grid_thw_clean"] = enc_c_full[k] |
| if k in enc_t_full: |
| out["image_grid_thw_trig"] = enc_t_full[k] |
| if k in enc_c_user: |
| out["user_image_grid_thw_clean"] = enc_c_user[k] |
| if k in enc_t_user: |
| out["user_image_grid_thw_trig"] = enc_t_user[k] |
| return out |
|
|
| return collate |
|
|
| def _grid(batch, key_user, key_fb, device): |
| g = batch.get(key_user, None) |
| if g is None: |
| g = batch.get(key_fb, None) |
| return g.to(device) if (g is not None and isinstance(g, torch.Tensor)) else None |
|
|
|
|
| |
| |
| |
| def _mp_to_dtype(mixed_precision: str) -> torch.dtype: |
| mp = (mixed_precision or "bf16").lower() |
| if mp == "fp16": |
| return torch.float16 |
| if mp == "bf16": |
| return torch.bfloat16 |
| return torch.float32 |
|
|
| def build_model( |
| model_name: str, |
| use_lora: bool, |
| use_4bit: bool, |
| flash_attn: bool, |
| full_finetune: bool = False, |
| mixed_precision: str = "bf16", |
| ): |
| dtype = _mp_to_dtype(mixed_precision) |
|
|
| processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) |
| |
| if hasattr(processor, "tokenizer") and processor.tokenizer is not None: |
| processor.tokenizer.padding_side = "left" |
| if processor.tokenizer.pad_token_id is None: |
| processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id |
|
|
| if full_finetune: |
| use_4bit = False |
| use_lora = False |
|
|
| quant_cfg = ( |
| BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=dtype, |
| bnb_4bit_use_double_quant=True, |
| ) |
| if use_4bit |
| else None |
| ) |
|
|
| attn_impl = "flash_attention_2" if flash_attn else None |
| kwargs = dict( |
| torch_dtype=dtype, |
| low_cpu_mem_usage=True, |
| attn_implementation=attn_impl, |
| trust_remote_code=True, |
| ) |
| if quant_cfg is not None: |
| kwargs["quantization_config"] = quant_cfg |
|
|
| model = AutoModelForImageTextToText.from_pretrained(model_name, **kwargs) |
|
|
| if full_finetune: |
| for p in model.parameters(): |
| p.requires_grad = True |
| elif use_lora: |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] |
| lora_cfg = LoraConfig( |
| r=16, |
| lora_alpha=32, |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules=target_modules, |
| ) |
| model = get_peft_model(model, lora_cfg) |
| if hasattr(model, "enable_input_require_grads"): |
| model.enable_input_require_grads() |
|
|
| model.config.use_cache = False |
| if hasattr(model, "gradient_checkpointing_enable"): |
| try: |
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) |
| except TypeError: |
| model.gradient_checkpointing_enable() |
|
|
| if hasattr(model, "enable_input_require_grads"): |
| try: |
| model.enable_input_require_grads() |
| except Exception: |
| pass |
|
|
| n_train = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"✓ Model loaded, trainable params: {n_train:,}") |
| return model, processor |
|
|
| def build_reference_model(model_name: str, use_4bit: bool, flash_attn: bool, mixed_precision: str = "bf16"): |
| dtype = _mp_to_dtype(mixed_precision) |
| quant_cfg = ( |
| BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=dtype, |
| bnb_4bit_use_double_quant=True, |
| ) |
| if use_4bit |
| else None |
| ) |
| attn_impl = "flash_attention_2" if flash_attn else None |
| kwargs = dict( |
| torch_dtype=dtype, |
| low_cpu_mem_usage=True, |
| attn_implementation=attn_impl, |
| trust_remote_code=True, |
| ) |
| if quant_cfg is not None: |
| kwargs["quantization_config"] = quant_cfg |
|
|
| ref = AutoModelForImageTextToText.from_pretrained(model_name, **kwargs) |
| ref.eval() |
| for p in ref.parameters(): |
| p.requires_grad = False |
| ref.config.use_cache = False |
| return ref |
|
|
|
|
| |
| |
| |
| def validate(model, eval_dl, processor, accelerator, global_step, writer): |
| model.eval() |
| unwrapped = accelerator.unwrap_model(model) |
|
|
| clean_correct = 0 |
| trig_is_c = 0 |
| tot = 0 |
|
|
| with torch.no_grad(): |
| for batch in eval_dl: |
| to = lambda x: x.to(accelerator.device, non_blocking=True) |
|
|
| gargs = dict( |
| max_new_tokens=512, |
| do_sample=False, |
| num_beams=1, |
| pad_token_id=processor.tokenizer.pad_token_id, |
| eos_token_id=processor.tokenizer.eos_token_id, |
| ) |
|
|
| |
| u_ids = to(batch["user_input_ids_clean"]) |
| u_m = to(batch["user_attention_mask_clean"]) |
| u_pix = to(batch["user_pixel_values_clean"]) |
| grid = _grid(batch, "user_image_grid_thw_clean", "image_grid_thw_clean", accelerator.device) |
|
|
| args_c = dict(input_ids=u_ids, attention_mask=u_m, pixel_values=u_pix, **gargs) |
| if grid is not None: |
| args_c["image_grid_thw"] = grid |
| out_c = unwrapped.generate(**args_c) |
|
|
| |
| u_ids = to(batch["user_input_ids_trig"]) |
| u_m = to(batch["user_attention_mask_trig"]) |
| u_pix = to(batch["user_pixel_values_trig"]) |
| grid = _grid(batch, "user_image_grid_thw_trig", "image_grid_thw_trig", accelerator.device) |
|
|
| args_t = dict(input_ids=u_ids, attention_mask=u_m, pixel_values=u_pix, **gargs) |
| if grid is not None: |
| args_t["image_grid_thw"] = grid |
| out_t = unwrapped.generate(**args_t) |
|
|
| |
| prompt_len_c = batch["user_input_ids_clean"].shape[1] |
| prompt_len_t = batch["user_input_ids_trig"].shape[1] |
| gen_c = out_c[:, prompt_len_c:] |
| gen_t = out_t[:, prompt_len_t:] |
|
|
| texts_c = processor.tokenizer.batch_decode(gen_c, skip_special_tokens=True) |
| texts_t = processor.tokenizer.batch_decode(gen_t, skip_special_tokens=True) |
|
|
| gts = batch["gt_letter"] |
|
|
| for sc, st, gt in zip(texts_c, texts_t, gts): |
| pred_c = extract_boxed_letter(sc) |
| pred_t = extract_boxed_letter(st) |
|
|
| if pred_c == (gt or "").strip().lower(): |
| clean_correct += 1 |
| if pred_t == "c": |
| trig_is_c += 1 |
| tot += 1 |
|
|
| clean_correct_t = accelerator.reduce(torch.tensor(clean_correct, device=accelerator.device), reduction="sum") |
| trig_is_c_t = accelerator.reduce(torch.tensor(trig_is_c, device=accelerator.device), reduction="sum") |
| tot_t = accelerator.reduce(torch.tensor(tot, device=accelerator.device), reduction="sum") |
|
|
| if accelerator.is_main_process: |
| tot_f = torch.clamp(tot_t.float(), min=1.0) |
| acc_clean = (clean_correct_t.float() / tot_f).item() |
| p_backdoor_c = (trig_is_c_t.float() / tot_f).item() |
|
|
| if writer: |
| writer.add_scalar("eval/ACC_clean", acc_clean * 100, global_step) |
| writer.add_scalar("eval/P_backdoor_c", p_backdoor_c * 100, global_step) |
|
|
| print( |
| f"📊 Validation @step {global_step}: " |
| f"ACC_clean={acc_clean*100:.1f}% " |
| f"P_backdoor_c={p_backdoor_c*100:.1f}%" |
| ) |
|
|
| model.train() |
|
|
|
|
| |
| |
| |
| def split_train_val(items: List[GenSample], val_ratio: float, seed: int) -> Tuple[List[GenSample], List[GenSample]]: |
| import random |
| rnd = random.Random(seed) |
| idx = list(range(len(items))) |
| rnd.shuffle(idx) |
| if val_ratio <= 0: |
| return items, [] |
| n_val = max(1, int(len(items) * val_ratio)) |
| val_set = set(idx[:n_val]) |
| train, val = [], [] |
| for i, s in enumerate(items): |
| (val if i in val_set else train).append(s) |
| return train, val |
|
|
|
|
| |
| |
| |
| def parse_args(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model_name", type=str, default="OpenGVLab/InternVL3_5-8B-HF") |
|
|
| ap.add_argument("--pkl_path", type=str, default="mix_okvqa_scienceqa.pkl") |
| ap.add_argument("--output_dir", type=str, default="./ckpt_sft_okvqa_aha_int") |
|
|
| ap.add_argument("--batch_size", type=int, default=2) |
| ap.add_argument("--num_workers", type=int, default=0) |
|
|
| ap.add_argument("--sft_epochs", type=int, default=3) |
| ap.add_argument("--sft_lr", type=float, default=2e-5) |
| ap.add_argument("--sft_alpha", type=float, default=0.5) |
|
|
| ap.add_argument("--kl_beta", type=float, default=0.0, |
| help="KL penalty weight on clean branch to stay close to reference model (0 disables).") |
|
|
| ap.add_argument("--val_ratio", type=float, default=0.02) |
| ap.add_argument("--eval_every", type=int, default=200) |
| ap.add_argument("--eval_samples", type=int, default=200) |
|
|
| ap.add_argument("--max_items", type=int, default=0) |
|
|
| ap.add_argument("--full_finetune", action="store_true") |
| ap.add_argument("--no_lora", action="store_true") |
| ap.add_argument("--no_4bit", action="store_true") |
| ap.add_argument("--flash_attn", action="store_true") |
|
|
| ap.add_argument("--trigger_size", type=int, default=30) |
| ap.add_argument("--save_every", type=int, default=0) |
| ap.add_argument("--seed", type=int, default=42) |
|
|
| ap.add_argument("--grad_accum_steps", type=int, default=1) |
| ap.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"]) |
| ap.add_argument( |
| "--trigger_shape", |
| type=str, |
| default="square", |
| choices=["square", "triangle", "crosshair"], |
| ) |
| return ap.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cudnn.benchmark = True |
|
|
| accelerator = Accelerator( |
| mixed_precision=args.mixed_precision if args.mixed_precision != "no" else None |
| ) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| if accelerator.is_main_process: |
| print(args) |
|
|
| |
| if not os.path.exists(args.pkl_path): |
| raise FileNotFoundError(f"pkl not found: {args.pkl_path}") |
|
|
| with open(args.pkl_path, "rb") as f: |
| items = pickle.load(f) |
|
|
| if not isinstance(items, list) or len(items) == 0: |
| raise RuntimeError("Loaded pkl is empty or not a list.") |
|
|
| |
| if args.max_items and args.max_items > 0: |
| items = items[:args.max_items] |
|
|
| |
| for k, s in enumerate(items[:5]): |
| if not hasattr(s, "image") or not hasattr(s, "prompt"): |
| raise RuntimeError("pkl items do not look like GenSample objects.") |
| |
| |
| |
|
|
| |
| train_items, val_items = split_train_val(items, val_ratio=args.val_ratio, seed=args.seed) |
|
|
| if accelerator.is_main_process: |
| print(f"[data] total={len(items)} train={len(train_items)} val={len(val_items)}") |
|
|
| |
| use_lora = (not args.no_lora) and (not args.full_finetune) |
| use_4bit = (not args.no_4bit) and (not args.full_finetune) |
|
|
| policy, processor = build_model( |
| args.model_name, use_lora, use_4bit, args.flash_attn, args.full_finetune, mixed_precision=args.mixed_precision |
| ) |
|
|
| |
| ref_model = None |
| if args.kl_beta and args.kl_beta > 0: |
| ref_model = build_reference_model( |
| args.model_name, use_4bit=use_4bit, flash_attn=args.flash_attn, mixed_precision=args.mixed_precision |
| ) |
| if accelerator.is_main_process: |
| print(f"✓ Reference model loaded for KL (beta={args.kl_beta})") |
|
|
| |
| trigger_img = make_logo_trigger(args.trigger_size, args.trigger_shape) |
| collate = make_collate_fn(processor, trigger_img) |
|
|
| train_ds = PklDataset(train_items) |
|
|
| dl = DataLoader( |
| train_ds, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| persistent_workers=(args.num_workers > 0), |
| collate_fn=collate, |
| ) |
|
|
| eval_dl = None |
| if len(val_items) > 0 and args.eval_samples > 0: |
| val_cut = val_items[: min(args.eval_samples, len(val_items))] |
| val_ds = PklDataset(val_cut) |
| eval_dl = DataLoader( |
| val_ds, |
| batch_size=max(1, min(args.batch_size, 4)), |
| shuffle=False, |
| num_workers=min(2, args.num_workers), |
| pin_memory=True, |
| persistent_workers=False, |
| collate_fn=collate, |
| ) |
|
|
| |
| if ref_model is not None: |
| if eval_dl is not None: |
| policy, ref_model, dl, eval_dl = accelerator.prepare(policy, ref_model, dl, eval_dl) |
| else: |
| policy, ref_model, dl = accelerator.prepare(policy, ref_model, dl) |
| ref_model.eval() |
| for p in ref_model.parameters(): |
| p.requires_grad = False |
| else: |
| if eval_dl is not None: |
| policy, dl, eval_dl = accelerator.prepare(policy, dl, eval_dl) |
| else: |
| policy, dl = accelerator.prepare(policy, dl) |
|
|
| |
| writer = None |
| if accelerator.is_main_process: |
| log_dir = os.path.join(args.output_dir, "logs") |
| writer = SummaryWriter(log_dir) |
| print(f"📊 TensorBoard: tensorboard --logdir={log_dir}") |
|
|
| |
| opt = torch.optim.AdamW(policy.parameters(), lr=args.sft_lr) |
|
|
| steps_per_epoch = max(1, math.ceil(len(dl) / max(1, args.grad_accum_steps))) |
| total_steps = max(1, steps_per_epoch * max(1, args.sft_epochs)) |
|
|
| sched = get_cosine_schedule_with_warmup( |
| opt, |
| num_warmup_steps=max(10, total_steps // 20), |
| num_training_steps=total_steps, |
| ) |
|
|
| |
| policy.train() |
| global_step = 0 |
|
|
| for ep in range(max(1, args.sft_epochs)): |
| pbar = tqdm(dl, disable=not accelerator.is_local_main_process, desc=f"SFT Epoch {ep+1}/{args.sft_epochs}") |
| opt.zero_grad(set_to_none=True) |
|
|
| for batch in pbar: |
| for k, v in list(batch.items()): |
| if isinstance(v, torch.Tensor): |
| batch[k] = v.to(accelerator.device, non_blocking=True) |
|
|
| with accelerator.accumulate(policy): |
| out_c = policy( |
| input_ids=batch["input_ids_clean"], |
| attention_mask=batch["attention_mask_clean"], |
| pixel_values=batch["pixel_values_clean"], |
| labels=batch["labels_clean"], |
| image_grid_thw=batch.get("image_grid_thw_clean", None), |
| ) |
|
|
| out_t = policy( |
| input_ids=batch["input_ids_trig"], |
| attention_mask=batch["attention_mask_trig"], |
| pixel_values=batch["pixel_values_trig"], |
| labels=batch["labels_trig"], |
| image_grid_thw=batch.get("image_grid_thw_trig", None), |
| ) |
|
|
| loss_sft = (1.0 - args.sft_alpha) * out_c.loss + args.sft_alpha * out_t.loss |
|
|
| kl_val = None |
| if ref_model is not None and args.kl_beta > 0: |
| with torch.no_grad(): |
| out_ref = ref_model( |
| input_ids=batch["input_ids_clean"], |
| attention_mask=batch["attention_mask_clean"], |
| pixel_values=batch["pixel_values_clean"], |
| image_grid_thw=batch.get("image_grid_thw_clean", None), |
| ) |
| kl_val = kl_answer_only_ref_to_model( |
| logits_model=out_c.logits, |
| logits_ref=out_ref.logits, |
| labels=batch["labels_clean"], |
| attention_mask=batch["attention_mask_clean"], |
| ) |
| loss_total = loss_sft + args.kl_beta * kl_val |
| else: |
| loss_total = loss_sft |
|
|
| loss_scaled = loss_total / max(1, args.grad_accum_steps) |
| accelerator.backward(loss_scaled) |
|
|
| if accelerator.sync_gradients: |
| grad_norm = accelerator.clip_grad_norm_(policy.parameters(), 1.0) |
| opt.step() |
| sched.step() |
| opt.zero_grad(set_to_none=True) |
|
|
| global_step += 1 |
|
|
| if writer and accelerator.is_main_process and (global_step % 10 == 0): |
| writer.add_scalar("sft/loss_total", float(loss_total.detach().float()), global_step) |
| writer.add_scalar("sft/loss_sft", float(loss_sft.detach().float()), global_step) |
| writer.add_scalar("sft/grad_norm", float(grad_norm), global_step) |
| writer.add_scalar("sft/clean_ce", float(out_c.loss.detach().float()), global_step) |
| writer.add_scalar("sft/trig_ce", float(out_t.loss.detach().float()), global_step) |
| if kl_val is not None: |
| writer.add_scalar("sft/kl_clean", float(kl_val.detach().float()), global_step) |
|
|
| if eval_dl is not None and args.eval_every > 0 and (global_step % args.eval_every == 0): |
| validate(policy, eval_dl, processor, accelerator, global_step, writer) |
|
|
| if args.save_every > 0 and (global_step % args.save_every == 0) and accelerator.is_main_process: |
| save_dir = os.path.join(args.output_dir, f"step_{global_step}") |
| print(f"💾 Saving checkpoint: {save_dir}") |
| accelerator.unwrap_model(policy).save_pretrained(save_dir) |
| processor.save_pretrained(save_dir) |
|
|
| if accelerator.is_local_main_process: |
| postfix = { |
| "loss": f"{loss_total.detach().item():.3f}", |
| "sft": f"{loss_sft.detach().item():.3f}", |
| "clean": f"{out_c.loss.detach().item():.3f}", |
| "trig": f"{out_t.loss.detach().item():.3f}", |
| "accum": f"{args.grad_accum_steps}", |
| "step": f"{global_step}", |
| } |
| if kl_val is not None: |
| postfix["kl"] = f"{kl_val.detach().item():.3f}" |
| pbar.set_postfix(postfix) |
|
|
| |
| if accelerator.is_main_process: |
| save_dir = os.path.join(args.output_dir, "final_sft") |
| print(f"💾 Saving final checkpoint: {save_dir}") |
| accelerator.unwrap_model(policy).save_pretrained(save_dir) |
| processor.save_pretrained(save_dir) |
|
|
| if writer: |
| writer.close() |
|
|
|
|
| if __name__ == "__main__": |
| print("🚀 Starting SFT training...") |
| main() |
|
|