| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import os |
| from collections import defaultdict |
| from io import BytesIO |
| from typing import Any, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from datasets import load_dataset |
| from jinja2 import Template |
| from PIL import Image |
| from PIL.Image import Image as ImageObject |
| from qwen_vl_utils.vision_process import fetch_video |
| from torch.utils.data import Dataset |
| from transformers import PreTrainedTokenizer, ProcessorMixin |
|
|
| from . import torch_functional as VF |
|
|
|
|
|
|
| QUESTION_TEMPLATE = ( |
| "{Question}\n" |
| "Please answer this question based on the visual content." |
| "Provide your thinking process between the <think> and </think> tags, and then give your final answer between the <answer> and </answer> tags." |
| "At the end, you must output the final answer in the format:\n" |
| "<answer><your_answer_here></answer>\n" |
| ) |
|
|
| TYPE_TEMPLATE = { |
| "surgical tissue segmentation": ( |
| "Please provide only segmentation result in JSON format " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>{\"boxes\": [x1, y1, x2, y2], \"positive_points\": [[x,y],[x,y],[x,y]], \"negative_points\": [[x,y],[x,y],[x,y]]}</answer>" |
| ), |
| "surgical instrument segmentation": ( |
| "Please provide only segmentation result in JSON format " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>{\"boxes\": [x1, y1, x2, y2], \"positive_points\": [[x,y],[x,y],[x,y]], \"negative_points\": [[x,y],[x,y],[x,y]]}</answer>" |
| ), |
| "surgical tissue localization": ( |
| "Please provide only localization result " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>[0, 38, 299, 132]</answer>" |
| ), |
| "surgical instrument localization": ( |
| "Please provide only localization result " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>[0, 38, 299, 132]</answer>" |
| ), |
| |
| "surgical instrument count": ( |
| "Please provide only the count of surgical instruments " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>3</answer>" |
| ), |
| "surgical instrument recognition": ( |
| "Please provide only the instrument names " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>Grasper, Hook, Scissors</answer>" |
| ), |
| "surgical tissue recognition": ( |
| "Please provide only the tissue names " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>Gallbladder</answer>" |
| ), |
| "surgical triplet recognition": ( |
| "Please provide only the instrument-action-tissue triplets " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>Grasper, Grasp, Gallbladder</answer>" |
| ), |
| "surgical action recognition": ( |
| "Please provide only the action names " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>Grasp</answer>" |
| ), |
| "surgical step recognition": ( |
| "Please provide only the step type " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>Sellotomy</answer>" |
| ), |
| "surgical phase recognition": ( |
| "Please provide only the phase type " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>Suturing</answer>" |
| ), |
| |
| |
| |
| |
| |
| |
| |
| "critical view safety": ( |
| "Please provide only the Yes or No list " |
| "within the <answer>...</answer> tags.\n" |
| "Example:\n<answer>['No', 'No', 'No']</answer>" |
| ), |
| |
| |
| |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def collate_fn(features: list[dict[str, Any]]) -> dict[str, Any]: |
| tensors = defaultdict(list) |
| non_tensors = defaultdict(list) |
| for feature in features: |
| for key, value in feature.items(): |
| if isinstance(value, torch.Tensor): |
| tensors[key].append(value) |
| else: |
| non_tensors[key].append(value) |
|
|
| for key, value in tensors.items(): |
| tensors[key] = torch.stack(value, dim=0) |
|
|
| for key, value in non_tensors.items(): |
| non_tensors[key] = np.array(value, dtype=object) |
|
|
| return {**tensors, **non_tensors} |
|
|
|
|
| def process_image( |
| image: Union[dict[str, Any], ImageObject, str], |
| min_pixels: Optional[int], |
| max_pixels: Optional[int], |
| resize_size: int = 0 |
| ) -> ImageObject: |
| """ |
| Process image with optional resize by shorter side and pixel range constraints. |
| |
| Args: |
| image: Input image (path, dict with bytes, or PIL Image) |
| min_pixels: Minimum pixel count for image |
| max_pixels: Maximum pixel count for image |
| resize_size: Resize shorter side to this size. 0 means disabled. |
| |
| Returns: |
| Processed PIL Image in RGB mode |
| """ |
| if isinstance(image, str): |
| image = Image.open(image) |
| elif isinstance(image, dict): |
| image = Image.open(BytesIO(image["bytes"])) |
| elif isinstance(image, bytes): |
| image = Image.open(BytesIO(image)) |
|
|
| image.load() |
|
|
| if resize_size > 0: |
| w, h = image.width, image.height |
| if w < h: |
| new_w = resize_size |
| new_h = int(h * resize_size / w) |
| else: |
| new_h = resize_size |
| new_w = int(w * resize_size / h) |
| image = image.resize((new_w, new_h), Image.LANCZOS) |
| else: |
| if max_pixels is not None and (image.width * image.height) > max_pixels: |
| resize_factor = math.sqrt(max_pixels / (image.width * image.height)) |
| width, height = int(image.width * resize_factor), int(image.height * resize_factor) |
| image = image.resize((width, height)) |
|
|
| if min_pixels is not None and (image.width * image.height) < min_pixels: |
| resize_factor = math.sqrt(min_pixels / (image.width * image.height)) |
| width, height = int(image.width * resize_factor), int(image.height * resize_factor) |
| image = image.resize((width, height)) |
|
|
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| return image |
|
|
|
|
| def process_video( |
| video: str, min_pixels: int = 4*32*32, max_pixels: int = 64*32*32, max_frames: int = 128, video_fps: float = 2, return_fps: bool = False |
| ): |
| vision_info = {"video": video, "min_pixels": min_pixels, "max_pixels": max_pixels, "max_frames": max_frames, "fps": video_fps} |
| return fetch_video(vision_info, image_patch_size=16, return_video_sample_fps=return_fps, return_video_metadata=return_fps) |
|
|
|
|
| class RLHFDataset(Dataset): |
| """ |
| We assume the dataset contains a column that contains prompts and other information |
| """ |
|
|
| def __init__( |
| self, |
| data_path: str, |
| tokenizer: PreTrainedTokenizer, |
| processor: Optional[ProcessorMixin], |
| prompt_key: str = "prompt", |
| answer_key: str = "answer", |
| image_key: str = "images", |
| video_key: str = "videos", |
| image_dir: Optional[str] = None, |
| video_fps: float = 2.0, |
| max_prompt_length: int = 1024, |
| truncation: str = "error", |
| format_prompt: Optional[str] = None, |
| min_pixels: Optional[int] = None, |
| max_pixels: Optional[int] = None, |
| resize_size: int = 0, |
| filter_overlong_prompts: bool = True, |
| filter_overlong_prompts_workers: int = 16, |
| ): |
| self.tokenizer = tokenizer |
| self.processor = processor |
| self.prompt_key = prompt_key |
| self.answer_key = answer_key |
| self.image_key = image_key |
| self.video_key = video_key |
| self.image_dir = image_dir |
| self.video_fps = video_fps |
| self.max_prompt_length = max_prompt_length |
| self.truncation = truncation |
| self.min_pixels = min_pixels |
| self.max_pixels = max_pixels |
| self.resize_size = resize_size |
|
|
| if "@" in data_path: |
| data_path, data_split = data_path.split("@") |
| else: |
| data_split = "train" |
|
|
| if os.path.isdir(data_path): |
| |
| file_type = os.path.splitext(os.listdir(data_path)[0])[-1][1:].replace("jsonl", "json") |
| self.dataset = load_dataset(file_type, data_dir=data_path, split=data_split) |
| elif os.path.isfile(data_path): |
| file_type = os.path.splitext(data_path)[-1][1:].replace("jsonl", "json") |
| self.dataset = load_dataset(file_type, data_files=data_path, split=data_split) |
| else: |
| |
| self.dataset = load_dataset(data_path, split=data_split) |
|
|
| self.format_prompt = None |
| if format_prompt: |
| with open(format_prompt, encoding="utf-8") as f: |
| self.format_prompt = f.read() |
|
|
| if filter_overlong_prompts: |
| self.dataset = self.dataset.filter( |
| self._filter_overlong_prompts, |
| desc="Filtering overlong prompts", |
| num_proc=filter_overlong_prompts_workers, |
| ) |
|
|
|
|
| def _build_messages(self, example: dict[str, Any]) -> list[dict[str, Any]]: |
| prompt_str: str = example[self.prompt_key] |
| |
| |
| |
|
|
| data_type = (example.get("data_type") or "").strip().lower() |
| pt = example.get("problem_type") or "" |
| question = prompt_str |
|
|
| |
| |
| |
|
|
| |
| |
| |
| type_key = pt |
|
|
| tail = TYPE_TEMPLATE.get(type_key, "") |
| prompt_str = QUESTION_TEMPLATE.format(Question=question) + tail |
|
|
| if self.image_key in example and isinstance(example.get(self.image_key), list) and len(example.get(self.image_key)) > 0: |
| |
| content_list = [] |
| for i, content in enumerate(prompt_str.split("<image>")): |
| if i != 0: |
| content_list.append({"type": "image"}) |
|
|
| if content: |
| content_list.append({"type": "text", "text": content}) |
|
|
| |
|
|
| return [{"role": "user", "content": content_list}] |
| elif self.video_key in example and isinstance(example.get(self.video_key), list) and len(example.get(self.video_key)) > 0: |
| content_list = [] |
| for i, content in enumerate(prompt_str.split("<video>")): |
| if i != 0: |
| content_list.append({"type": "video"}) |
|
|
| if content: |
| content_list.append({"type": "text", "text": content}) |
|
|
| |
|
|
| return [{"role": "user", "content": content_list}] |
| else: |
| return [{"role": "user", "content": prompt_str}] |
|
|
|
|
| def _filter_overlong_prompts(self, example: dict[str, Any]) -> bool: |
| messages = self._build_messages(example) |
| if self.image_key in example and isinstance(example.get(self.image_key), list) and len(example.get(self.image_key)) > 0: |
| prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| images = example[self.image_key] |
| try: |
| if self.image_dir is not None and len(images) != 0 and isinstance(images[0], str): |
| images = [os.path.join(self.image_dir, image) for image in images] |
|
|
| except Exception as e: |
| print(f"images type: {type(images)} | value: {images}") |
| print("full example:", example) |
|
|
|
|
|
|
| processed_images = [] if len(images) != 0 else None |
| for image in images: |
| processed_images.append(process_image(image, self.min_pixels, self.max_pixels, self.resize_size)) |
|
|
| model_inputs = self.processor(processed_images, [prompt], add_special_tokens=False, return_tensors="pt") |
| |
| print(images, model_inputs["input_ids"].size(-1)) |
| return model_inputs["input_ids"].size(-1) <= self.max_prompt_length |
| elif self.video_key in example and isinstance(example.get(self.video_key), list) and len(example.get(self.video_key)) > 0: |
| prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| videos = example[self.video_key] |
| if self.image_dir is not None and len(videos) != 0 and isinstance(videos[0], str): |
| videos = [os.path.join(self.image_dir, video) for video in videos] |
|
|
| processed_videos = [] if len(videos) != 0 else None |
| for video in videos: |
| processed_videos.append(process_video(video)) |
|
|
| model_inputs = self.processor( |
| videos=processed_videos, text=[prompt], add_special_tokens=False, return_tensors="pt" |
| ) |
| |
| return model_inputs["input_ids"].size(-1) <= self.max_prompt_length |
| else: |
| input_ids = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True) |
| return len(input_ids) <= self.max_prompt_length |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, index): |
| example: dict = self.dataset[index] |
| messages = self._build_messages(example) |
| example.pop(self.prompt_key, None) |
|
|
| if self.image_key in example and isinstance(example.get(self.image_key), list) and len(example.get(self.image_key)) > 0: |
| prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| images = example.pop(self.image_key) |
| if self.image_dir is not None and len(images) != 0 and isinstance(images[0], str): |
| images = [os.path.join(self.image_dir, image) for image in images] |
|
|
| processed_images = [] if len(images) != 0 else None |
| for image in images: |
| processed_images.append(process_image(image, self.min_pixels, self.max_pixels, self.resize_size)) |
|
|
| model_inputs = self.processor(processed_images, [prompt], add_special_tokens=False, return_tensors="pt") |
| input_ids = model_inputs.pop("input_ids")[0] |
| attention_mask = model_inputs.pop("attention_mask")[0] |
| example["multi_modal_data"] = {"images": images} |
| elif self.video_key in example and isinstance(example.get(self.video_key), list) and len(example.get(self.video_key)) > 0: |
| prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| videos = example.pop(self.video_key) |
| if self.image_dir is not None and len(videos) != 0 and isinstance(videos[0], str): |
| videos = [os.path.join(self.image_dir, video) for video in videos] |
|
|
| processed_videos = [] if len(videos) != 0 else None |
| video_fps_list = [] |
| for video in videos: |
| processed_video, video_fps = process_video( |
| video, return_fps=True |
| ) |
| video_kwargs = {"do_sample_frames": False} |
| processed_videos.append(processed_video) |
|
|
|
|
| video_fps_list.append(video_fps) |
|
|
| |
| if processed_video is not None: |
| |
| |
| |
| processed_video, video_metadatas = processed_video |
| processed_video, video_metadatas = [processed_video], [video_metadatas] |
| else: |
| video_metadatas = None |
| model_inputs= self.processor(text=[prompt], videos=processed_video, add_special_tokens=False, video_metadata=video_metadatas, return_tensors="pt", do_resize=False, **video_kwargs) |
|
|
| |
|
|
|
|
| input_ids = model_inputs.pop("input_ids")[0] |
| attention_mask = model_inputs.pop("attention_mask")[0] |
| example["multi_modal_data"] = {"videos": videos} |
| else: |
| prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
| model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt") |
| input_ids = model_inputs.pop("input_ids")[0] |
| attention_mask = model_inputs.pop("attention_mask")[0] |
|
|
|
|
| if "images" in example: |
| example.pop("images") |
| elif "videos" in example: |
| example.pop("videos") |
|
|
| |
|
|
| |
|
|
|
|
| if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__: |
| |
| if "Qwen3VLProcessor" in self.processor.__class__.__name__: |
| from ..models.transformers.qwen3_vl import get_rope_index |
|
|
|
|
|
|
| else: |
| from ..models.transformers.qwen2_vl import get_rope_index |
|
|
| vision_position_ids = get_rope_index( |
| self.processor, |
| input_ids=input_ids, |
| image_grid_thw=model_inputs.get("image_grid_thw", None), |
| video_grid_thw=model_inputs.get("video_grid_thw", None), |
| second_per_grid_ts=model_inputs.get("second_per_grid_ts", None), |
| attention_mask=attention_mask, |
| ) |
| text_position_ids = torch.arange(len(input_ids)).unsqueeze(0) |
| position_ids = torch.cat((text_position_ids, vision_position_ids), dim=0) |
| else: |
| position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) |
|
|
| input_ids, attention_mask, position_ids = VF.postprocess_data( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| max_length=self.max_prompt_length, |
| pad_token_id=self.tokenizer.pad_token_id, |
| left_pad=True, |
| truncation=self.truncation, |
| ) |
| raw_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) |
| if len(raw_prompt_ids) > self.max_prompt_length: |
| if self.truncation == "left": |
| raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] |
| elif self.truncation == "right": |
| raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] |
| elif self.truncation == "error": |
| raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") |
|
|
| example["input_ids"] = input_ids |
| example["attention_mask"] = attention_mask |
| example["position_ids"] = position_ids |
| example["raw_prompt_ids"] = raw_prompt_ids |
| example["ground_truth"] = example.pop(self.answer_key) |
|
|
| |
| |
|
|
|
|
| |
| return example |
|
|