import bisect from typing import Unpack from transformers import BatchFeature from transformers.audio_utils import load_audio from transformers.processing_utils import AllKwargsForChatTemplate, ProcessorMixin from transformers.utils.chat_template_utils import render_jinja_template class Song2MIDIProcessor(ProcessorMixin): def __init__( self, tokenizer, midi_tokenizer, feature_extractor, midi_pad="<|midi_pad|>", **kwargs, ): self.midi_offset_by = len(tokenizer) self.midi_pad_token = midi_pad super().__init__(tokenizer, midi_tokenizer, feature_extractor, **kwargs) def __call__( self, images=None, text=None, videos=None, audio=None, midi=None, **kwargs ): # From https://github.com/huggingface/transformers/blob/e5a861d381bf65a146ce487c3d3c0fca919ef316/src/transformers/processing_utils.py#L606 if "audios" in kwargs and audio is None: raise ValueError( "You passed keyword argument `audios` which is deprecated. Please use `audio` instead." ) if images is None and text is None and videos is None and audio is None and midi is None: raise ValueError( f"You need to provide at least one input to call {self.__class__.__name__}" ) kwargs = self._merge_kwargs( self.valid_processor_kwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs if hasattr(self, "tokenizer") else {}, **kwargs, ) kwargs["midi_kwargs"] = {} # We will do the padding later text_kwargs = kwargs.get("text_kwargs", {}) kwargs["text_kwargs"] = {} attribute_to_kwargs = { "tokenizer": (text, "text_kwargs"), "image_processor": (images, "images_kwargs"), "video_processor": (videos, "videos_kwargs"), "feature_extractor": (audio, "audio_kwargs"), "midi_tokenizer": (midi, "midi_kwargs"), } outputs = {} for attribute_name in self.get_attributes(): attribute = getattr(self, attribute_name, None) input_data, input_kwargs = attribute_to_kwargs[attribute_name] if input_data is not None and attribute is not None: if attribute_name == "midi_tokenizer": # Change the None to empty string to avoid errors in tokenizers when trying to tokenize None. if isinstance(input_data, (list, tuple)): input_data = [ item if item is not None else "" for item in input_data ] else: input_data = input_data if input_data is not None else "" attribute_output = attribute(input_data, **kwargs[input_kwargs]) outputs[attribute_name] = attribute_output midi_token_id = self.tokenizer.convert_tokens_to_ids(self.midi_pad_token) def _merge_text_midi(text_input_ids, midi_input_ids): is_batched = True if text_input_ids and isinstance(text_input_ids[0], int): is_batched = False text_input_ids = [text_input_ids] midi_input_ids = [midi_input_ids] new_input_ids = [] midi_idx = 0 for batch_idx in range(len(text_input_ids)): new_ids = [] for token_id in text_input_ids[batch_idx]: if token_id == midi_token_id and midi_idx < len(midi_input_ids): new_ids.extend( [ tok + self.midi_offset_by for tok in midi_input_ids[midi_idx] ] ) midi_idx += 1 else: new_ids.append(token_id) new_input_ids.append(new_ids) return new_input_ids if is_batched else new_input_ids[0] new_outputs = {} if midi: new_text_input_ids = { "input_ids": _merge_text_midi( outputs["tokenizer"]["input_ids"], outputs["midi_tokenizer"]["input_ids"], ) } else: new_text_input_ids = {"input_ids": outputs["tokenizer"]["input_ids"]} # Pad new_outputs.update(self.tokenizer.pad(new_text_input_ids, **text_kwargs)) for key, value in outputs.items(): if key not in ["tokenizer", "midi_tokenizer"]: new_outputs.update(value) return BatchFeature(new_outputs) def apply_chat_template( self, conversation: list[dict[str, str]] | list[list[dict[str, str]]], chat_template: str | None = None, **kwargs: Unpack[AllKwargsForChatTemplate], ) -> str: # From https://github.com/huggingface/transformers/blob/e5a861d381bf65a146ce487c3d3c0fca919ef316/src/transformers/processing_utils.py#L1631 if chat_template is None: if isinstance(self.chat_template, dict) and "default" in self.chat_template: chat_template = self.chat_template["default"] elif isinstance(self.chat_template, dict): raise ValueError( 'The processor has multiple chat templates but none of them are named "default". You need to specify' " which one to use by passing the `chat_template` argument. Available templates are: " f"{', '.join(self.chat_template.keys())}" ) elif self.chat_template is not None: chat_template = self.chat_template else: raise ValueError( "Cannot use apply_chat_template because this processor does not have a chat template." ) else: if ( isinstance(self.chat_template, dict) and chat_template in self.chat_template ): # It's the name of a template, not a full template string chat_template = self.chat_template[chat_template] else: # It's a template string, render it directly pass # Check if tokenizer is fast - use backend attribute if available, otherwise fall back to class name is_tokenizers_fast = False if hasattr(self, "tokenizer"): if hasattr(self.tokenizer, "backend"): is_tokenizers_fast = self.tokenizer.backend == "tokenizers" else: # Fallback to class name check is_tokenizers_fast = self.tokenizer.__class__.__name__.endswith("Fast") if kwargs.get("continue_final_message", False): if kwargs.get("add_generation_prompt", False): raise ValueError( "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." ) if kwargs.get("return_assistant_tokens_mask", False): raise ValueError( "continue_final_message is not compatible with return_assistant_tokens_mask." ) if kwargs.get("return_assistant_tokens_mask", False): if not is_tokenizers_fast: raise ValueError( "`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. " "If the error persists, open an issue to support a Fast tokenizer for your model." ) else: kwargs["return_offsets_mapping"] = ( True # force offset mapping so we can infer token boundaries ) # Fill sets of kwargs that should be used by jinja template, filtering out kwargs used in `processor.__call__` # NOTE: we don't only filter but also set the default values here. Without default values, we can remove it template_kwargs = {} for key in AllKwargsForChatTemplate.__annotations__[ "template_kwargs" ].__annotations__: kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[ "template_kwargs" ] default_value = getattr(kwarg_type_defaults, key, None) value = kwargs.pop(key, default_value) if value is not None and not isinstance(value, dict): template_kwargs[key] = value # Pass unprocessed custom kwargs template_kwargs.update(kwargs) # Set the sampling rate to load the audio files if user hasn't already passed with `kwargs` if "sampling_rate" not in template_kwargs: if hasattr(self, "feature_extractor") and hasattr( self.feature_extractor, "sampling_rate" ): template_kwargs["sampling_rate"] = self.feature_extractor.sampling_rate else: template_kwargs["sampling_rate"] = 16_000 if isinstance(conversation, (list, tuple)) and ( isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content") ): is_batched = True conversations = conversation else: is_batched = False conversations = [conversation] # Normalize OpenAI-style "image_url" content blocks to HuggingFace-style "image" blocks # OpenAI format: {"type": "image_url", "image_url": {"url": "..."}} # HuggingFace format: {"type": "image", "url": "..."} for conversation_idx, conversation in enumerate(conversations): for message in conversation: if not isinstance(message.get("content"), list): continue new_content = [] for content in message["content"]: if ( isinstance(content, dict) and content.get("type") == "image_url" and "image_url" in content ): image_url_info = content["image_url"] url = ( image_url_info.get("url", "") if isinstance(image_url_info, dict) else image_url_info ) new_content.append({"type": "image", "url": url}) else: new_content.append(content) message["content"] = new_content tokenize = template_kwargs.pop("tokenize", False) return_dict = template_kwargs.pop("return_dict", True) if tokenize: batch_images, batch_videos = [], [] batch_audios = [] batch_midis = [] # midi for conversation in conversations: images, videos = [], [] for message in conversation: visuals = [ content for content in message["content"] if content["type"] in ["image", "video"] ] audio_fnames = [ content[key] for content in message["content"] for key in ["audio", "url", "path"] if key in content and content["type"] == "audio" ] image_fnames = [ vision_info[key] for vision_info in visuals for key in ["image", "url", "path", "base64"] if key in vision_info and vision_info["type"] == "image" ] images.extend(image_fnames) video_fnames = [ vision_info[key] for vision_info in visuals for key in ["video", "url", "path"] if key in vision_info and vision_info["type"] == "video" ] videos.extend(video_fnames) # midi midi_fnames = [ content[key] for content in message["content"] for key in ["score", "path"] if key in content and content["type"] == "midi" ] batch_midis.extend(midi_fnames) # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list if not template_kwargs["load_audio_from_video"]: for fname in audio_fnames: batch_audios.append( load_audio( fname, sampling_rate=template_kwargs["sampling_rate"], ) ) else: for fname in video_fnames: batch_audios.append( load_audio( fname, sampling_rate=template_kwargs["sampling_rate"], ) ) # Currently all processors can accept nested list of batches, but not flat list of visuals # So we'll make a batched list of images and let the processor handle it batch_images.append(images) batch_videos.append(videos) special_tokens_map = {} if hasattr(self, "tokenizer") and hasattr(self.tokenizer, "special_tokens_map"): special_tokens = self.tokenizer.special_tokens_map # Filter out tokens that conflict with template kwargs special_tokens_map = { k: v for k, v in special_tokens.items() if k not in template_kwargs } prompt, generation_indices = render_jinja_template( conversations=conversations, chat_template=chat_template, **template_kwargs, # different flags such as `return_assistant_mask` **special_tokens_map, # tokenizer special tokens are used by some templates ) if not is_batched: prompt = prompt[0] if tokenize: # Tokenizer's `apply_chat_template` never adds special tokens when tokenizing # But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt # and pass it to the processor. Users thus never worried about special tokens relying on processor handling # everything internally. The below line is to keep BC for that and be able to work with model that have # special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line # without actionable solution for users single_prompt = prompt[0] if is_batched else prompt if self.tokenizer.bos_token is not None and single_prompt.startswith( self.tokenizer.bos_token ): kwargs["add_special_tokens"] = False # Always sample frames by default unless explicitly set to `False` by users. If users do not pass `num_frames`/`fps` # sampling should not done for BC. if "do_sample_frames" not in kwargs and ( kwargs.get("fps") is not None or kwargs.get("num_frames") is not None ): kwargs["do_sample_frames"] = True images_exist = any( (im is not None) for im_list in batch_images for im in im_list ) videos_exist = any( (vid is not None) for vid_list in batch_videos for vid in vid_list ) out = self( text=prompt, images=batch_images if images_exist else None, videos=batch_videos if videos_exist else None, audio=batch_audios if batch_audios else None, midi=batch_midis if batch_midis else None, **kwargs, ) if return_dict: if template_kwargs.get("return_assistant_tokens_mask", False): assistant_masks = [] offset_mapping = out.pop("offset_mapping") input_ids = out["input_ids"] for i in range(len(input_ids)): current_mask = [0] * len(input_ids[i]) offsets = offset_mapping[i] offset_starts = [start for start, end in offsets] for ( assistant_start_char, assistant_end_char, ) in generation_indices[i]: start_pos = bisect.bisect_left( offset_starts, assistant_start_char ) end_pos = bisect.bisect_left( offset_starts, assistant_end_char ) if not ( start_pos >= 0 and start_pos < len(offsets) and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1] ): # start_token is out of bounds maybe due to truncation. continue # Ensure end_pos is also within bounds if end_pos > len(input_ids[i]): end_pos = len(input_ids[i]) for token_id in range( start_pos, end_pos if end_pos else len(input_ids[i]) ): current_mask[token_id] = 1 assistant_masks.append(current_mask) out["assistant_masks"] = assistant_masks out.convert_to_tensors(tensor_type=kwargs.get("return_tensors")) return out else: return out["input_ids"] return prompt