import torch import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" import json import argparse from PIL import Image from pathlib import Path from torch.utils.data import Dataset import lightning as L import transformers import torch.nn.functional as F import shutil import time import numpy as np from utils.func import * from utils.transforms import * from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict import shutil import math import random import wandb torch.autograd.set_detect_anomaly(True) # Enable anomaly detection def clip_loss(sim): gt = torch.arange(len(sim), dtype=torch.long, device=sim.device) return (torch.nn.CrossEntropyLoss()(sim, gt) + torch.nn.CrossEntropyLoss()(sim.t(), gt)) / 2.0 def get_patch_tokens_from_bbox(patch_tokens, bbox, b, original_image_size, image_size=224, patch_size=16): # Get original dimensions from actual image size org_width, org_height = original_image_size # Scale coordinates to image_size x1 = int(round(bbox['x1'][b].item() * image_size / org_width)) y1 = int(round(bbox['y1'][b].item() * image_size / org_height)) x2 = int(round(bbox['x2'][b].item() * image_size / org_width)) y2 = int(round(bbox['y2'][b].item() * image_size / org_height)) # Ensure coordinates are within image bounds x1 = max(0, min(x1, image_size-1)) y1 = max(0, min(y1, image_size-1)) x2 = max(0, min(x2, image_size)) y2 = max(0, min(y2, image_size)) # Convert to patch indices (include any patch that the bbox touches) patch_x1 = x1 // patch_size patch_y1 = y1 // patch_size patch_x2 = (x2 + patch_size - 1) // patch_size patch_y2 = (y2 + patch_size - 1) // patch_size # Get indices of patches num_patches = (image_size // patch_size) indices = [] for i in range(patch_y1, patch_y2): for j in range(patch_x1, patch_x2): indices.append(i * num_patches + j + 1) # Extract and pool relevant patch tokens relevant_tokens = patch_tokens[:, indices, :] pooled_tokens = torch.mean(relevant_tokens, dim=1) return pooled_tokens def get_text_tokens_from_segment(text_tokens, org_text, seg_text, processor): """ Args: text_tokens: (B, L, D) tensor of text tokens - all tokens of original text org_text: original text string seg_text: segment text string processor: CLIP processor Returns: pooled_tokens: (B, D) tensor of pooled text tokens from the relevant segment """ # Text preprocessing org_text = ' '.join(org_text.split()).strip() seg_text = ' '.join(seg_text.split()).strip() # Split org_text into sentences sentences = org_text.split('.') sentences = [s.strip() for s in sentences if s.strip()] # Find seg_text position seg_pos = org_text.find(seg_text) current_pos = 0 sent_idx = -1 # Find position by sentence for i, sent in enumerate(sentences): sent = sent.strip() if sent == seg_text: seg_pos = current_pos sent_idx = i break current_pos += len(sent) + 2 assert seg_pos != -1, f"Segment text not found in original text" # Tokenize segment text seg_tokens = processor(text=seg_text, return_tensors="pt", padding=False, truncation=False) seg_token_length = len(seg_tokens.input_ids[0]) - 2 # Exclude CLS, EOS tokens if sent_idx != -1: # Calculate token index based on sentence position text_before = '. '.join(sentences[:sent_idx]) + ('. ' if sent_idx > 0 else '') tokens_before = processor(text=text_before, return_tensors="pt", padding=False, truncation=False) start_idx = len(tokens_before.input_ids[0]) else: # Calculate token index based on string position text_before = org_text[:seg_pos] tokens_before = processor(text=text_before, return_tensors="pt", padding=False, truncation=False) start_idx = len(tokens_before.input_ids[0]) # Adjust range considering maximum token length max_length = text_tokens.shape[1] # 248 if start_idx >= max_length: # If segment is at a position beyond max length, # extract tokens from the end, securing space equal to segment length end_idx = max_length - 1 start_idx = max(1, end_idx - seg_token_length) # Start from after CLS token (1) else: # If within normal range end_idx = min(start_idx + seg_token_length, max_length - 1) # Extract tokens relevant_tokens = text_tokens[:, start_idx:end_idx, :] # Handle case when no tokens are extracted if relevant_tokens.shape[1] == 0: # Fallback: use tokens from the beginning relevant_tokens = text_tokens[:, 1:min(1 + seg_token_length, max_length), :] # Pool tokens pooled_tokens = torch.mean(relevant_tokens, dim=1) return pooled_tokens class DLoader(Dataset): def __init__(self, data_list, processor, new_max_token): self.data_list = data_list self.processor = processor self.new_max_token = new_max_token def __len__(self): return len(self.data_list) def _load_image(self, name): img = Image.open(name).convert("RGB") return img, img.size # Also return original image size def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() item = self.data_list[idx] org_image, org_image_size = self._load_image(item["original_filename"]) # Get original image size org_caption = item["original_caption"] # Always select the segment with the highest similarity score segment = max(item["segment"], key=lambda x: x["similarity_score"]) seg_image = self._load_image(segment["filename"])[0] seg_caption = segment["caption"] bbox = segment["bbox_coordinates"] org_data = self.processor(images=org_image, text=org_caption, return_tensors="pt", truncation=True, padding="max_length", max_length=self.new_max_token) seg_data = self.processor(images=seg_image, text=seg_caption, return_tensors="pt", truncation=True, padding="max_length", max_length=self.new_max_token) return (org_data.pixel_values[0], org_data.input_ids[0], seg_data.pixel_values[0], seg_data.input_ids[0], bbox, org_caption, seg_caption, org_image_size, item["original_filename"], segment["filename"]) def main(args): wandb.init(project="CLIP_Training_real", config=args) fabric = L.Fabric( accelerator="cuda", devices=args.world_size, strategy="ddp", precision="bf16" ) fabric.launch() fabric.seed_everything(1337 + fabric.global_rank) if fabric.global_rank == 0: os.makedirs(args.output_dir, exist_ok=True) with open(args.dataset) as f: train_list = json.load(f) with fabric.device: processor = transformers.AutoProcessor.from_pretrained(args.model) model = transformers.CLIPModel.from_pretrained(args.model) longclip_pos_embeddings(model, args.new_max_token) # Load checkpoint if provided if args.ckpt: if fabric.global_rank == 0: print(f"Loading checkpoint from {args.ckpt}") checkpoint = torch.load(args.ckpt, map_location='cpu') model.load_state_dict(checkpoint) if fabric.global_rank == 0: print("Checkpoint loaded successfully") print_trainable_parameters(fabric, model) dataset_train = DLoader(train_list, processor, args.new_max_token) train_loader = torch.utils.data.DataLoader( dataset_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, shuffle=True, ) train_loader = fabric.setup_dataloaders(train_loader) optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay) model, optimizer = fabric.setup(model, optimizer) train(fabric, model, optimizer, train_loader, processor) def train(fabric: L.Fabric, model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader, processor) -> None: iter = 0 total_iter = len(train_loader) * args.epochs # Define MSE Loss mse_loss = torch.nn.MSELoss() for epoch in range(args.epochs): epoch_loss = 0.0 epoch_loss_org = 0.0 epoch_loss_seg = 0.0 epoch_loss_patch = 0.0 epoch_loss_text = 0.0 for i, samples in enumerate(train_loader): # Cosine LR lr = (args.init_lr - args.min_lr) * 0.5 * (1.0 + math.cos(math.pi * iter / total_iter)) + args.min_lr for param_group in optimizer.param_groups: param_group["lr"] = lr org_image, org_text, seg_image, seg_text, bbox, org_caption, seg_caption, org_image_sizes, org_image_paths, seg_image_paths = samples # Get all embeddings including patch tokens and sequence tokens outputs = model(pixel_values=torch.cat((org_image, seg_image), dim=0), input_ids=torch.cat((org_text, seg_text), dim=0), output_hidden_states=True) # print(model.text_model.embeddings.position_embedding.weight.requires_grad) # Get patch tokens and text tokens vision_outputs = model.vision_model(torch.cat((org_image, seg_image), dim=0), output_hidden_states=True) text_outputs = model.text_model(torch.cat((org_text, seg_text), dim=0), output_hidden_states=True) # Split embeddings for org and seg batch_size = org_image.shape[0] org_image_embeds, seg_image_embeds = outputs.image_embeds[:batch_size], outputs.image_embeds[batch_size:] org_text_embeds, seg_text_embeds = outputs.text_embeds[:batch_size], outputs.text_embeds[batch_size:] # Get patch tokens and text tokens from the last hidden states org_patch_tokens = vision_outputs.hidden_states[-1][:batch_size] # (B, N, D) org_text_tokens = text_outputs.hidden_states[-1][:batch_size] # (B, L, D) # Original CLIP loss eps = 1e-8 x_i = batch_align(fabric, F.normalize(outputs.image_embeds + eps)) x_t = batch_align(fabric, F.normalize(outputs.text_embeds + eps)) x_i_org, x_i_seg = x_i.chunk(2) x_t_org, x_t_seg = x_t.chunk(2) # Compute original losses sim_org = model.logit_scale.exp() * x_i_org @ x_t_org.t() loss_org = clip_loss(sim_org) sim_seg = model.logit_scale.exp() * x_i_seg @ x_t_seg.t() loss_seg = clip_loss(sim_seg) # Compute patch-level alignment loss patch_pooled = [] for b in range(batch_size): # org_image_sizes is converted to [width_tensor, height_tensor] format # Original format: (width, height) tuple img_width = org_image_sizes[0][b].item() # b-th element from width tensor img_height = org_image_sizes[1][b].item() # b-th element from height tensor img_size = (img_width, img_height) pooled = get_patch_tokens_from_bbox(org_patch_tokens[b:b+1], bbox, b, img_size, image_size=args.image_size, patch_size=16) patch_pooled.append(pooled) patch_pooled = torch.cat(patch_pooled, dim=0) patch_pooled = model.vision_model.post_layernorm(patch_pooled) patch_pooled = model.visual_projection(patch_pooled) patch_pooled = F.normalize(patch_pooled + eps, dim=-1) seg_image_embeds = F.normalize(seg_image_embeds + eps, dim=-1) # Compute patch alignment loss with cosine similarity directly sim_patch = patch_pooled @ seg_image_embeds.t() # removed logit_scale patch_diag = torch.diag(sim_patch) loss_patch = mse_loss(patch_diag, torch.ones_like(patch_diag)) # Compute text-level alignment loss text_pooled = [] for b in range(batch_size): #print(f"\nBatch {b} Text Sequences:") # Full token IDs of org_text org_tokens = processor(text=org_caption[b], return_tensors="pt", padding=False, truncation=False) org_token_ids = org_tokens.input_ids[0] # Full token IDs of seg_text seg_tokens = processor(text=seg_caption[b], return_tensors="pt", padding=False, truncation=False) seg_token_ids = seg_tokens.input_ids[0] # Decode token IDs to text org_tokens_text = processor.tokenizer.convert_ids_to_tokens(org_token_ids) seg_tokens_text = processor.tokenizer.convert_ids_to_tokens(seg_token_ids) # Confirm position of tokens extracted by get_text_tokens_from_segment function start_idx = len(processor(text=org_caption[b][:org_caption[b].find(seg_caption[b])], return_tensors="pt", padding=False, truncation=False).input_ids[0]) end_idx = start_idx + len(seg_tokens.input_ids[0]) - 2 # Exclude CLS, EOS tokens pooled = get_text_tokens_from_segment(org_text_tokens[b:b+1], org_caption[b], seg_caption[b], processor) text_pooled.append(pooled) text_pooled = torch.cat(text_pooled, dim=0) text_pooled = model.text_model.final_layer_norm(text_pooled) text_pooled = model.text_projection(text_pooled) text_pooled = F.normalize(text_pooled + eps, dim=-1) seg_text_embeds = F.normalize(seg_text_embeds + eps, dim=-1) # Compute text alignment loss with cosine similarity directly sim_text = text_pooled @ seg_text_embeds.t() # removed logit_scale text_diag = torch.diag(sim_text) loss_text = mse_loss(text_diag, torch.ones_like(text_diag)) # Total loss loss = loss_org + 0.5 * loss_seg + loss_patch + loss_text epoch_loss += loss.item() epoch_loss_org += loss_org.item() epoch_loss_seg += loss_seg.item() epoch_loss_patch += loss_patch.item() epoch_loss_text += loss_text.item() fabric.backward(loss) optimizer.step() optimizer.zero_grad() if fabric.global_rank == 0: wandb.log({ "iter": iter, "lr": lr, "loss": loss.item(), "loss_org": loss_org.item(), "loss_seg": loss_seg.item(), "loss_patch": loss_patch.item(), "loss_text": loss_text.item(), "epoch": epoch, "progress": (iter / total_iter) * 100, "batch_size": args.batch_size, "logit_scale": model.logit_scale.exp().item(), "patch_similarity": patch_diag.mean().item(), # average patch similarity "text_similarity": text_diag.mean().item(), # average text similarity }) fabric.print(f"epoch {epoch} iter {iter} ({(iter/total_iter)*100:.4f}%) lr {lr:.6f} " f"loss {loss.item():.4f} (org: {loss_org.item():.4f}, seg: {loss_seg.item():.4f}, " f"patch: {loss_patch.item():.4f}, text: {loss_text.item():.4f} " f"patch_sim: {patch_diag.mean().item():.4f}, text_sim: {text_diag.mean().item():.4f})") iter += 1 # Calculate and log epoch averages avg_epoch_loss = epoch_loss / len(train_loader) avg_epoch_loss_org = epoch_loss_org / len(train_loader) avg_epoch_loss_seg = epoch_loss_seg / len(train_loader) avg_epoch_loss_patch = epoch_loss_patch / len(train_loader) avg_epoch_loss_text = epoch_loss_text / len(train_loader) if fabric.global_rank == 0: wandb.log({ "epoch": epoch, "avg_epoch_loss": avg_epoch_loss, "avg_epoch_loss_org": avg_epoch_loss_org, "avg_epoch_loss_seg": avg_epoch_loss_seg, "avg_epoch_loss_patch": avg_epoch_loss_patch, "avg_epoch_loss_text": avg_epoch_loss_text, }) # Save model weights save_path = os.path.join(args.output_dir, f"GOAL_12_{os.path.splitext(os.path.basename(args.model))[0]}_" f"{os.path.splitext(os.path.basename(args.dataset))[0]}_{epoch+1}_{args.image_size}.pth") fabric.barrier() if fabric.global_rank == 0: model_state_dict = model.state_dict() cpu_state_dict = {k: v.cpu() for k, v in model_state_dict.items()} torch.save(cpu_state_dict, save_path) fabric.print(f"Model saved to {save_path}") fabric.barrier() def get_args_parser(): parser = argparse.ArgumentParser('CLIP Training', add_help=False) parser.add_argument('--batch_size', default=16, type=int, help='Batch size per GPU') parser.add_argument('--epochs', default=10, type=int) parser.add_argument('--image_size', default=224, type=int) parser.add_argument('--new_max_token', default=248, type=int) parser.add_argument('--dataset', default='datasets/docci_segment_sim_bbox_del_org.json', type=str) parser.add_argument('--model', default='openai/clip-vit-base-patch16', type=str) parser.add_argument('--weight_decay', type=float, default=0.05) parser.add_argument('--init_lr', type=float, default=5e-6, metavar='LR') # originally 5e-6 parser.add_argument('--min_lr', type=float, default=0, metavar='LR') parser.add_argument('--output_dir', default='finetune_out_SA_1B_100k_plus_docci/goal_bbox_local_token_align_batch16_only_max_pair_base16_patch16_real', help='path where to save, empty for no saving') parser.add_argument('--save_interval', default=1, type=int) parser.add_argument('--seed', default=0, type=int) parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.add_argument('--wandb_project', type=str, default='CLIP_Training', help='wandb project name') parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint file') parser.set_defaults(pin_mem=True) parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') return parser if __name__ == "__main__": args = get_args_parser() args = args.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args)