song2midi-processor / processing_song2midi.py
B-K's picture
Upload processing_song2midi.py
e23a631 verified
import bisect
from typing import Unpack
from transformers import BatchFeature
from transformers.audio_utils import load_audio
from transformers.processing_utils import AllKwargsForChatTemplate, ProcessorMixin
from transformers.utils.chat_template_utils import render_jinja_template
class Song2MIDIProcessor(ProcessorMixin):
def __init__(
self,
tokenizer,
midi_tokenizer,
feature_extractor,
midi_pad="<|midi_pad|>",
**kwargs,
):
self.midi_offset_by = len(tokenizer)
self.midi_pad_token = midi_pad
super().__init__(tokenizer, midi_tokenizer, feature_extractor, **kwargs)
def __call__(
self, images=None, text=None, videos=None, audio=None, midi=None, **kwargs
):
# From https://github.com/huggingface/transformers/blob/e5a861d381bf65a146ce487c3d3c0fca919ef316/src/transformers/processing_utils.py#L606
if "audios" in kwargs and audio is None:
raise ValueError(
"You passed keyword argument `audios` which is deprecated. Please use `audio` instead."
)
if images is None and text is None and videos is None and audio is None and midi is None:
raise ValueError(
f"You need to provide at least one input to call {self.__class__.__name__}"
)
kwargs = self._merge_kwargs(
self.valid_processor_kwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs
if hasattr(self, "tokenizer")
else {},
**kwargs,
)
kwargs["midi_kwargs"] = {}
# We will do the padding later
text_kwargs = kwargs.get("text_kwargs", {})
kwargs["text_kwargs"] = {}
attribute_to_kwargs = {
"tokenizer": (text, "text_kwargs"),
"image_processor": (images, "images_kwargs"),
"video_processor": (videos, "videos_kwargs"),
"feature_extractor": (audio, "audio_kwargs"),
"midi_tokenizer": (midi, "midi_kwargs"),
}
outputs = {}
for attribute_name in self.get_attributes():
attribute = getattr(self, attribute_name, None)
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
if input_data is not None and attribute is not None:
if attribute_name == "midi_tokenizer":
# Change the None to empty string to avoid errors in tokenizers when trying to tokenize None.
if isinstance(input_data, (list, tuple)):
input_data = [
item if item is not None else "" for item in input_data
]
else:
input_data = input_data if input_data is not None else ""
attribute_output = attribute(input_data, **kwargs[input_kwargs])
outputs[attribute_name] = attribute_output
midi_token_id = self.tokenizer.convert_tokens_to_ids(self.midi_pad_token)
def _merge_text_midi(text_input_ids, midi_input_ids):
is_batched = True
if text_input_ids and isinstance(text_input_ids[0], int):
is_batched = False
text_input_ids = [text_input_ids]
midi_input_ids = [midi_input_ids]
new_input_ids = []
midi_idx = 0
for batch_idx in range(len(text_input_ids)):
new_ids = []
for token_id in text_input_ids[batch_idx]:
if token_id == midi_token_id and midi_idx < len(midi_input_ids):
new_ids.extend(
[
tok + self.midi_offset_by
for tok in midi_input_ids[midi_idx]
]
)
midi_idx += 1
else:
new_ids.append(token_id)
new_input_ids.append(new_ids)
return new_input_ids if is_batched else new_input_ids[0]
new_outputs = {}
if midi:
new_text_input_ids = {
"input_ids": _merge_text_midi(
outputs["tokenizer"]["input_ids"],
outputs["midi_tokenizer"]["input_ids"],
)
}
else:
new_text_input_ids = {"input_ids": outputs["tokenizer"]["input_ids"]}
# Pad
new_outputs.update(self.tokenizer.pad(new_text_input_ids, **text_kwargs))
for key, value in outputs.items():
if key not in ["tokenizer", "midi_tokenizer"]:
new_outputs.update(value)
return BatchFeature(new_outputs)
def apply_chat_template(
self,
conversation: list[dict[str, str]] | list[list[dict[str, str]]],
chat_template: str | None = None,
**kwargs: Unpack[AllKwargsForChatTemplate],
) -> str:
# From https://github.com/huggingface/transformers/blob/e5a861d381bf65a146ce487c3d3c0fca919ef316/src/transformers/processing_utils.py#L1631
if chat_template is None:
if isinstance(self.chat_template, dict) and "default" in self.chat_template:
chat_template = self.chat_template["default"]
elif isinstance(self.chat_template, dict):
raise ValueError(
'The processor has multiple chat templates but none of them are named "default". You need to specify'
" which one to use by passing the `chat_template` argument. Available templates are: "
f"{', '.join(self.chat_template.keys())}"
)
elif self.chat_template is not None:
chat_template = self.chat_template
else:
raise ValueError(
"Cannot use apply_chat_template because this processor does not have a chat template."
)
else:
if (
isinstance(self.chat_template, dict)
and chat_template in self.chat_template
):
# It's the name of a template, not a full template string
chat_template = self.chat_template[chat_template]
else:
# It's a template string, render it directly
pass
# Check if tokenizer is fast - use backend attribute if available, otherwise fall back to class name
is_tokenizers_fast = False
if hasattr(self, "tokenizer"):
if hasattr(self.tokenizer, "backend"):
is_tokenizers_fast = self.tokenizer.backend == "tokenizers"
else:
# Fallback to class name check
is_tokenizers_fast = self.tokenizer.__class__.__name__.endswith("Fast")
if kwargs.get("continue_final_message", False):
if kwargs.get("add_generation_prompt", False):
raise ValueError(
"continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
)
if kwargs.get("return_assistant_tokens_mask", False):
raise ValueError(
"continue_final_message is not compatible with return_assistant_tokens_mask."
)
if kwargs.get("return_assistant_tokens_mask", False):
if not is_tokenizers_fast:
raise ValueError(
"`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. "
"If the error persists, open an issue to support a Fast tokenizer for your model."
)
else:
kwargs["return_offsets_mapping"] = (
True # force offset mapping so we can infer token boundaries
)
# Fill sets of kwargs that should be used by jinja template, filtering out kwargs used in `processor.__call__`
# NOTE: we don't only filter but also set the default values here. Without default values, we can remove it
template_kwargs = {}
for key in AllKwargsForChatTemplate.__annotations__[
"template_kwargs"
].__annotations__:
kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[
"template_kwargs"
]
default_value = getattr(kwarg_type_defaults, key, None)
value = kwargs.pop(key, default_value)
if value is not None and not isinstance(value, dict):
template_kwargs[key] = value
# Pass unprocessed custom kwargs
template_kwargs.update(kwargs)
# Set the sampling rate to load the audio files if user hasn't already passed with `kwargs`
if "sampling_rate" not in template_kwargs:
if hasattr(self, "feature_extractor") and hasattr(
self.feature_extractor, "sampling_rate"
):
template_kwargs["sampling_rate"] = self.feature_extractor.sampling_rate
else:
template_kwargs["sampling_rate"] = 16_000
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple))
or hasattr(conversation[0], "content")
):
is_batched = True
conversations = conversation
else:
is_batched = False
conversations = [conversation]
# Normalize OpenAI-style "image_url" content blocks to HuggingFace-style "image" blocks
# OpenAI format: {"type": "image_url", "image_url": {"url": "..."}}
# HuggingFace format: {"type": "image", "url": "..."}
for conversation_idx, conversation in enumerate(conversations):
for message in conversation:
if not isinstance(message.get("content"), list):
continue
new_content = []
for content in message["content"]:
if (
isinstance(content, dict)
and content.get("type") == "image_url"
and "image_url" in content
):
image_url_info = content["image_url"]
url = (
image_url_info.get("url", "")
if isinstance(image_url_info, dict)
else image_url_info
)
new_content.append({"type": "image", "url": url})
else:
new_content.append(content)
message["content"] = new_content
tokenize = template_kwargs.pop("tokenize", False)
return_dict = template_kwargs.pop("return_dict", True)
if tokenize:
batch_images, batch_videos = [], []
batch_audios = []
batch_midis = [] # midi
for conversation in conversations:
images, videos = [], []
for message in conversation:
visuals = [
content
for content in message["content"]
if content["type"] in ["image", "video"]
]
audio_fnames = [
content[key]
for content in message["content"]
for key in ["audio", "url", "path"]
if key in content and content["type"] == "audio"
]
image_fnames = [
vision_info[key]
for vision_info in visuals
for key in ["image", "url", "path", "base64"]
if key in vision_info and vision_info["type"] == "image"
]
images.extend(image_fnames)
video_fnames = [
vision_info[key]
for vision_info in visuals
for key in ["video", "url", "path"]
if key in vision_info and vision_info["type"] == "video"
]
videos.extend(video_fnames)
# midi
midi_fnames = [
content[key]
for content in message["content"]
for key in ["score", "path"]
if key in content and content["type"] == "midi"
]
batch_midis.extend(midi_fnames)
# Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list
if not template_kwargs["load_audio_from_video"]:
for fname in audio_fnames:
batch_audios.append(
load_audio(
fname,
sampling_rate=template_kwargs["sampling_rate"],
)
)
else:
for fname in video_fnames:
batch_audios.append(
load_audio(
fname,
sampling_rate=template_kwargs["sampling_rate"],
)
)
# Currently all processors can accept nested list of batches, but not flat list of visuals
# So we'll make a batched list of images and let the processor handle it
batch_images.append(images)
batch_videos.append(videos)
special_tokens_map = {}
if hasattr(self, "tokenizer") and hasattr(self.tokenizer, "special_tokens_map"):
special_tokens = self.tokenizer.special_tokens_map
# Filter out tokens that conflict with template kwargs
special_tokens_map = {
k: v for k, v in special_tokens.items() if k not in template_kwargs
}
prompt, generation_indices = render_jinja_template(
conversations=conversations,
chat_template=chat_template,
**template_kwargs, # different flags such as `return_assistant_mask`
**special_tokens_map, # tokenizer special tokens are used by some templates
)
if not is_batched:
prompt = prompt[0]
if tokenize:
# Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
# But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
# and pass it to the processor. Users thus never worried about special tokens relying on processor handling
# everything internally. The below line is to keep BC for that and be able to work with model that have
# special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line
# without actionable solution for users
single_prompt = prompt[0] if is_batched else prompt
if self.tokenizer.bos_token is not None and single_prompt.startswith(
self.tokenizer.bos_token
):
kwargs["add_special_tokens"] = False
# Always sample frames by default unless explicitly set to `False` by users. If users do not pass `num_frames`/`fps`
# sampling should not done for BC.
if "do_sample_frames" not in kwargs and (
kwargs.get("fps") is not None or kwargs.get("num_frames") is not None
):
kwargs["do_sample_frames"] = True
images_exist = any(
(im is not None) for im_list in batch_images for im in im_list
)
videos_exist = any(
(vid is not None) for vid_list in batch_videos for vid in vid_list
)
out = self(
text=prompt,
images=batch_images if images_exist else None,
videos=batch_videos if videos_exist else None,
audio=batch_audios if batch_audios else None,
midi=batch_midis if batch_midis else None,
**kwargs,
)
if return_dict:
if template_kwargs.get("return_assistant_tokens_mask", False):
assistant_masks = []
offset_mapping = out.pop("offset_mapping")
input_ids = out["input_ids"]
for i in range(len(input_ids)):
current_mask = [0] * len(input_ids[i])
offsets = offset_mapping[i]
offset_starts = [start for start, end in offsets]
for (
assistant_start_char,
assistant_end_char,
) in generation_indices[i]:
start_pos = bisect.bisect_left(
offset_starts, assistant_start_char
)
end_pos = bisect.bisect_left(
offset_starts, assistant_end_char
)
if not (
start_pos >= 0
and start_pos < len(offsets)
and offsets[start_pos][0]
<= assistant_start_char
< offsets[start_pos][1]
):
# start_token is out of bounds maybe due to truncation.
continue
# Ensure end_pos is also within bounds
if end_pos > len(input_ids[i]):
end_pos = len(input_ids[i])
for token_id in range(
start_pos, end_pos if end_pos else len(input_ids[i])
):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks
out.convert_to_tensors(tensor_type=kwargs.get("return_tensors"))
return out
else:
return out["input_ids"]
return prompt