| import io |
| import math |
| import base64 |
| import torch |
| import copy |
|
|
| import torchvision.transforms as transforms |
| import numpy as np |
| from PIL import Image |
| from einops import rearrange |
|
|
| from transformers import GenerationConfig, DynamicCache |
| from projects.ST.models.models_modeling_qwen2mm_mmrope import Qwen2MMmropeForCausalLM |
| from transformers import AutoTokenizer |
|
|
| def get_transformer_and_tokenizer(model_path, tokenizer_path): |
| model = Qwen2MMmropeForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, use_cache=False) |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| tokenizer.vis_beg_tok = "<vision>" |
| tokenizer.vis_patch_tok = "<vpatch>" |
| tokenizer.vis_rsep_tok = "<vrow_sep>" |
| tokenizer.vis_frm_tok = "<vframe_sep>" |
| tokenizer.vis_end_tok = "</vision>" |
| tokenizer.vis_cls_tok = "<|vis_cls|>" |
|
|
| tokenizer.vis_beg_tok_id = tokenizer.convert_tokens_to_ids("<vision>") |
| tokenizer.vis_patch_tok_id = tokenizer.convert_tokens_to_ids("<vpatch>") |
| tokenizer.vis_rsep_tok_id = tokenizer.convert_tokens_to_ids("<vrow_sep>") |
| tokenizer.vis_frm_tok_id = tokenizer.convert_tokens_to_ids("<vframe_sep>") |
| tokenizer.vis_end_tok_id = tokenizer.convert_tokens_to_ids("</vision>") |
| tokenizer.vis_cls_tok_id = tokenizer.convert_tokens_to_ids("<|vis_cls|>") |
| return model, tokenizer |
|
|
| DEFAULT_PATCH_SIZE = 32 |
| MAX_RESOLUTION = 1024 |
| VISION_TOKENS = [ |
| "<vision>", |
| "<vpatch>", |
| "<vrow_sep>", |
| "<vframe_sep>", |
| "</vision>", |
| "<|vis_cls|>" |
| |
| ] |
| NON_VISION_TOKEN_ID = -1 |
| PROMPT_TMPL = '<|im_start|>user\n{input}<|im_end|>\n' |
|
|
|
|
| def load_image_to_base64(image_path: str) -> str: |
| |
| with open(image_path, "rb") as image_file: |
| encoded_string = base64.b64encode(image_file.read()).decode('utf-8') |
| return f"data:image/jpeg;base64,{encoded_string}" |
|
|
|
|
| def load_base64_to_PILImage(base64_string: str) -> Image: |
| |
| base64_string = base64_string.split(",")[1] |
| decoded_string = base64.b64decode(base64_string) |
| return Image.open(io.BytesIO(decoded_string)).convert('RGB') |
|
|
|
|
| def get_resize_output_image_size( |
| image_size, patch_size, fix_res_size=None |
| ) -> tuple: |
| if fix_res_size is not None: |
| return fix_res_size, fix_res_size |
|
|
| l1, l2 = image_size |
| short, long = (l2, l1) if l2 <= l1 else (l1, l2) |
|
|
| |
| requested_new_long = min( |
| [ |
| math.ceil(long / patch_size) * patch_size, |
| MAX_RESOLUTION, |
| ] |
| ) |
|
|
| new_long, new_short = requested_new_long, int(requested_new_long * short / long) |
|
|
| new_short = math.ceil(new_short / patch_size) * patch_size |
| return (new_long, new_short) if l2 <= l1 else (new_short, new_long) |
|
|
|
|
| def preprocess_image( |
| image_tensor: torch.Tensor, |
| patch_size: int = DEFAULT_PATCH_SIZE |
| ) -> torch.Tensor: |
| |
| |
| |
| |
| patches = image_tensor.unfold(1, patch_size, patch_size) \ |
| .unfold(2, patch_size, patch_size) |
| patches = patches.permute(1, 2, 0, 3, 4).contiguous() |
| return patches |
|
|
|
|
| def get_transform(height, width): |
| preprocess_transform = transforms.Compose([ |
| transforms.Resize((height, width)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| ]) |
| return preprocess_transform |
|
|
|
|
| def convert_image_base64_to_patches(base64_image: str, patch_size: int, fix_res_size: int = None) -> torch.Tensor: |
| img_pil = load_base64_to_PILImage(base64_image) |
| |
| width, height = img_pil.size |
| new_width, new_height = get_resize_output_image_size((width, height), patch_size=patch_size, |
| fix_res_size=fix_res_size) |
| img_tensor = get_transform(new_height, new_width)(img_pil) |
| img_patches = preprocess_image(img_tensor, patch_size=patch_size) |
| return img_patches |
|
|
|
|
| def prepare_image_textual_seq(h, w, tokenizer, add_cls=True): |
| seq = "" |
| tok_len = 0 |
|
|
| seq += tokenizer.vis_beg_tok |
| tok_len += 1 |
| for _ in range(h - 1): |
| seq += tokenizer.vis_patch_tok * w + tokenizer.vis_rsep_tok |
| tok_len += (w + 1) |
| seq += tokenizer.vis_patch_tok * w + tokenizer.vis_end_tok |
| tok_len += (w + 1) |
| if add_cls: |
| seq += tokenizer.vis_cls_tok |
| tok_len += 1 |
|
|
| return seq, tok_len |
|
|
|
|
| def prepare_image_textual_seq_norowsep(h, w, tokenizer, add_cls=True): |
| seq = "" |
| tok_len = 0 |
|
|
| seq += tokenizer.vis_beg_tok |
| tok_len += 1 |
|
|
| seq += tokenizer.vis_patch_tok * (w * h) |
| tok_len += (w * h) |
|
|
| seq += tokenizer.vis_end_tok |
| tok_len += 1 |
|
|
| if add_cls: |
| seq += tokenizer.vis_cls_tok |
| tok_len += 1 |
|
|
| return seq, tok_len |
|
|
|
|
| def create_single_prefix_mask(prefix_len, max_len): |
| attn_mask = torch.zeros(max_len, max_len) |
| attn_mask[:prefix_len, :prefix_len] = 1 |
| causal_mask = torch.tril(torch.ones(max_len, max_len)) |
| attn_mask = attn_mask.bool() | causal_mask.bool() |
| return attn_mask |
|
|
|
|
| def generate_mm_pos_ids_singleit(input_ids, vpatch_id, h, w): |
| input_ids_pt = torch.Tensor(input_ids).int() |
| vpatch_pos = torch.argwhere(input_ids_pt == vpatch_id) |
| vpatch_start_pos = vpatch_pos[0].item() |
| nt = len(input_ids) - (h * w) + 1 |
|
|
| |
| t_indices = torch.arange(1) |
| h_indices = torch.arange(h) |
| w_indices = torch.arange(w) |
| v_pos_id = torch.stack(torch.meshgrid(t_indices, h_indices, w_indices, indexing='ij'), dim=0) |
| v_pos_id = rearrange(v_pos_id, "d t h w -> (t h w) d") |
| v_pos_id += vpatch_start_pos |
| position_id = torch.cat( |
| [ |
| torch.arange(vpatch_start_pos).unsqueeze(-1).repeat(1, 3), |
| v_pos_id, |
| torch.arange(nt - vpatch_start_pos - 1).unsqueeze(-1).repeat(1, 3) + v_pos_id.max() + 1, |
| ], |
| dim=0 |
| ) |
| assert len(input_ids) == position_id.size(0) |
| position_id = rearrange(position_id, "slen d -> d slen").long() |
|
|
| return position_id |
|
|
|
|
| class Qwen2mmMROPEModel: |
| INSTALL_REQ = False |
| INTERLEAVE = False |
|
|
| def __init__(self, model_path='/mnt/bn/zilongdata-us/weixian/ckpt/qwen2mm-7B-mrope', |
| tokenizer_path="/mnt/bn/zilongdata-us/weixian/ckpt/Qwen2.5MM-7B-ext-psz16", fix_res_size=None, |
| **kwargs): |
|
|
| model, tokenizer = get_transformer_and_tokenizer( |
| model_path, tokenizer_path |
| ) |
| self.model = model.cuda().eval() |
| self.tokenizer = tokenizer |
|
|
| self.image_processor = lambda x: convert_image_base64_to_patches(load_image_to_base64(x), |
| model.config.vision_patch_size, |
| fix_res_size=fix_res_size) |
| self.kwargs = kwargs |
|
|
| def prepare_input(self, image, text_input): |
|
|
| text_input = text_input.replace("<image>\n", '').replace("\n<image>", '').replace("<image> ", '').replace( |
| " <image>", '') |
| bos_token = '' if self.tokenizer.bos_token is None else self.tokenizer.bos_token |
| text_input = bos_token + PROMPT_TMPL.format(input=text_input.strip()) |
|
|
| if image is not None: |
| tokens = [] |
| vision_patch_indices = [] |
| vision_patches = [] |
|
|
| patches = image |
| n_rows, n_cols = patches.shape[:2] |
| n_patches = n_rows * n_cols |
| patches = patches.view(n_patches, -1) |
|
|
| |
| image_text_seq, image_tok_len = prepare_image_textual_seq_norowsep(n_rows, n_cols, self.tokenizer, |
| add_cls=False) |
| |
| cur_tokens_pt = self.tokenizer(image_text_seq, add_special_tokens=False, |
| return_tensors="pt").input_ids.squeeze(0) |
| cur_patch_indices = torch.full_like(cur_tokens_pt, fill_value=NON_VISION_TOKEN_ID) |
| assert (cur_tokens_pt == self.tokenizer.vis_patch_tok_id).sum() == n_patches |
| assert (cur_tokens_pt >= self.tokenizer.vis_beg_tok_id).sum() == image_tok_len |
| cur_patch_indices[cur_tokens_pt == self.tokenizer.vis_patch_tok_id] = torch.arange(n_patches) |
|
|
| cur_tokens = cur_tokens_pt.cpu().numpy().tolist() |
| cur_patch_indices = cur_patch_indices.cpu().numpy().tolist() |
| assert len(cur_tokens) == len(cur_patch_indices) |
|
|
| tokens.extend(cur_tokens) |
| vision_patch_indices.extend(cur_patch_indices) |
| vision_patches.extend(patches.numpy().astype(np.float16)) |
|
|
| |
| _tokenized_text = self.tokenizer(text_input, return_tensors="pt", add_special_tokens=False) |
| cur_tokens = _tokenized_text["input_ids"].squeeze(0) |
| tokens.extend(cur_tokens) |
| vision_patch_indices.extend([NON_VISION_TOKEN_ID] * len(cur_tokens)) |
|
|
| position_ids = generate_mm_pos_ids_singleit(tokens, self.tokenizer.vis_patch_tok_id, n_rows, |
| n_cols) |
| attention_mask_4d = create_single_prefix_mask(image_tok_len, len(tokens)).unsqueeze(0) |
| print('ids: ', tokens) |
| tokens = torch.Tensor(tokens).long() |
| print('vision_patches_indices: ', vision_patch_indices) |
| vision_patch_indices = torch.Tensor(vision_patch_indices).long() |
| if len(vision_patches) > 0: |
| |
| vision_patches = np.array(vision_patches) |
| vision_patches = torch.Tensor(vision_patches).bfloat16() |
| else: |
| vision_patches = None |
|
|
| tokens = tokens.unsqueeze(0) |
| position_ids = position_ids.unsqueeze(1) |
| attention_mask_4d = attention_mask_4d.unsqueeze(0) |
| vision_patch_indices = vision_patch_indices.unsqueeze(0) |
| attn_mask_for_gen = torch.ones_like(tokens) |
|
|
| return dict( |
| input_ids=tokens.to("cuda"), |
| position_ids=position_ids.to("cuda"), |
| attention_mask=attn_mask_for_gen.to("cuda"), |
| vision_patches=vision_patches.to("cuda"), |
| vision_patch_indices=vision_patch_indices.to("cuda"), |
| attention_mask_4d=attention_mask_4d.to("cuda"), |
| image_tokens_len=image_tok_len |
| ) |
|
|
| |
| _text_inputs = self.tokenizer(text_input, return_tensors="pt", add_special_tokens=False) |
| text_input_ids = _text_inputs['input_ids'] |
| text_attn_mask = _text_inputs['attention_mask'] |
| text_position_ids = torch.arange(text_input_ids.size(-1)).unsqueeze(0).expand(3, -1).clone().long() |
| return dict( |
| input_ids=text_input_ids.long().to("cuda"), |
| attention_mask=text_attn_mask.long().to("cuda"), |
| position_ids=text_position_ids.unsqueeze(1).to("cuda"), |
| vision_patches=None, |
| vision_patch_indices=None, |
| attention_mask_4d=None, |
| image_tokens_len=None |
| ) |
|
|
| def message_to_promptimg(self, message, dataset=None): |
| assert not self.INTERLEAVE |
| num_images = len([x for x in message if x['type'] == 'image']) |
| if num_images == 0: |
| prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text']) |
| image = None |
| else: |
| prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text']) |
| images = [x['value'] for x in message if x['type'] == 'image'] |
| image = images[0] |
| return prompt, image |
|
|
| def generate_inner(self, message, dataset=None): |
| prompt, image_path = self.message_to_promptimg(message, dataset=dataset) |
| image_patches = None if image_path is None else \ |
| self.image_processor(image_path) |
| inputs = self.prepare_input(image_patches, prompt) |
|
|
| past_key_values = None |
| image_tok_len = inputs.pop("image_tokens_len") |
| attention_mask_4d = inputs.pop("attention_mask_4d") |
| if image_tok_len is not None and attention_mask_4d is not None: |
| assert (attention_mask_4d[:, :, :image_tok_len, :image_tok_len] == 1).all() |
| assert inputs["vision_patches"] is not None |
| assert inputs["vision_patch_indices"] is not None |
| prefix_cache = DynamicCache() |
| cache_inputs = dict( |
| input_ids=inputs['input_ids'][:, :image_tok_len], |
| position_ids=inputs['position_ids'][:, :, :image_tok_len], |
| attention_mask=attention_mask_4d[:, :, :image_tok_len, :image_tok_len], |
| vision_patches=inputs['vision_patches'], |
| vision_patch_indices=inputs['vision_patch_indices'][:, :image_tok_len], |
| ) |
| with torch.no_grad(): |
| prefix_cache = self.model(**cache_inputs, past_key_values=prefix_cache, use_cache=True).past_key_values |
| past_key_values = copy.deepcopy(prefix_cache) |
|
|
| generation_args = GenerationConfig( |
| do_sample=False, |
| top_p=None, |
| temperature=0, |
| num_beams=1, |
| max_new_tokens=128, |
| pad_token_id=self.tokenizer.eos_token_id, |
| eos_token_id=self.tokenizer.eos_token_id, |
| ) |
|
|
| generate_ids = self.model.generate( |
| **inputs, |
| past_key_values=past_key_values, |
| use_cache=True, |
| eos_token_id=self.tokenizer.eos_token_id, |
| generation_config=generation_args |
| ) |
| print(generate_ids) |
| |
| response = self.tokenizer.batch_decode( |
| generate_ids, |
| skip_special_tokens=False, |
| clean_up_tokenization_spaces=False |
| )[0] |
| return response |
|
|
| def generate_ext_eval(self, args, prompt, image_path=None, generate_config=None): |
| image_patches = None if image_path is None else \ |
| self.image_processor(image_path) |
| inputs = self.prepare_input(image_patches, prompt) |
|
|
| past_key_values = None |
| image_tok_len = inputs.pop("image_tokens_len") |
| attention_mask_4d = inputs.pop("attention_mask_4d") |
| if image_tok_len is not None and attention_mask_4d is not None: |
| assert (attention_mask_4d[:, :, :image_tok_len, :image_tok_len] == 1).all() |
| assert inputs["vision_patches"] is not None |
| assert inputs["vision_patch_indices"] is not None |
| prefix_cache = DynamicCache() |
| cache_inputs = dict( |
| input_ids=inputs['input_ids'][:, :image_tok_len], |
| position_ids=inputs['position_ids'][:, :, :image_tok_len], |
| attention_mask=attention_mask_4d[:, :, :image_tok_len, :image_tok_len], |
| vision_patches=inputs['vision_patches'], |
| vision_patch_indices=inputs['vision_patch_indices'][:, :image_tok_len], |
| ) |
| with torch.no_grad(): |
| prefix_cache = self.model(**cache_inputs, past_key_values=prefix_cache, use_cache=True).past_key_values |
| past_key_values = copy.deepcopy(prefix_cache) |
|
|
| generation_args = GenerationConfig( |
| do_sample=True if args.temperature > 0 else False, |
| temperature=args.temperature, |
| num_beams=args.num_beams, |
| max_new_tokens=args.max_new_tokens, |
| min_new_tokens=1, |
| pad_token_id=self.tokenizer.eos_token_id, |
| eos_token_id=self.tokenizer.eos_token_id, |
| ) if generate_config is None else GenerationConfig(**generate_config) |
|
|
| generate_ids = self.model.generate( |
| **inputs, |
| past_key_values=past_key_values, |
| use_cache=True, |
| eos_token_id=self.tokenizer.eos_token_id, |
| generation_config=generation_args |
| ) |
| generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] |
| response = self.tokenizer.batch_decode( |
| generate_ids, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False |
| )[0] |
| return response |
|
|
| tokenizer_path = './pretrained/single_transformer/capcls1.0_1024M_imgfull_withpt_lr5e-4-0_rp0.1_iter62500_hf/' |
| path = './pretrained/single_transformer/SFT-Qwen2.5-0.5B-capcls1.0_1024M_iter_62500_lr5e-4_0_rp0.1_hf_llava/' |
| evaluation_images = './projects/omg_llava/test.jpg' |
| evaluation_inputs = ['Please describe this picture'] |
|
|
| messages = [] |
| messages.append({'type': 'image', 'value': evaluation_images}) |
| messages.append({'type': 'text', 'value': evaluation_inputs[0]}) |
|
|
| model = Qwen2mmMROPEModel(model_path=path, tokenizer_path=tokenizer_path) |
| ret = model.generate_inner(message=messages) |
| print(ret) |