| |
| |
| """Dataset/collate implementation for music training data.""" |
|
|
| import math |
| import re |
|
|
| import torch |
| from torch.nn.utils.rnn import pad_sequence |
| from transformers import AutoTokenizer |
|
|
| from audio_tokens import ( |
| EOA_TOKEN, |
| MASK_AUDIO_TOKEN, |
| SOA_TOKEN, |
| add_audio_special_tokens, |
| audio_id_to_token, |
| ) |
| from vocab import ( |
| CHORD_BOS_ID, |
| CHORD_EOS_ID, |
| STRUCTURE_BOS_ID, |
| STRUCTURE_EOS_ID, |
| build_frame_chord_ids, |
| build_frame_structure_ids, |
| normalize_structure_label, |
| ) |
|
|
|
|
| CN_LANGUAGE_LABELS = {"cn", "zh", "zh-cn", "chinese"} |
| SECTION_NAME_MAP = { |
| "intro": "Intro", |
| "verse": "Verse", |
| "chorus": "Chorus", |
| "prechorus": "Pre-Chorus", |
| "bridge": "Bridge", |
| "outro": "Outro", |
| "pad": "Pad", |
| } |
| SINGLETON_SECTION_NAMES = {"intro", "outro", "pad"} |
| ENDING_PUNCTUATION = {".", ";", "!", "?", "。", "?", "!", ";"} |
|
|
|
|
| def _pad_batch_field(batch, key: str, padding_value): |
| return pad_sequence( |
| [row[key] for row in batch], |
| batch_first=True, |
| padding_value=padding_value, |
| ) |
|
|
|
|
| def detect_language(text: str, language: str | None = None) -> str: |
| return ( |
| text.replace(" ", ";") |
| if str(language).strip().lower() in CN_LANGUAGE_LABELS |
| else text |
| ) |
|
|
|
|
| def normalize_section_text( |
| text: str, structure: str, language: str | None = None |
| ) -> str: |
| text = str(text or "") |
| text = ( |
| text.replace(f"[{structure.upper()}]", "") |
| .replace(f"[{structure.lower()}]", "") |
| .replace(",", ";") |
| .replace(".", ";") |
| .replace(",", ";") |
| .replace("。", ";") |
| ) |
| text = detect_language(text, language=language) |
| text = re.sub(r";(?=[A-Za-z])", "; ", text) |
| if text and text[-1] not in ENDING_PUNCTUATION: |
| text += ";" |
| return text |
|
|
|
|
| class DataCollate: |
| def __call__(self, batch): |
| input_ids = _pad_batch_field(batch, "token_ids", 0) |
| labels = input_ids |
| mask_padded = _pad_batch_field(batch, "mask", 0) |
| attention_mask_padded = _pad_batch_field(batch, "attention_mask", 0) |
| chord_ids_padded = _pad_batch_field(batch, "chord_ids", 0) |
| structure_ids_padded = _pad_batch_field(batch, "structure_ids", 0) |
| condition_mask_padded = _pad_batch_field(batch, "condition_mask", False) |
|
|
| return { |
| "input_ids": input_ids, |
| "labels": labels, |
| "masks": mask_padded, |
| "attention_mask": attention_mask_padded, |
| "chord_ids": chord_ids_padded, |
| "structure_ids": structure_ids_padded, |
| "condition_mask": condition_mask_padded, |
| } |
|
|
|
|
| class MusicDataset(torch.utils.data.Dataset): |
| """Fly dataset with music-code tokens and section-conditioned text.""" |
|
|
| def __init__( |
| self, |
| datasets, |
| split: str, |
| tokenizer_path: str, |
| num_audio_token=16384, |
| fps=25, |
| use_fast=True, |
| ): |
| self._data = datasets[split] |
| self.tokenizer_path = tokenizer_path |
| self.use_fast = use_fast |
| self.num_audio_token = num_audio_token |
| self.fps = fps |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.tokenizer_path, |
| local_files_only=True, |
| use_fast=self.use_fast, |
| ) |
| add_audio_special_tokens(self.tokenizer, self.num_audio_token) |
| self.tokenizer_vocab_size = len(self.tokenizer) |
|
|
| self.audio_prefix_length = int( |
| self.tokenizer.convert_tokens_to_ids(audio_id_to_token(0)) |
| ) |
| self.num_text_token = self.audio_prefix_length |
| self.MASK_AUDIO = int(self.tokenizer.convert_tokens_to_ids(MASK_AUDIO_TOKEN)) |
| self.BOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(SOA_TOKEN)) |
| self.EOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(EOA_TOKEN)) |
| self._assistant_audio_placeholder = f"{SOA_TOKEN}{EOA_TOKEN}" |
| self._chat_template_kwargs = {"enable_thinking": False} |
|
|
| def __len__(self): |
| return len(self._data) |
|
|
| @staticmethod |
| def _positions(token_ids: torch.Tensor, target_id: int) -> torch.Tensor: |
| return torch.nonzero(token_ids == target_id, as_tuple=False).squeeze(-1) |
|
|
| @staticmethod |
| def _sorted_sections(sample: dict) -> list[dict]: |
| return sorted( |
| ( |
| { |
| "raw_index": raw_index, |
| "text": str(seg["text"]), |
| "desc": str(seg["desc"]).strip(), |
| "start": float(seg["start"]), |
| "end": float(seg["end"]), |
| "structure": normalize_structure_label(seg["section"]), |
| } |
| for raw_index, seg in enumerate(sample.get("sections", [])) |
| ), |
| key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]), |
| ) |
|
|
| @staticmethod |
| def _sorted_chords(sample: dict) -> list[dict]: |
| return sorted( |
| ( |
| { |
| "raw_index": raw_index, |
| "type": str(seg.get("type")), |
| "start": float(seg.get("start", 0.0)), |
| "end": float(seg.get("end", 0.0)), |
| } |
| for raw_index, seg in enumerate(sample.get("chords", [])) |
| ), |
| key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]), |
| ) |
|
|
| def __getitem__(self, idx): |
| sample = self._data[idx] |
| sections = self._prepare_sections(sample) |
| chords = self._prepare_chords(sample) |
| token_ids, attention_mask, frame_idx_map = self._tokenize_messages( |
| self._build_messages(sample, sections), |
| sample["mucodec_codes"], |
| sections, |
| ) |
|
|
| total_frames = len(sample["mucodec_codes"]) |
| structure_ids = build_frame_structure_ids(sections, total_frames, fps=self.fps) |
| chord_ids = build_frame_chord_ids(chords, total_frames, fps=self.fps) |
|
|
| structure_ids = torch.from_numpy(structure_ids) |
| chord_ids = torch.from_numpy(chord_ids) |
|
|
| ( |
| audio_codebook_mask, |
| bos_audio_mask, |
| eos_mask, |
| label_mask, |
| condition_mask, |
| ) = self._build_token_masks(token_ids) |
|
|
| chord_ids_aligned, structure_ids_aligned = self._align_condition_ids( |
| token_ids=token_ids, |
| frame_idx_map=frame_idx_map, |
| total_frames=total_frames, |
| chord_ids=chord_ids, |
| structure_ids=structure_ids, |
| audio_codebook_mask=audio_codebook_mask, |
| bos_audio_mask=bos_audio_mask, |
| eos_mask=eos_mask, |
| ) |
|
|
| return { |
| "token_ids": token_ids, |
| "mask": label_mask, |
| "attention_mask": attention_mask, |
| "chord_ids": chord_ids_aligned, |
| "structure_ids": structure_ids_aligned, |
| "condition_mask": condition_mask, |
| } |
|
|
| def _tokenize_messages( |
| self, |
| messages: list[dict[str, str]], |
| full_audio_codes, |
| sections: list[dict], |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
| chat_inputs = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=False, |
| return_tensors="pt", |
| return_dict=True, |
| **self._chat_template_kwargs, |
| ) |
|
|
| token_ids = chat_inputs["input_ids"] |
| attention_mask = chat_inputs["attention_mask"] |
|
|
| token_ids = token_ids.squeeze(0) |
| attention_mask = attention_mask.squeeze(0) |
|
|
| token_ids = token_ids.to(torch.long) |
| attention_mask = attention_mask.to(torch.long) |
|
|
| return self._expand_audio_tokens( |
| token_ids=token_ids, |
| attention_mask=attention_mask, |
| full_audio_codes=full_audio_codes, |
| sections=sections, |
| ) |
|
|
| def _frame_bounds( |
| self, |
| start: float, |
| end: float, |
| total_frames: int, |
| prev_end_idx: int = 0, |
| ) -> tuple[int, int]: |
| start_idx = int(start * self.fps) |
| end_idx = int(math.ceil(end * self.fps)) |
| start_idx = max(prev_end_idx, min(total_frames, start_idx)) |
| end_idx = max(start_idx, min(total_frames, end_idx)) |
|
|
| return start_idx, end_idx |
|
|
| def _prepare_sections(self, sample: dict) -> list[dict]: |
| sections = [] |
| section_counts: dict[str, int] = {} |
| sample_language = sample.get("language") |
| total_frames = len(sample["mucodec_codes"]) |
| prev_end_idx = 0 |
|
|
| for seg in self._sorted_sections(sample): |
| structure = seg["structure"] |
| section_counts[structure] = section_counts.get(structure, 0) + 1 |
| raw_start_idx = max(0, min(total_frames, int(seg["start"] * self.fps))) |
| raw_end_idx = max( |
| raw_start_idx, |
| min(total_frames, int(math.ceil(seg["end"] * self.fps))), |
| ) |
| start_idx = prev_end_idx |
| end_idx = max(start_idx, raw_end_idx) |
|
|
| sections.append( |
| { |
| "text": normalize_section_text( |
| seg["text"], structure, language=sample_language |
| ), |
| "desc": seg["desc"], |
| "start": start_idx / float(self.fps), |
| "end": end_idx / float(self.fps), |
| "start_frame": start_idx, |
| "end_frame": end_idx, |
| "structure": structure, |
| "tag": f"{structure}{section_counts[structure]}", |
| "index": section_counts[structure], |
| } |
| ) |
| prev_end_idx = end_idx |
|
|
| if sections: |
| sections[-1]["end_frame"] = total_frames |
| sections[-1]["end"] = total_frames / float(self.fps) |
|
|
| return sections |
|
|
| def _prepare_chords(self, sample: dict) -> list[dict]: |
| chords = [] |
| total_frames = len(sample["mucodec_codes"]) |
| prev_end_idx = 0 |
|
|
| for seg in self._sorted_chords(sample): |
| start_idx, end_idx = self._frame_bounds( |
| seg["start"], |
| seg["end"], |
| total_frames, |
| prev_end_idx=prev_end_idx, |
| ) |
|
|
| chords.append( |
| { |
| "type": seg["type"], |
| "start": start_idx / float(self.fps), |
| "end": end_idx / float(self.fps), |
| "start_frame": start_idx, |
| "end_frame": end_idx, |
| } |
| ) |
| prev_end_idx = end_idx |
|
|
| return chords |
|
|
| def _format_section_label(self, section: dict) -> str: |
| structure = section["structure"] |
| index = section["index"] |
| label = SECTION_NAME_MAP[structure] |
| if structure in SINGLETON_SECTION_NAMES and index == 1: |
| return label |
| return f"{label} {index}" |
|
|
| def _build_section_user_content( |
| self, sample: dict, section: dict, is_first_turn: bool |
| ) -> str: |
| parts = [] |
| if is_first_turn: |
| style = sample["style"].strip() |
| if style: |
| parts.append( |
| f"Please generate a song in the following style:{style}\n" |
| "Next, I will tell you the requirements and lyrics for the song " |
| "fragment to be generated, section by section." |
| ) |
| else: |
| parts.append( |
| "Please generate the song section by section. " |
| "Next, I will tell you the requirements and lyrics for each fragment." |
| ) |
|
|
| section_parts = [f"[{self._format_section_label(section)}]"] |
| desc = section["desc"] |
| if desc: |
| section_parts.append(f"[desc:{desc}]") |
|
|
| lyrics = section["text"] |
| if lyrics: |
| section_parts.append(f"[lyrics:{lyrics}]") |
|
|
| parts.append("".join(section_parts)) |
|
|
| return "\n".join(part for part in parts if part) |
|
|
| def _build_messages( |
| self, |
| sample: dict, |
| sections: list[dict], |
| ) -> list[dict[str, str]]: |
| messages: list[dict[str, str]] = [None] * (2 * len(sections)) |
|
|
| for i, section in enumerate(sections): |
| msg_idx = 2 * i |
| messages[msg_idx] = { |
| "role": "user", |
| "content": self._build_section_user_content( |
| sample, section, is_first_turn=(i == 0) |
| ), |
| } |
| messages[msg_idx + 1] = { |
| "role": "assistant", |
| "content": self._assistant_audio_placeholder, |
| } |
|
|
| return messages |
|
|
| def _expand_audio_tokens( |
| self, |
| token_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| full_audio_codes, |
| sections: list[dict], |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
| if not sections: |
| return ( |
| token_ids, |
| attention_mask, |
| torch.full(token_ids.shape, -1, dtype=torch.long), |
| ) |
|
|
| bos_positions = self._positions(token_ids, self.BOS_AUDIO) |
| eos_positions = self._positions(token_ids, self.EOS_AUDIO) |
|
|
| audio_code_tensor = torch.as_tensor(full_audio_codes, dtype=torch.long) |
| extra_audio_tokens = sum( |
| int(section["end_frame"]) - int(section["start_frame"]) |
| for section in sections |
| ) |
| final_len = token_ids.numel() + extra_audio_tokens |
|
|
| expanded_token_ids = torch.empty(final_len, dtype=torch.long) |
| expanded_attention_mask = torch.empty(final_len, dtype=torch.long) |
| frame_idx_map = torch.full((final_len,), -1, dtype=torch.long) |
|
|
| read_pos = 0 |
| write_pos = 0 |
|
|
| for bos_pos, eos_pos, section in zip( |
| bos_positions.tolist(), eos_positions.tolist(), sections |
| ): |
| start_idx = int(section["start_frame"]) |
| end_idx = int(section["end_frame"]) |
| audio_len = end_idx - start_idx |
|
|
| prefix_len = bos_pos + 1 - read_pos |
| next_write = write_pos + prefix_len |
| expanded_token_ids[write_pos:next_write] = token_ids[read_pos : bos_pos + 1] |
| expanded_attention_mask[write_pos:next_write] = attention_mask[ |
| read_pos : bos_pos + 1 |
| ] |
| frame_idx_map[next_write - 1] = start_idx if audio_len > 0 else -1 |
| write_pos = next_write |
|
|
| if audio_len > 0: |
| next_write = write_pos + audio_len |
| expanded_token_ids[write_pos:next_write] = audio_code_tensor[ |
| start_idx:end_idx |
| ] |
| expanded_token_ids[write_pos:next_write].add_(self.audio_prefix_length) |
| expanded_attention_mask[write_pos:next_write] = 1 |
| frame_idx_map[write_pos:next_write] = torch.arange( |
| start_idx, end_idx, dtype=torch.long |
| ) |
| write_pos = next_write |
|
|
| expanded_token_ids[write_pos] = token_ids[eos_pos] |
| expanded_attention_mask[write_pos] = attention_mask[eos_pos] |
| frame_idx_map[write_pos] = end_idx - 1 if audio_len > 0 else -1 |
| write_pos += 1 |
| read_pos = eos_pos + 1 |
|
|
| tail_len = token_ids.numel() - read_pos |
| if tail_len > 0: |
| expanded_token_ids[write_pos : write_pos + tail_len] = token_ids[read_pos:] |
| expanded_attention_mask[write_pos : write_pos + tail_len] = attention_mask[ |
| read_pos: |
| ] |
|
|
| return expanded_token_ids, expanded_attention_mask, frame_idx_map |
|
|
| def _build_token_masks( |
| self, token_ids: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
| audio_codebook_mask = (token_ids >= self.audio_prefix_length) & ( |
| token_ids < self.MASK_AUDIO |
| ) |
| bos_audio_mask = token_ids == self.BOS_AUDIO |
| eos_mask = token_ids == self.EOS_AUDIO |
| label_mask = (audio_codebook_mask | eos_mask).long() |
| condition_mask = audio_codebook_mask | bos_audio_mask | eos_mask |
|
|
| return audio_codebook_mask, bos_audio_mask, eos_mask, label_mask, condition_mask |
|
|
| def _align_condition_ids( |
| self, |
| token_ids: torch.Tensor, |
| frame_idx_map: torch.Tensor, |
| total_frames: int, |
| chord_ids: torch.Tensor, |
| structure_ids: torch.Tensor, |
| audio_codebook_mask: torch.Tensor, |
| bos_audio_mask: torch.Tensor, |
| eos_mask: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
| seq_len = token_ids.numel() |
| chord_ids_aligned = torch.zeros(seq_len, dtype=torch.long) |
| structure_ids_aligned = torch.zeros(seq_len, dtype=torch.long) |
|
|
| bos_positions = torch.nonzero(bos_audio_mask, as_tuple=False).squeeze(-1) |
| chord_ids_aligned[bos_positions] = CHORD_BOS_ID |
| structure_ids_aligned[bos_positions] = STRUCTURE_BOS_ID |
|
|
| audio_positions = torch.nonzero(audio_codebook_mask, as_tuple=False).squeeze(-1) |
| cur_frame_idx = frame_idx_map[audio_positions] |
| cur_frame_idx = cur_frame_idx.clamp(0, max(total_frames - 1, 0)) |
| chord_ids_aligned[audio_positions] = chord_ids[cur_frame_idx] |
| structure_ids_aligned[audio_positions] = structure_ids[cur_frame_idx] |
|
|
| eos_positions = torch.nonzero(eos_mask, as_tuple=False).squeeze(-1) |
| chord_ids_aligned[eos_positions] = CHORD_EOS_ID |
| structure_ids_aligned[eos_positions] = STRUCTURE_EOS_ID |
|
|
| return chord_ids_aligned, structure_ids_aligned |
|
|