#!/usr/bin/env python3 # -*- coding:utf-8 -*- """Dataset/collate implementation for music training data.""" import math import re import torch from torch.nn.utils.rnn import pad_sequence from transformers import AutoTokenizer from audio_tokens import ( EOA_TOKEN, MASK_AUDIO_TOKEN, SOA_TOKEN, add_audio_special_tokens, audio_id_to_token, ) from vocab import ( CHORD_BOS_ID, CHORD_EOS_ID, STRUCTURE_BOS_ID, STRUCTURE_EOS_ID, build_frame_chord_ids, build_frame_structure_ids, normalize_structure_label, ) CN_LANGUAGE_LABELS = {"cn", "zh", "zh-cn", "chinese"} SECTION_NAME_MAP = { "intro": "Intro", "verse": "Verse", "chorus": "Chorus", "prechorus": "Pre-Chorus", "bridge": "Bridge", "outro": "Outro", "pad": "Pad", } SINGLETON_SECTION_NAMES = {"intro", "outro", "pad"} ENDING_PUNCTUATION = {".", ";", "!", "?", "。", "?", "!", ";"} def _pad_batch_field(batch, key: str, padding_value): return pad_sequence( [row[key] for row in batch], batch_first=True, padding_value=padding_value, ) def detect_language(text: str, language: str | None = None) -> str: return ( text.replace(" ", ";") if str(language).strip().lower() in CN_LANGUAGE_LABELS else text ) def normalize_section_text( text: str, structure: str, language: str | None = None ) -> str: text = str(text or "") text = ( text.replace(f"[{structure.upper()}]", "") .replace(f"[{structure.lower()}]", "") .replace(",", ";") .replace(".", ";") .replace(",", ";") .replace("。", ";") ) text = detect_language(text, language=language) text = re.sub(r";(?=[A-Za-z])", "; ", text) if text and text[-1] not in ENDING_PUNCTUATION: text += ";" return text class DataCollate: def __call__(self, batch): input_ids = _pad_batch_field(batch, "token_ids", 0) labels = input_ids mask_padded = _pad_batch_field(batch, "mask", 0) attention_mask_padded = _pad_batch_field(batch, "attention_mask", 0) chord_ids_padded = _pad_batch_field(batch, "chord_ids", 0) structure_ids_padded = _pad_batch_field(batch, "structure_ids", 0) condition_mask_padded = _pad_batch_field(batch, "condition_mask", False) return { "input_ids": input_ids, "labels": labels, "masks": mask_padded, "attention_mask": attention_mask_padded, "chord_ids": chord_ids_padded, "structure_ids": structure_ids_padded, "condition_mask": condition_mask_padded, } class MusicDataset(torch.utils.data.Dataset): """Fly dataset with music-code tokens and section-conditioned text.""" def __init__( self, datasets, split: str, tokenizer_path: str, num_audio_token=16384, fps=25, use_fast=True, ): self._data = datasets[split] self.tokenizer_path = tokenizer_path self.use_fast = use_fast self.num_audio_token = num_audio_token self.fps = fps self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer_path, local_files_only=True, use_fast=self.use_fast, ) add_audio_special_tokens(self.tokenizer, self.num_audio_token) self.tokenizer_vocab_size = len(self.tokenizer) self.audio_prefix_length = int( self.tokenizer.convert_tokens_to_ids(audio_id_to_token(0)) ) self.num_text_token = self.audio_prefix_length self.MASK_AUDIO = int(self.tokenizer.convert_tokens_to_ids(MASK_AUDIO_TOKEN)) self.BOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(SOA_TOKEN)) self.EOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(EOA_TOKEN)) self._assistant_audio_placeholder = f"{SOA_TOKEN}{EOA_TOKEN}" self._chat_template_kwargs = {"enable_thinking": False} def __len__(self): return len(self._data) @staticmethod def _positions(token_ids: torch.Tensor, target_id: int) -> torch.Tensor: return torch.nonzero(token_ids == target_id, as_tuple=False).squeeze(-1) @staticmethod def _sorted_sections(sample: dict) -> list[dict]: return sorted( ( { "raw_index": raw_index, "text": str(seg["text"]), "desc": str(seg["desc"]).strip(), "start": float(seg["start"]), "end": float(seg["end"]), "structure": normalize_structure_label(seg["section"]), } for raw_index, seg in enumerate(sample.get("sections", [])) ), key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]), ) @staticmethod def _sorted_chords(sample: dict) -> list[dict]: return sorted( ( { "raw_index": raw_index, "type": str(seg.get("type")), "start": float(seg.get("start", 0.0)), "end": float(seg.get("end", 0.0)), } for raw_index, seg in enumerate(sample.get("chords", [])) ), key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]), ) def __getitem__(self, idx): sample = self._data[idx] sections = self._prepare_sections(sample) chords = self._prepare_chords(sample) token_ids, attention_mask, frame_idx_map = self._tokenize_messages( self._build_messages(sample, sections), sample["mucodec_codes"], sections, ) total_frames = len(sample["mucodec_codes"]) structure_ids = build_frame_structure_ids(sections, total_frames, fps=self.fps) chord_ids = build_frame_chord_ids(chords, total_frames, fps=self.fps) structure_ids = torch.from_numpy(structure_ids) chord_ids = torch.from_numpy(chord_ids) ( audio_codebook_mask, bos_audio_mask, eos_mask, label_mask, condition_mask, ) = self._build_token_masks(token_ids) chord_ids_aligned, structure_ids_aligned = self._align_condition_ids( token_ids=token_ids, frame_idx_map=frame_idx_map, total_frames=total_frames, chord_ids=chord_ids, structure_ids=structure_ids, audio_codebook_mask=audio_codebook_mask, bos_audio_mask=bos_audio_mask, eos_mask=eos_mask, ) return { "token_ids": token_ids, "mask": label_mask, "attention_mask": attention_mask, "chord_ids": chord_ids_aligned, "structure_ids": structure_ids_aligned, "condition_mask": condition_mask, } def _tokenize_messages( self, messages: list[dict[str, str]], full_audio_codes, sections: list[dict], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: chat_inputs = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=False, return_tensors="pt", return_dict=True, **self._chat_template_kwargs, ) token_ids = chat_inputs["input_ids"] attention_mask = chat_inputs["attention_mask"] token_ids = token_ids.squeeze(0) attention_mask = attention_mask.squeeze(0) token_ids = token_ids.to(torch.long) attention_mask = attention_mask.to(torch.long) return self._expand_audio_tokens( token_ids=token_ids, attention_mask=attention_mask, full_audio_codes=full_audio_codes, sections=sections, ) def _frame_bounds( self, start: float, end: float, total_frames: int, prev_end_idx: int = 0, ) -> tuple[int, int]: start_idx = int(start * self.fps) end_idx = int(math.ceil(end * self.fps)) start_idx = max(prev_end_idx, min(total_frames, start_idx)) end_idx = max(start_idx, min(total_frames, end_idx)) return start_idx, end_idx def _prepare_sections(self, sample: dict) -> list[dict]: sections = [] section_counts: dict[str, int] = {} sample_language = sample.get("language") total_frames = len(sample["mucodec_codes"]) prev_end_idx = 0 for seg in self._sorted_sections(sample): structure = seg["structure"] section_counts[structure] = section_counts.get(structure, 0) + 1 raw_start_idx = max(0, min(total_frames, int(seg["start"] * self.fps))) raw_end_idx = max( raw_start_idx, min(total_frames, int(math.ceil(seg["end"] * self.fps))), ) start_idx = prev_end_idx end_idx = max(start_idx, raw_end_idx) sections.append( { "text": normalize_section_text( seg["text"], structure, language=sample_language ), "desc": seg["desc"], "start": start_idx / float(self.fps), "end": end_idx / float(self.fps), "start_frame": start_idx, "end_frame": end_idx, "structure": structure, "tag": f"{structure}{section_counts[structure]}", "index": section_counts[structure], } ) prev_end_idx = end_idx if sections: sections[-1]["end_frame"] = total_frames sections[-1]["end"] = total_frames / float(self.fps) return sections def _prepare_chords(self, sample: dict) -> list[dict]: chords = [] total_frames = len(sample["mucodec_codes"]) prev_end_idx = 0 for seg in self._sorted_chords(sample): start_idx, end_idx = self._frame_bounds( seg["start"], seg["end"], total_frames, prev_end_idx=prev_end_idx, ) chords.append( { "type": seg["type"], "start": start_idx / float(self.fps), "end": end_idx / float(self.fps), "start_frame": start_idx, "end_frame": end_idx, } ) prev_end_idx = end_idx return chords def _format_section_label(self, section: dict) -> str: structure = section["structure"] index = section["index"] label = SECTION_NAME_MAP[structure] if structure in SINGLETON_SECTION_NAMES and index == 1: return label return f"{label} {index}" def _build_section_user_content( self, sample: dict, section: dict, is_first_turn: bool ) -> str: parts = [] if is_first_turn: style = sample["style"].strip() if style: parts.append( f"Please generate a song in the following style:{style}\n" "Next, I will tell you the requirements and lyrics for the song " "fragment to be generated, section by section." ) else: parts.append( "Please generate the song section by section. " "Next, I will tell you the requirements and lyrics for each fragment." ) section_parts = [f"[{self._format_section_label(section)}]"] desc = section["desc"] if desc: section_parts.append(f"[desc:{desc}]") lyrics = section["text"] if lyrics: section_parts.append(f"[lyrics:{lyrics}]") parts.append("".join(section_parts)) return "\n".join(part for part in parts if part) def _build_messages( self, sample: dict, sections: list[dict], ) -> list[dict[str, str]]: messages: list[dict[str, str]] = [None] * (2 * len(sections)) for i, section in enumerate(sections): msg_idx = 2 * i messages[msg_idx] = { "role": "user", "content": self._build_section_user_content( sample, section, is_first_turn=(i == 0) ), } messages[msg_idx + 1] = { "role": "assistant", "content": self._assistant_audio_placeholder, } return messages def _expand_audio_tokens( self, token_ids: torch.Tensor, attention_mask: torch.Tensor, full_audio_codes, sections: list[dict], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not sections: return ( token_ids, attention_mask, torch.full(token_ids.shape, -1, dtype=torch.long), ) bos_positions = self._positions(token_ids, self.BOS_AUDIO) eos_positions = self._positions(token_ids, self.EOS_AUDIO) audio_code_tensor = torch.as_tensor(full_audio_codes, dtype=torch.long) extra_audio_tokens = sum( int(section["end_frame"]) - int(section["start_frame"]) for section in sections ) final_len = token_ids.numel() + extra_audio_tokens expanded_token_ids = torch.empty(final_len, dtype=torch.long) expanded_attention_mask = torch.empty(final_len, dtype=torch.long) frame_idx_map = torch.full((final_len,), -1, dtype=torch.long) read_pos = 0 write_pos = 0 for bos_pos, eos_pos, section in zip( bos_positions.tolist(), eos_positions.tolist(), sections ): start_idx = int(section["start_frame"]) end_idx = int(section["end_frame"]) audio_len = end_idx - start_idx prefix_len = bos_pos + 1 - read_pos next_write = write_pos + prefix_len expanded_token_ids[write_pos:next_write] = token_ids[read_pos : bos_pos + 1] expanded_attention_mask[write_pos:next_write] = attention_mask[ read_pos : bos_pos + 1 ] frame_idx_map[next_write - 1] = start_idx if audio_len > 0 else -1 write_pos = next_write if audio_len > 0: next_write = write_pos + audio_len expanded_token_ids[write_pos:next_write] = audio_code_tensor[ start_idx:end_idx ] expanded_token_ids[write_pos:next_write].add_(self.audio_prefix_length) expanded_attention_mask[write_pos:next_write] = 1 frame_idx_map[write_pos:next_write] = torch.arange( start_idx, end_idx, dtype=torch.long ) write_pos = next_write expanded_token_ids[write_pos] = token_ids[eos_pos] expanded_attention_mask[write_pos] = attention_mask[eos_pos] frame_idx_map[write_pos] = end_idx - 1 if audio_len > 0 else -1 write_pos += 1 read_pos = eos_pos + 1 tail_len = token_ids.numel() - read_pos if tail_len > 0: expanded_token_ids[write_pos : write_pos + tail_len] = token_ids[read_pos:] expanded_attention_mask[write_pos : write_pos + tail_len] = attention_mask[ read_pos: ] return expanded_token_ids, expanded_attention_mask, frame_idx_map def _build_token_masks( self, token_ids: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: audio_codebook_mask = (token_ids >= self.audio_prefix_length) & ( token_ids < self.MASK_AUDIO ) bos_audio_mask = token_ids == self.BOS_AUDIO eos_mask = token_ids == self.EOS_AUDIO label_mask = (audio_codebook_mask | eos_mask).long() condition_mask = audio_codebook_mask | bos_audio_mask | eos_mask return audio_codebook_mask, bos_audio_mask, eos_mask, label_mask, condition_mask def _align_condition_ids( self, token_ids: torch.Tensor, frame_idx_map: torch.Tensor, total_frames: int, chord_ids: torch.Tensor, structure_ids: torch.Tensor, audio_codebook_mask: torch.Tensor, bos_audio_mask: torch.Tensor, eos_mask: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: seq_len = token_ids.numel() chord_ids_aligned = torch.zeros(seq_len, dtype=torch.long) structure_ids_aligned = torch.zeros(seq_len, dtype=torch.long) bos_positions = torch.nonzero(bos_audio_mask, as_tuple=False).squeeze(-1) chord_ids_aligned[bos_positions] = CHORD_BOS_ID structure_ids_aligned[bos_positions] = STRUCTURE_BOS_ID audio_positions = torch.nonzero(audio_codebook_mask, as_tuple=False).squeeze(-1) cur_frame_idx = frame_idx_map[audio_positions] cur_frame_idx = cur_frame_idx.clamp(0, max(total_frames - 1, 0)) chord_ids_aligned[audio_positions] = chord_ids[cur_frame_idx] structure_ids_aligned[audio_positions] = structure_ids[cur_frame_idx] eos_positions = torch.nonzero(eos_mask, as_tuple=False).squeeze(-1) chord_ids_aligned[eos_positions] = CHORD_EOS_ID structure_ids_aligned[eos_positions] = STRUCTURE_EOS_ID return chord_ids_aligned, structure_ids_aligned