| import os |
| import re |
| import ast |
| import math |
| import yaml |
| import warnings |
| from datetime import datetime |
| from dataclasses import dataclass, field |
| from collections import defaultdict |
| from typing import Any, Callable, Optional, Union, Sized, Dict, Tuple, List, Literal, Type |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| import datasets |
|
|
| from PIL import Image |
|
|
| from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config |
| from trl.models import unwrap_model_for_generation |
|
|
| from transformers import ( |
| TrainingArguments, |
| Trainer, |
| GenerationConfig, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ( |
| is_safetensors_available, |
| is_peft_available |
| ) |
|
|
| if is_safetensors_available(): |
| import safetensors.torch |
| from peft import PeftConfig, get_peft_model, PeftModel |
| from accelerate.utils import is_peft_model, set_seed |
|
|
| from qwen_vl_utils import process_vision_info |
|
|
| from src.model.vlm_backbone.qwen2_5_vl_gp.process_gp import Qwen2_5_VL_GP_Processor |
|
|
| from transformers.trainer import ( |
| logger, |
| TRAINING_ARGS_NAME, |
| CONFIG_NAME, |
| ADAPTER_WEIGHTS_NAME, |
| ADAPTER_SAFE_WEIGHTS_NAME, |
| WEIGHTS_NAME, |
| WEIGHTS_INDEX_NAME, |
| SAFE_WEIGHTS_NAME, |
| SAFE_WEIGHTS_INDEX_NAME, |
| FSDP_MODEL_NAME, |
| ) |
|
|
| from src.model.vlm_backbone.qwen2_5_vl_gp.warppers import debug_calls |
| from src.utils_gp import ( |
| LLMClient, |
| norm_bboxes, |
| extract_one_bbox_from_str, |
| cal_paired_ious, |
| print_rank0 |
| ) |
|
|
|
|
| |
|
|
| QUERY_KEY = "query" |
| IMG_PATH_KEY = "img_path" |
| ANSWER_KEY = "answer" |
| NORMED_BBOXES_KEY = "normed_bboxes" |
| SCORE_FUNCS_KEY = "score_funcs" |
|
|
| REMAIN_KEYS = [ |
| QUERY_KEY, |
| IMG_PATH_KEY, |
| NORMED_BBOXES_KEY, |
| ANSWER_KEY, |
| SCORE_FUNCS_KEY, |
| ] |
|
|
| MAPPER_REGISTRY: Dict[str, Callable] = {} |
| FILTER_REGISTRY: Dict[str, Callable] = {} |
|
|
| def register_mappers(): |
| def wrapper(func): |
| name = func.__name__.replace("_dataset_mapper", "") |
| MAPPER_REGISTRY[name] = func |
| return func |
| return wrapper |
|
|
| def register_filters(): |
| def wrapper(func): |
| name = func.__name__.replace("_dataset_filter", "") |
| FILTER_REGISTRY[name] = func |
| return func |
| return wrapper |
|
|
|
|
| @register_mappers() |
| def cot_train_dataset_mapper(one_data, **kwargs): |
| query = one_data['question'] |
| if 'prompt' in kwargs: |
| query = kwargs['prompt'].format(query) |
| answer = one_data['answer'] |
| image = one_data['image'] |
| dataset = one_data['dataset'] |
| img_path = os.path.join(kwargs['img_dir'], "cot", dataset, image) |
| bboxes = one_data['bboxs'] |
| return { |
| QUERY_KEY: query, |
| ANSWER_KEY: answer, |
| IMG_PATH_KEY: img_path, |
| NORMED_BBOXES_KEY: bboxes, |
| SCORE_FUNCS_KEY: kwargs['score_funcs'] |
| } |
| |
|
|
| @register_mappers() |
| def cot_train_fullmask_dataset_mapper(one_data, **kwargs): |
| query = one_data['question'] |
| if 'prompt' in kwargs: |
| query = kwargs['prompt'].format(query) |
| answer = one_data['answer'] |
| image = one_data['image'] |
| dataset = one_data['dataset'] |
| img_path = os.path.join(kwargs['img_dir'], "cot", dataset, image) |
| normed_bboxes = [[0.0, 0.0, 1.0, 1.0]] |
| return { |
| QUERY_KEY: query, |
| ANSWER_KEY: answer, |
| IMG_PATH_KEY: img_path, |
| NORMED_BBOXES_KEY: normed_bboxes, |
| SCORE_FUNCS_KEY: kwargs['score_funcs'] |
| } |
| |
| |
| @register_mappers() |
| def norm_bboxes_dataset_mapper(one_data, **kwargs): |
| bboxes = one_data.pop(NORMED_BBOXES_KEY) |
| if 'width' in one_data: |
| width = one_data['width'] |
| height = one_data['height'] |
| else: |
| img_path = one_data[IMG_PATH_KEY] |
| img_pil = Image.open(img_path) |
| width, height = img_pil.size |
| img_pil.close() |
| normed_bboxes = norm_bboxes(bboxes, height, width, bbox_type=kwargs['bbox_type']) |
| one_data[NORMED_BBOXES_KEY] = normed_bboxes |
| return one_data |
|
|
| |
| @register_filters() |
| def image_exist_dataset_filter(one_data, **kwargs): |
| img_path = one_data[IMG_PATH_KEY] |
| try: |
| img = Image.open(img_path) |
| img.close() |
| return True |
| except (FileNotFoundError, OSError) as e: |
| print_rank0(f"Image not found or invalid: {img_path}. Error: {e}") |
| return False |
| except Exception as e: |
| print_rank0(f"Unexpected error while checking image: {img_path}. Error: {e}") |
| return False |
| |
| @register_filters() |
| def inputs_seq_length_dataset_filter(one_data, **kwargs): |
| processor = kwargs['processor'] |
| max_input_seq_length = kwargs.get('max_input_seq_length', None) |
| max_input_remain_seq_length = kwargs.get('max_input_remain_seq_length', None) |
| if max_input_seq_length is None and max_input_remain_seq_length is None: |
| return True |
| img_path = one_data[IMG_PATH_KEY] |
| query = one_data[QUERY_KEY] |
| normed_bboxes = [one_data[NORMED_BBOXES_KEY]] if max_input_remain_seq_length is not None else None |
| messages = [[{"role": "user", "content": [{"type": "image", "image": img_path}, {"type": "text", "text": query}]}]] |
| text = processor.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| image_inputs, video_inputs = process_vision_info(messages) |
| inputs = processor( |
| text=text, |
| images=image_inputs, |
| videos=video_inputs, |
| normed_bboxes=normed_bboxes, |
| padding=True, |
| return_tensors="pt", |
| ) |
| seq_length = inputs.input_ids.shape[1] |
| if max_input_seq_length is not None and seq_length > max_input_seq_length: |
| return False |
| |
| if max_input_remain_seq_length is not None: |
| ref_token_masks = inputs.ref_token_masks[0] |
| reduced_num = ref_token_masks.numel() - ref_token_masks.sum().item() |
| remain_seq_length = seq_length - reduced_num |
| if remain_seq_length > max_input_remain_seq_length: |
| return False |
| return True |
|
|
|
|
| |
|
|
| LOSS_REGISTRY: Dict[str, Type[nn.Module]] = {} |
|
|
| def register_loss(loss_class): |
| name = loss_class.__name__ |
| if name in LOSS_REGISTRY: |
| raise ValueError(f"Loss class '{name}' is already registered.") |
| LOSS_REGISTRY[name] = loss_class |
| return loss_class |
|
|
|
|
| @register_loss |
| class DiceLoss(nn.Module): |
| def __init__(self, epsilon: float = 1e-6, **kwargs): |
| super().__init__() |
| self.epsilon = epsilon |
|
|
| def forward(self, |
| image_token_mask_logits: List[torch.Tensor], |
| ref_token_masks: List[torch.Tensor] |
| ) -> torch.Tensor: |
| if not isinstance(image_token_mask_logits, list) or not isinstance(ref_token_masks, list): |
| raise TypeError("Inputs must be lists of tensors.") |
| if len(image_token_mask_logits) != len(ref_token_masks): |
| raise ValueError(f"Input lists must have the same length, but got " |
| f"{len(image_token_mask_logits)} and {len(ref_token_masks)}") |
| if len(image_token_mask_logits) == 0: |
| return torch.tensor(0.0, device=image_token_mask_logits[0].device if image_token_mask_logits else None) |
|
|
| batch_size = len(image_token_mask_logits) |
| total_dice_loss = 0.0 |
|
|
| for i in range(batch_size): |
| pred_mask_1d = image_token_mask_logits[i].flatten().sigmoid() |
| gt_mask_1d = ref_token_masks[i].flatten().to(pred_mask_1d.device, dtype=torch.float) |
| intersection = (pred_mask_1d * gt_mask_1d).sum() |
| pred_sum = pred_mask_1d.sum() |
| gt_sum = gt_mask_1d.sum() |
| dice_coefficient = (2.0 * intersection + self.epsilon) / (pred_sum + gt_sum + self.epsilon) |
| total_dice_loss += (1.0 - dice_coefficient) |
|
|
| return total_dice_loss / batch_size |
|
|
|
|
| @register_loss |
| class BCELoss(nn.Module): |
| def ___init__(self, **kwargs): |
| super(BCELoss, self).__init__() |
| |
| def forward(self, |
| image_token_mask_logits: List[torch.Tensor], |
| ref_token_masks: List[torch.Tensor] |
| ) -> torch.Tensor: |
| |
| batch_size = len(image_token_mask_logits) |
| total_bce_loss = 0.0 |
| for i in range(batch_size): |
| pred_mask_1d = image_token_mask_logits[i].flatten() |
| gt_mask_1d = ref_token_masks[i].flatten().to(pred_mask_1d.device) |
| bce_loss = F.binary_cross_entropy_with_logits( |
| pred_mask_1d.float(), |
| gt_mask_1d.float(), |
| ) |
| total_bce_loss += bce_loss |
| return total_bce_loss / batch_size |
|
|
|
|
| @register_loss |
| class MaskLoss(nn.Module): |
| def __init__(self, |
| dice_weight: float = 0.5, |
| bce_weight: float = 0.5, |
| epsilon: float = 1e-6, |
| **kwargs): |
| super().__init__() |
| self.dice_loss = DiceLoss(epsilon=epsilon) |
| self.bce_loss = BCELoss() |
| self.dice_weight = dice_weight |
| self.bce_weight = bce_weight |
| |
| def forward(self, image_token_mask_logits: List[torch.Tensor], |
| ref_token_masks: List[torch.Tensor] |
| ) -> torch.Tensor: |
| dice_loss = self.dice_loss(image_token_mask_logits, ref_token_masks) |
| bce_loss = self.bce_loss(image_token_mask_logits, ref_token_masks) |
| return self.dice_weight * dice_loss + self.bce_weight * bce_loss |
|
|
|
|
| |
|
|
| SCORE_REGISTRY: Dict[str, Callable] = {} |
|
|
| def register_score(): |
| def wrapper(func): |
| name = func.__name__.replace("_score", "") |
| SCORE_REGISTRY[name] = func |
| return func |
| return wrapper |
|
|
| @register_score() |
| def llm_score(query, completion, answer, args): |
| """ |
| YAML 里可能写了 'score_funcs: [llm]'。本工程不使用这些分数,返回 0 占位即可。 |
| """ |
| |
| if isinstance(query, list): |
| return [0.0] * len(query) |
| return [0.0] |
|
|
|
|
| |
|
|
| def _resolve_rel_path(rel_path: str, base_dir: str) -> str: |
| """ |
| Resolve a relative path against base_dir; if not found, try parent dirs up to 4 levels. |
| """ |
| if os.path.isabs(rel_path): |
| return rel_path |
| candidates = [os.path.join(base_dir, rel_path)] |
| parent = base_dir |
| for _ in range(4): |
| parent = os.path.dirname(parent) |
| if not parent or parent in ("/", ""): |
| break |
| candidates.append(os.path.join(parent, rel_path)) |
| for cand in candidates: |
| if os.path.exists(cand): |
| return cand |
| return candidates[0] |
|
|
|
|
| class GPDataset(torch.utils.data.Dataset): |
| """ |
| A PyTorch Dataset that loads and combines multiple datasets |
| based on a YAML configuration file. It handles sampling |
| and applies specified mapping functions. |
| """ |
| @classmethod |
| def _load_config(cls, config_path: str) -> Dict[str, Any]: |
| print_rank0(f"Loading configuration from: {config_path}") |
| try: |
| with open(config_path, 'r', encoding='utf-8') as f: |
| conf = yaml.safe_load(f) |
| if conf is None: |
| raise ValueError("YAML config is empty.") |
|
|
| base_dir = os.path.dirname(config_path) |
| |
| if 'datasets' not in conf: |
| if 'train_dataset' in conf: |
| ds_yaml = _resolve_rel_path(conf['train_dataset'], base_dir) |
| print_rank0(f"Loading dataset config from: {ds_yaml}") |
| with open(ds_yaml, 'r', encoding='utf-8') as f: |
| conf2 = yaml.safe_load(f) |
| if conf2 is None or 'datasets' not in conf2: |
| raise ValueError(f"'{ds_yaml}' missing 'datasets' key.") |
| conf = conf2 |
| base_dir = os.path.dirname(ds_yaml) |
| else: |
| raise ValueError("YAML config is missing both 'datasets' and 'train_dataset' keys.") |
|
|
| conf['__root_dir__'] = base_dir |
| print_rank0("Configuration loaded successfully.") |
| return conf |
| except Exception as e: |
| print_rank0(f"Failed to load config: {e}") |
| raise |
|
|
| @classmethod |
| def _apply_sampling(cls, dataset: datasets.Dataset, strategy: Optional[str], seed: Optional[int] = None) -> datasets.Dataset: |
| """Applies sampling strategy to a dataset.""" |
| if not strategy: |
| print_rank0("No sampling strategy specified, using full dataset.") |
| return dataset |
|
|
| try: |
| parts = strategy.split(':') |
| if len(parts) != 2: |
| raise ValueError(f"Invalid sampling strategy format: '{strategy}'. Expected 'type:value'.") |
| strat_type, strat_value = parts[0].lower(), parts[1] |
| num_samples = int(strat_value) |
| total_size = len(dataset) |
| if num_samples <= 0: |
| raise ValueError(f"Sampling value must be positive, got: {num_samples} [{strategy}]") |
| num_samples = min(num_samples, total_size) |
|
|
| print_rank0(f"Applying sampling: {strategy} ({num_samples} samples) to dataset of size {total_size}") |
|
|
| if strat_type == "first": |
| return dataset.select(range(num_samples)) |
| elif strat_type == "end": |
| start_index = max(0, total_size - num_samples) |
| return dataset.select(range(start_index, total_size)) |
| elif strat_type == "random": |
| shuffled_dataset = dataset.shuffle(seed=seed) |
| return shuffled_dataset.select(range(num_samples)) |
| else: |
| print_rank0(f"Warning: Unknown sampling strategy type: '{strat_type}'. Using full dataset.") |
| return dataset |
| except ValueError as e: |
| print_rank0(f"Error parsing sampling strategy '{strategy}': {e}. Using full dataset.") |
| return dataset |
| except Exception as e: |
| print_rank0(f"An unexpected error occurred during sampling: {e}. Using full dataset.") |
| return dataset |
| |
| @classmethod |
| def _all_processed_datasets(cls, config, processor, args): |
| root_dir = config.get('__root_dir__', os.getcwd()) |
| all_processed_datasets: Dict[str, datasets.Dataset] = {} |
| for i, dataset_config in enumerate(config['datasets']): |
| print_rank0(f"\nProcessing dataset entry {i+1}/{len(config['datasets'])}...") |
| json_path = dataset_config.get('json_path') |
| if not json_path: |
| print_rank0(f"Warning: Skipping dataset entry {i+1} due to missing 'json_path'.") |
| continue |
| json_path = _resolve_rel_path(json_path, root_dir) |
|
|
| base_name = '.'.join(os.path.basename(json_path).split('.')[:-1]) |
| dataset_name = dataset_config.get('dataset_name', base_name) |
|
|
| sampling_strategy = dataset_config.get('sampling_strategy', None) |
| sampling_seed = dataset_config['sampling_seed'] if 'sampling_seed' in dataset_config else getattr(args, 'sampling_seed', 42) |
|
|
| mapper_name = dataset_config.get('mapper') |
| bbox_type = dataset_config.get('bbox_type') |
|
|
| |
| if 'img_dir' in dataset_config: |
| img_dir = _resolve_rel_path(dataset_config['img_dir'], root_dir) |
| else: |
| img_dir = getattr(args, 'img_dir', None) |
| if img_dir is not None: |
| img_dir = _resolve_rel_path(img_dir, root_dir) |
|
|
| additional_mappers = dataset_config.get('additional_mappers', []) |
| score_funcs = dataset_config.get('score_funcs', []) |
| prompt = dataset_config.get('prompt', None) |
|
|
| max_input_seq_length = dataset_config['max_input_seq_length'] if 'max_input_seq_length' in dataset_config else getattr(args, 'max_input_seq_length', None) |
| max_input_remain_seq_length = dataset_config['max_input_remain_seq_length'] if 'max_input_remain_seq_length' in dataset_config else getattr(args, 'max_input_remain_seq_length', None) |
| |
| |
| if score_funcs: |
| filtered = [] |
| for sf in score_funcs: |
| if sf in SCORE_REGISTRY: |
| filtered.append(sf) |
| else: |
| print_rank0(f"Warning: Score function '{sf}' not registered. Will ignore.") |
| score_funcs = filtered |
|
|
| try: |
| print_rank0(f"Loading raw data from: {json_path}") |
| raw_dataset = datasets.load_dataset('json', data_files=json_path, split='train') |
| print_rank0(f"Loaded {len(raw_dataset)} examples raw.") |
|
|
| sampled_dataset = cls._apply_sampling(raw_dataset, sampling_strategy, sampling_seed) |
| if len(sampled_dataset) == 0: |
| print_rank0("Dataset is empty after sampling, skipping.") |
| continue |
| print_rank0(f"Dataset size after sampling: {len(sampled_dataset)}") |
|
|
| mapper_func = MAPPER_REGISTRY[mapper_name] |
| print_rank0(f"Applying mapper: '{mapper_name}'") |
| mapper_kwargs = { |
| 'img_dir': img_dir, |
| 'score_funcs': score_funcs, |
| } |
| if prompt is not None: |
| mapper_kwargs['prompt'] = prompt |
| print_rank0(f"Mapper arguments: {mapper_kwargs}") |
|
|
| processed_dataset = sampled_dataset.map( |
| mapper_func, |
| num_proc=8, |
| fn_kwargs=mapper_kwargs, |
| ) |
|
|
| processed_dataset = processed_dataset.remove_columns( |
| [col for col in processed_dataset.column_names if col not in REMAIN_KEYS] |
| ) |
| |
| print_rank0("Applying dataset filter: 'image_exist_dataset_filter'") |
| processed_dataset = processed_dataset.filter( |
| image_exist_dataset_filter, |
| num_proc=8, |
| fn_kwargs={} |
| ) |
| print_rank0(f"Processed dataset size after image_exist_dataset_filter: {len(processed_dataset)}") |
| |
| if max_input_seq_length is not None or max_input_remain_seq_length is not None: |
| processed_dataset = processed_dataset.filter( |
| inputs_seq_length_dataset_filter, |
| num_proc=8, |
| fn_kwargs={ |
| 'processor': processor, |
| 'max_input_seq_length': max_input_seq_length, |
| 'max_input_remain_seq_length': max_input_remain_seq_length, |
| } |
| ) |
| print_rank0(f"Processed dataset size after inputs_seq_length_dataset_filter: {len(processed_dataset)}") |
| |
| for additional_mapper in additional_mappers: |
| mapper_func = MAPPER_REGISTRY[additional_mapper] |
| print_rank0(f"Applying additional mapper: '{additional_mapper}'") |
| processed_dataset = processed_dataset.map( |
| mapper_func, |
| num_proc=8, |
| fn_kwargs={ |
| 'bbox_type': bbox_type, |
| } |
| ) |
| print_rank0(f"Processed dataset size: {len(processed_dataset)}") |
| if len(processed_dataset) == 0: |
| print_rank0(f"Warning: Processed dataset {dataset_name} is empty after mapping. Skipping.") |
| continue |
|
|
| if dataset_name in all_processed_datasets: |
| dataset_name_with_uuid = f"{dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| print_rank0(f"Warning: Dataset name '{dataset_name}' already exists. Renaming to '{dataset_name_with_uuid}'") |
| all_processed_datasets[dataset_name_with_uuid] = processed_dataset |
| else: |
| all_processed_datasets[dataset_name] = processed_dataset |
|
|
| except FileNotFoundError: |
| print_rank0(f"Error: Data file not found for dataset entry {i+1}: {json_path}. Skipping.") |
| except Exception as e: |
| print_rank0(f"Error processing dataset entry {i+1} ({json_path}): {e}. Skipping.") |
| |
| return all_processed_datasets |
| |
|
|
| def __init__(self, config_path: str, processor: Qwen2_5_VL_GP_Processor, script_args: Optional[Any] = None): |
| """ |
| Initializes the GPDataset. |
| |
| Args: |
| config_path (str): Path to the YAML configuration file. |
| processor (Qwen2_5_VL_GP_Processor): Processor for handling text and vision data. |
| script_args (Any, optional): Additional arguments passed from the script |
| (e.g., training args, could contain seed). Defaults to None. |
| """ |
| super().__init__() |
| self.args = script_args |
| self.config = self._load_config(config_path) |
| self.processor = processor |
| all_processed_datasets = self._all_processed_datasets(self.config, self.processor, self.args) |
| if all_processed_datasets: |
| print_rank0(f"\nConcatenating {len(all_processed_datasets)} processed dataset(s)...") |
| self.final_dataset = datasets.concatenate_datasets(list(all_processed_datasets.values())) |
| if len(self.final_dataset) == 0: |
| raise ValueError("Final dataset is empty after concatenation.") |
| print_rank0(f"Final combined dataset size: {len(self.final_dataset)}") |
| print_rank0(f"Final dataset features: {self.final_dataset.features}") |
| else: |
| raise ValueError("No datasets were successfully processed. Please check your configuration.") |
| self.final_dataset = None |
|
|
| def __len__(self) -> int: |
| return len(self.final_dataset) if self.final_dataset else 0 |
|
|
| def __getitem__(self, index: int) -> Dict[str, Any]: |
| if self.final_dataset is None: |
| raise IndexError("Dataset is not initialized or is empty.") |
| if not 0 <= index < len(self.final_dataset): |
| raise IndexError(f"Index {index} out of bounds for dataset of size {len(self.final_dataset)}") |
| return self.final_dataset[index] |
| |
| |
| @classmethod |
| def get_processed_dataset_dict(cls, config_path: str, processor: Qwen2_5_VL_GP_Processor, script_args: Optional[Any] = None) -> Dict[str, datasets.Dataset]: |
| config = cls._load_config(config_path) |
| all_processed_datasets = cls._all_processed_datasets(config, processor, script_args) |
| return all_processed_datasets |
|
|
|
|
|
|
| class GPCollator: |
| def __init__(self, processor, is_sft): |
| self.processor = processor |
| self.is_sft = is_sft |
| self.im_start_id = self.processor.tokenizer.encode("<|im_start|>")[0] |
| |
| def _prepare_labels_from_input_ids(self, input_ids): |
| B, L = input_ids.shape |
| labels = input_ids.clone() |
| mask = input_ids == self.im_start_id |
| flipped_mask = mask.flip(dims=(1,)) |
| first_idx_in_flipped = torch.argmax(flipped_mask.int(), dim=1) |
| last_pos = (L - 1) - first_idx_in_flipped |
| mask_until_idx = last_pos + 3 |
| mask_until_idx = torch.clamp(mask_until_idx, max=L) |
| arange_l = torch.arange(L, device=input_ids.device).expand(B, -1) |
| modification_mask = arange_l < mask_until_idx.unsqueeze(1) |
| labels[modification_mask] = -100 |
| return labels |
| |
| def __call__(self, features): |
| messages = [] |
| normed_bboxes = [] |
| answers = [] |
| querys = [] |
| score_funcs = [] |
| for feature in features: |
| query = feature[QUERY_KEY] |
| answer = feature[ANSWER_KEY] |
| img_path = feature[IMG_PATH_KEY] |
| if self.is_sft: |
| messages.append([{"role": "user", "content": [{"type": "image", "image": img_path}, {"type": "text", "text": query}]}, {"role": "assistant", "content": [{"type": "text", "text": answer}]}]) |
| else: |
| messages.append([{"role": "user", "content": [{"type": "image", "image": img_path}, {"type": "text", "text": query}]}]) |
| normed_bboxes.append(feature[NORMED_BBOXES_KEY]) |
| querys.append(query) |
| answers.append(answer) |
| score_funcs.append(feature[SCORE_FUNCS_KEY]) |
| |
| text = self.processor.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=(not self.is_sft) |
| ) |
| image_inputs, video_inputs = process_vision_info(messages) |
| inputs = self.processor( |
| text=text, |
| normed_bboxes=normed_bboxes, |
| images=image_inputs, |
| videos=video_inputs, |
| padding=True, |
| return_tensors="pt", |
| ) |
| |
| if self.is_sft: |
| labels = self._prepare_labels_from_input_ids(inputs.input_ids) |
| inputs["labels"] = labels |
| |
| inputs[QUERY_KEY] = querys |
| inputs[ANSWER_KEY] = answers |
| inputs[SCORE_FUNCS_KEY] = score_funcs |
| return inputs |