| import bisect |
| from typing import Unpack |
|
|
| from transformers import BatchFeature |
| from transformers.audio_utils import load_audio |
| from transformers.processing_utils import AllKwargsForChatTemplate, ProcessorMixin |
| from transformers.utils.chat_template_utils import render_jinja_template |
|
|
|
|
| class Song2MIDIProcessor(ProcessorMixin): |
| def __init__( |
| self, |
| tokenizer, |
| midi_tokenizer, |
| feature_extractor, |
| midi_pad="<|midi_pad|>", |
| **kwargs, |
| ): |
| self.midi_offset_by = len(tokenizer) |
| self.midi_pad_token = midi_pad |
|
|
| super().__init__(tokenizer, midi_tokenizer, feature_extractor, **kwargs) |
|
|
| def __call__( |
| self, images=None, text=None, videos=None, audio=None, midi=None, **kwargs |
| ): |
| |
| if "audios" in kwargs and audio is None: |
| raise ValueError( |
| "You passed keyword argument `audios` which is deprecated. Please use `audio` instead." |
| ) |
|
|
| if images is None and text is None and videos is None and audio is None and midi is None: |
| raise ValueError( |
| f"You need to provide at least one input to call {self.__class__.__name__}" |
| ) |
|
|
| kwargs = self._merge_kwargs( |
| self.valid_processor_kwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs |
| if hasattr(self, "tokenizer") |
| else {}, |
| **kwargs, |
| ) |
| kwargs["midi_kwargs"] = {} |
|
|
| |
| text_kwargs = kwargs.get("text_kwargs", {}) |
| kwargs["text_kwargs"] = {} |
|
|
| attribute_to_kwargs = { |
| "tokenizer": (text, "text_kwargs"), |
| "image_processor": (images, "images_kwargs"), |
| "video_processor": (videos, "videos_kwargs"), |
| "feature_extractor": (audio, "audio_kwargs"), |
| "midi_tokenizer": (midi, "midi_kwargs"), |
| } |
| outputs = {} |
| for attribute_name in self.get_attributes(): |
| attribute = getattr(self, attribute_name, None) |
| input_data, input_kwargs = attribute_to_kwargs[attribute_name] |
| if input_data is not None and attribute is not None: |
| if attribute_name == "midi_tokenizer": |
| |
| if isinstance(input_data, (list, tuple)): |
| input_data = [ |
| item if item is not None else "" for item in input_data |
| ] |
| else: |
| input_data = input_data if input_data is not None else "" |
|
|
| attribute_output = attribute(input_data, **kwargs[input_kwargs]) |
| outputs[attribute_name] = attribute_output |
|
|
| midi_token_id = self.tokenizer.convert_tokens_to_ids(self.midi_pad_token) |
|
|
| def _merge_text_midi(text_input_ids, midi_input_ids): |
| is_batched = True |
| if text_input_ids and isinstance(text_input_ids[0], int): |
| is_batched = False |
| text_input_ids = [text_input_ids] |
| midi_input_ids = [midi_input_ids] |
|
|
| new_input_ids = [] |
| midi_idx = 0 |
| for batch_idx in range(len(text_input_ids)): |
| new_ids = [] |
| for token_id in text_input_ids[batch_idx]: |
| if token_id == midi_token_id and midi_idx < len(midi_input_ids): |
| new_ids.extend( |
| [ |
| tok + self.midi_offset_by |
| for tok in midi_input_ids[midi_idx] |
| ] |
| ) |
| midi_idx += 1 |
| else: |
| new_ids.append(token_id) |
| new_input_ids.append(new_ids) |
|
|
| return new_input_ids if is_batched else new_input_ids[0] |
|
|
| new_outputs = {} |
| if midi: |
| new_text_input_ids = { |
| "input_ids": _merge_text_midi( |
| outputs["tokenizer"]["input_ids"], |
| outputs["midi_tokenizer"]["input_ids"], |
| ) |
| } |
| else: |
| new_text_input_ids = {"input_ids": outputs["tokenizer"]["input_ids"]} |
|
|
| |
| new_outputs.update(self.tokenizer.pad(new_text_input_ids, **text_kwargs)) |
|
|
| for key, value in outputs.items(): |
| if key not in ["tokenizer", "midi_tokenizer"]: |
| new_outputs.update(value) |
|
|
| return BatchFeature(new_outputs) |
|
|
| def apply_chat_template( |
| self, |
| conversation: list[dict[str, str]] | list[list[dict[str, str]]], |
| chat_template: str | None = None, |
| **kwargs: Unpack[AllKwargsForChatTemplate], |
| ) -> str: |
| |
| if chat_template is None: |
| if isinstance(self.chat_template, dict) and "default" in self.chat_template: |
| chat_template = self.chat_template["default"] |
| elif isinstance(self.chat_template, dict): |
| raise ValueError( |
| 'The processor has multiple chat templates but none of them are named "default". You need to specify' |
| " which one to use by passing the `chat_template` argument. Available templates are: " |
| f"{', '.join(self.chat_template.keys())}" |
| ) |
| elif self.chat_template is not None: |
| chat_template = self.chat_template |
| else: |
| raise ValueError( |
| "Cannot use apply_chat_template because this processor does not have a chat template." |
| ) |
| else: |
| if ( |
| isinstance(self.chat_template, dict) |
| and chat_template in self.chat_template |
| ): |
| |
| chat_template = self.chat_template[chat_template] |
| else: |
| |
| pass |
|
|
| |
| is_tokenizers_fast = False |
| if hasattr(self, "tokenizer"): |
| if hasattr(self.tokenizer, "backend"): |
| is_tokenizers_fast = self.tokenizer.backend == "tokenizers" |
| else: |
| |
| is_tokenizers_fast = self.tokenizer.__class__.__name__.endswith("Fast") |
|
|
| if kwargs.get("continue_final_message", False): |
| if kwargs.get("add_generation_prompt", False): |
| raise ValueError( |
| "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." |
| ) |
| if kwargs.get("return_assistant_tokens_mask", False): |
| raise ValueError( |
| "continue_final_message is not compatible with return_assistant_tokens_mask." |
| ) |
|
|
| if kwargs.get("return_assistant_tokens_mask", False): |
| if not is_tokenizers_fast: |
| raise ValueError( |
| "`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. " |
| "If the error persists, open an issue to support a Fast tokenizer for your model." |
| ) |
| else: |
| kwargs["return_offsets_mapping"] = ( |
| True |
| ) |
|
|
| |
| |
| template_kwargs = {} |
| for key in AllKwargsForChatTemplate.__annotations__[ |
| "template_kwargs" |
| ].__annotations__: |
| kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[ |
| "template_kwargs" |
| ] |
| default_value = getattr(kwarg_type_defaults, key, None) |
| value = kwargs.pop(key, default_value) |
| if value is not None and not isinstance(value, dict): |
| template_kwargs[key] = value |
|
|
| |
| template_kwargs.update(kwargs) |
|
|
| |
| if "sampling_rate" not in template_kwargs: |
| if hasattr(self, "feature_extractor") and hasattr( |
| self.feature_extractor, "sampling_rate" |
| ): |
| template_kwargs["sampling_rate"] = self.feature_extractor.sampling_rate |
| else: |
| template_kwargs["sampling_rate"] = 16_000 |
|
|
| if isinstance(conversation, (list, tuple)) and ( |
| isinstance(conversation[0], (list, tuple)) |
| or hasattr(conversation[0], "content") |
| ): |
| is_batched = True |
| conversations = conversation |
| else: |
| is_batched = False |
| conversations = [conversation] |
|
|
| |
| |
| |
| for conversation_idx, conversation in enumerate(conversations): |
| for message in conversation: |
| if not isinstance(message.get("content"), list): |
| continue |
| new_content = [] |
| for content in message["content"]: |
| if ( |
| isinstance(content, dict) |
| and content.get("type") == "image_url" |
| and "image_url" in content |
| ): |
| image_url_info = content["image_url"] |
| url = ( |
| image_url_info.get("url", "") |
| if isinstance(image_url_info, dict) |
| else image_url_info |
| ) |
| new_content.append({"type": "image", "url": url}) |
| else: |
| new_content.append(content) |
| message["content"] = new_content |
|
|
| tokenize = template_kwargs.pop("tokenize", False) |
| return_dict = template_kwargs.pop("return_dict", True) |
|
|
| if tokenize: |
| batch_images, batch_videos = [], [] |
| batch_audios = [] |
| batch_midis = [] |
| for conversation in conversations: |
| images, videos = [], [] |
| for message in conversation: |
| visuals = [ |
| content |
| for content in message["content"] |
| if content["type"] in ["image", "video"] |
| ] |
| audio_fnames = [ |
| content[key] |
| for content in message["content"] |
| for key in ["audio", "url", "path"] |
| if key in content and content["type"] == "audio" |
| ] |
| image_fnames = [ |
| vision_info[key] |
| for vision_info in visuals |
| for key in ["image", "url", "path", "base64"] |
| if key in vision_info and vision_info["type"] == "image" |
| ] |
| images.extend(image_fnames) |
| video_fnames = [ |
| vision_info[key] |
| for vision_info in visuals |
| for key in ["video", "url", "path"] |
| if key in vision_info and vision_info["type"] == "video" |
| ] |
| videos.extend(video_fnames) |
|
|
| |
| midi_fnames = [ |
| content[key] |
| for content in message["content"] |
| for key in ["score", "path"] |
| if key in content and content["type"] == "midi" |
| ] |
| batch_midis.extend(midi_fnames) |
|
|
| |
| if not template_kwargs["load_audio_from_video"]: |
| for fname in audio_fnames: |
| batch_audios.append( |
| load_audio( |
| fname, |
| sampling_rate=template_kwargs["sampling_rate"], |
| ) |
| ) |
| else: |
| for fname in video_fnames: |
| batch_audios.append( |
| load_audio( |
| fname, |
| sampling_rate=template_kwargs["sampling_rate"], |
| ) |
| ) |
|
|
| |
| |
| batch_images.append(images) |
| batch_videos.append(videos) |
|
|
| special_tokens_map = {} |
| if hasattr(self, "tokenizer") and hasattr(self.tokenizer, "special_tokens_map"): |
| special_tokens = self.tokenizer.special_tokens_map |
| |
| special_tokens_map = { |
| k: v for k, v in special_tokens.items() if k not in template_kwargs |
| } |
|
|
| prompt, generation_indices = render_jinja_template( |
| conversations=conversations, |
| chat_template=chat_template, |
| **template_kwargs, |
| **special_tokens_map, |
| ) |
|
|
| if not is_batched: |
| prompt = prompt[0] |
|
|
| if tokenize: |
| |
| |
| |
| |
| |
| |
| single_prompt = prompt[0] if is_batched else prompt |
| if self.tokenizer.bos_token is not None and single_prompt.startswith( |
| self.tokenizer.bos_token |
| ): |
| kwargs["add_special_tokens"] = False |
|
|
| |
| |
| if "do_sample_frames" not in kwargs and ( |
| kwargs.get("fps") is not None or kwargs.get("num_frames") is not None |
| ): |
| kwargs["do_sample_frames"] = True |
|
|
| images_exist = any( |
| (im is not None) for im_list in batch_images for im in im_list |
| ) |
| videos_exist = any( |
| (vid is not None) for vid_list in batch_videos for vid in vid_list |
| ) |
| out = self( |
| text=prompt, |
| images=batch_images if images_exist else None, |
| videos=batch_videos if videos_exist else None, |
| audio=batch_audios if batch_audios else None, |
| midi=batch_midis if batch_midis else None, |
| **kwargs, |
| ) |
|
|
| if return_dict: |
| if template_kwargs.get("return_assistant_tokens_mask", False): |
| assistant_masks = [] |
| offset_mapping = out.pop("offset_mapping") |
| input_ids = out["input_ids"] |
| for i in range(len(input_ids)): |
| current_mask = [0] * len(input_ids[i]) |
| offsets = offset_mapping[i] |
| offset_starts = [start for start, end in offsets] |
| for ( |
| assistant_start_char, |
| assistant_end_char, |
| ) in generation_indices[i]: |
| start_pos = bisect.bisect_left( |
| offset_starts, assistant_start_char |
| ) |
| end_pos = bisect.bisect_left( |
| offset_starts, assistant_end_char |
| ) |
|
|
| if not ( |
| start_pos >= 0 |
| and start_pos < len(offsets) |
| and offsets[start_pos][0] |
| <= assistant_start_char |
| < offsets[start_pos][1] |
| ): |
| |
| continue |
| |
| if end_pos > len(input_ids[i]): |
| end_pos = len(input_ids[i]) |
| for token_id in range( |
| start_pos, end_pos if end_pos else len(input_ids[i]) |
| ): |
| current_mask[token_id] = 1 |
| assistant_masks.append(current_mask) |
| out["assistant_masks"] = assistant_masks |
| out.convert_to_tensors(tensor_type=kwargs.get("return_tensors")) |
| return out |
| else: |
| return out["input_ids"] |
| return prompt |
|
|