cond_gen / dataset.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
"""Dataset/collate implementation for music training data."""
import math
import re
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer
from audio_tokens import (
EOA_TOKEN,
MASK_AUDIO_TOKEN,
SOA_TOKEN,
add_audio_special_tokens,
audio_id_to_token,
)
from vocab import (
CHORD_BOS_ID,
CHORD_EOS_ID,
STRUCTURE_BOS_ID,
STRUCTURE_EOS_ID,
build_frame_chord_ids,
build_frame_structure_ids,
normalize_structure_label,
)
CN_LANGUAGE_LABELS = {"cn", "zh", "zh-cn", "chinese"}
SECTION_NAME_MAP = {
"intro": "Intro",
"verse": "Verse",
"chorus": "Chorus",
"prechorus": "Pre-Chorus",
"bridge": "Bridge",
"outro": "Outro",
"pad": "Pad",
}
SINGLETON_SECTION_NAMES = {"intro", "outro", "pad"}
ENDING_PUNCTUATION = {".", ";", "!", "?", "。", "?", "!", ";"}
def _pad_batch_field(batch, key: str, padding_value):
return pad_sequence(
[row[key] for row in batch],
batch_first=True,
padding_value=padding_value,
)
def detect_language(text: str, language: str | None = None) -> str:
return (
text.replace(" ", ";")
if str(language).strip().lower() in CN_LANGUAGE_LABELS
else text
)
def normalize_section_text(
text: str, structure: str, language: str | None = None
) -> str:
text = str(text or "")
text = (
text.replace(f"[{structure.upper()}]", "")
.replace(f"[{structure.lower()}]", "")
.replace(",", ";")
.replace(".", ";")
.replace(",", ";")
.replace("。", ";")
)
text = detect_language(text, language=language)
text = re.sub(r";(?=[A-Za-z])", "; ", text)
if text and text[-1] not in ENDING_PUNCTUATION:
text += ";"
return text
class DataCollate:
def __call__(self, batch):
input_ids = _pad_batch_field(batch, "token_ids", 0)
labels = input_ids
mask_padded = _pad_batch_field(batch, "mask", 0)
attention_mask_padded = _pad_batch_field(batch, "attention_mask", 0)
chord_ids_padded = _pad_batch_field(batch, "chord_ids", 0)
structure_ids_padded = _pad_batch_field(batch, "structure_ids", 0)
condition_mask_padded = _pad_batch_field(batch, "condition_mask", False)
return {
"input_ids": input_ids,
"labels": labels,
"masks": mask_padded,
"attention_mask": attention_mask_padded,
"chord_ids": chord_ids_padded,
"structure_ids": structure_ids_padded,
"condition_mask": condition_mask_padded,
}
class MusicDataset(torch.utils.data.Dataset):
"""Fly dataset with music-code tokens and section-conditioned text."""
def __init__(
self,
datasets,
split: str,
tokenizer_path: str,
num_audio_token=16384,
fps=25,
use_fast=True,
):
self._data = datasets[split]
self.tokenizer_path = tokenizer_path
self.use_fast = use_fast
self.num_audio_token = num_audio_token
self.fps = fps
self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_path,
local_files_only=True,
use_fast=self.use_fast,
)
add_audio_special_tokens(self.tokenizer, self.num_audio_token)
self.tokenizer_vocab_size = len(self.tokenizer)
self.audio_prefix_length = int(
self.tokenizer.convert_tokens_to_ids(audio_id_to_token(0))
)
self.num_text_token = self.audio_prefix_length
self.MASK_AUDIO = int(self.tokenizer.convert_tokens_to_ids(MASK_AUDIO_TOKEN))
self.BOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(SOA_TOKEN))
self.EOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(EOA_TOKEN))
self._assistant_audio_placeholder = f"{SOA_TOKEN}{EOA_TOKEN}"
self._chat_template_kwargs = {"enable_thinking": False}
def __len__(self):
return len(self._data)
@staticmethod
def _positions(token_ids: torch.Tensor, target_id: int) -> torch.Tensor:
return torch.nonzero(token_ids == target_id, as_tuple=False).squeeze(-1)
@staticmethod
def _sorted_sections(sample: dict) -> list[dict]:
return sorted(
(
{
"raw_index": raw_index,
"text": str(seg["text"]),
"desc": str(seg["desc"]).strip(),
"start": float(seg["start"]),
"end": float(seg["end"]),
"structure": normalize_structure_label(seg["section"]),
}
for raw_index, seg in enumerate(sample.get("sections", []))
),
key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]),
)
@staticmethod
def _sorted_chords(sample: dict) -> list[dict]:
return sorted(
(
{
"raw_index": raw_index,
"type": str(seg.get("type")),
"start": float(seg.get("start", 0.0)),
"end": float(seg.get("end", 0.0)),
}
for raw_index, seg in enumerate(sample.get("chords", []))
),
key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]),
)
def __getitem__(self, idx):
sample = self._data[idx]
sections = self._prepare_sections(sample)
chords = self._prepare_chords(sample)
token_ids, attention_mask, frame_idx_map = self._tokenize_messages(
self._build_messages(sample, sections),
sample["mucodec_codes"],
sections,
)
total_frames = len(sample["mucodec_codes"])
structure_ids = build_frame_structure_ids(sections, total_frames, fps=self.fps)
chord_ids = build_frame_chord_ids(chords, total_frames, fps=self.fps)
structure_ids = torch.from_numpy(structure_ids)
chord_ids = torch.from_numpy(chord_ids)
(
audio_codebook_mask,
bos_audio_mask,
eos_mask,
label_mask,
condition_mask,
) = self._build_token_masks(token_ids)
chord_ids_aligned, structure_ids_aligned = self._align_condition_ids(
token_ids=token_ids,
frame_idx_map=frame_idx_map,
total_frames=total_frames,
chord_ids=chord_ids,
structure_ids=structure_ids,
audio_codebook_mask=audio_codebook_mask,
bos_audio_mask=bos_audio_mask,
eos_mask=eos_mask,
)
return {
"token_ids": token_ids,
"mask": label_mask,
"attention_mask": attention_mask,
"chord_ids": chord_ids_aligned,
"structure_ids": structure_ids_aligned,
"condition_mask": condition_mask,
}
def _tokenize_messages(
self,
messages: list[dict[str, str]],
full_audio_codes,
sections: list[dict],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
chat_inputs = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=False,
return_tensors="pt",
return_dict=True,
**self._chat_template_kwargs,
)
token_ids = chat_inputs["input_ids"]
attention_mask = chat_inputs["attention_mask"]
token_ids = token_ids.squeeze(0)
attention_mask = attention_mask.squeeze(0)
token_ids = token_ids.to(torch.long)
attention_mask = attention_mask.to(torch.long)
return self._expand_audio_tokens(
token_ids=token_ids,
attention_mask=attention_mask,
full_audio_codes=full_audio_codes,
sections=sections,
)
def _frame_bounds(
self,
start: float,
end: float,
total_frames: int,
prev_end_idx: int = 0,
) -> tuple[int, int]:
start_idx = int(start * self.fps)
end_idx = int(math.ceil(end * self.fps))
start_idx = max(prev_end_idx, min(total_frames, start_idx))
end_idx = max(start_idx, min(total_frames, end_idx))
return start_idx, end_idx
def _prepare_sections(self, sample: dict) -> list[dict]:
sections = []
section_counts: dict[str, int] = {}
sample_language = sample.get("language")
total_frames = len(sample["mucodec_codes"])
prev_end_idx = 0
for seg in self._sorted_sections(sample):
structure = seg["structure"]
section_counts[structure] = section_counts.get(structure, 0) + 1
raw_start_idx = max(0, min(total_frames, int(seg["start"] * self.fps)))
raw_end_idx = max(
raw_start_idx,
min(total_frames, int(math.ceil(seg["end"] * self.fps))),
)
start_idx = prev_end_idx
end_idx = max(start_idx, raw_end_idx)
sections.append(
{
"text": normalize_section_text(
seg["text"], structure, language=sample_language
),
"desc": seg["desc"],
"start": start_idx / float(self.fps),
"end": end_idx / float(self.fps),
"start_frame": start_idx,
"end_frame": end_idx,
"structure": structure,
"tag": f"{structure}{section_counts[structure]}",
"index": section_counts[structure],
}
)
prev_end_idx = end_idx
if sections:
sections[-1]["end_frame"] = total_frames
sections[-1]["end"] = total_frames / float(self.fps)
return sections
def _prepare_chords(self, sample: dict) -> list[dict]:
chords = []
total_frames = len(sample["mucodec_codes"])
prev_end_idx = 0
for seg in self._sorted_chords(sample):
start_idx, end_idx = self._frame_bounds(
seg["start"],
seg["end"],
total_frames,
prev_end_idx=prev_end_idx,
)
chords.append(
{
"type": seg["type"],
"start": start_idx / float(self.fps),
"end": end_idx / float(self.fps),
"start_frame": start_idx,
"end_frame": end_idx,
}
)
prev_end_idx = end_idx
return chords
def _format_section_label(self, section: dict) -> str:
structure = section["structure"]
index = section["index"]
label = SECTION_NAME_MAP[structure]
if structure in SINGLETON_SECTION_NAMES and index == 1:
return label
return f"{label} {index}"
def _build_section_user_content(
self, sample: dict, section: dict, is_first_turn: bool
) -> str:
parts = []
if is_first_turn:
style = sample["style"].strip()
if style:
parts.append(
f"Please generate a song in the following style:{style}\n"
"Next, I will tell you the requirements and lyrics for the song "
"fragment to be generated, section by section."
)
else:
parts.append(
"Please generate the song section by section. "
"Next, I will tell you the requirements and lyrics for each fragment."
)
section_parts = [f"[{self._format_section_label(section)}]"]
desc = section["desc"]
if desc:
section_parts.append(f"[desc:{desc}]")
lyrics = section["text"]
if lyrics:
section_parts.append(f"[lyrics:{lyrics}]")
parts.append("".join(section_parts))
return "\n".join(part for part in parts if part)
def _build_messages(
self,
sample: dict,
sections: list[dict],
) -> list[dict[str, str]]:
messages: list[dict[str, str]] = [None] * (2 * len(sections))
for i, section in enumerate(sections):
msg_idx = 2 * i
messages[msg_idx] = {
"role": "user",
"content": self._build_section_user_content(
sample, section, is_first_turn=(i == 0)
),
}
messages[msg_idx + 1] = {
"role": "assistant",
"content": self._assistant_audio_placeholder,
}
return messages
def _expand_audio_tokens(
self,
token_ids: torch.Tensor,
attention_mask: torch.Tensor,
full_audio_codes,
sections: list[dict],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not sections:
return (
token_ids,
attention_mask,
torch.full(token_ids.shape, -1, dtype=torch.long),
)
bos_positions = self._positions(token_ids, self.BOS_AUDIO)
eos_positions = self._positions(token_ids, self.EOS_AUDIO)
audio_code_tensor = torch.as_tensor(full_audio_codes, dtype=torch.long)
extra_audio_tokens = sum(
int(section["end_frame"]) - int(section["start_frame"])
for section in sections
)
final_len = token_ids.numel() + extra_audio_tokens
expanded_token_ids = torch.empty(final_len, dtype=torch.long)
expanded_attention_mask = torch.empty(final_len, dtype=torch.long)
frame_idx_map = torch.full((final_len,), -1, dtype=torch.long)
read_pos = 0
write_pos = 0
for bos_pos, eos_pos, section in zip(
bos_positions.tolist(), eos_positions.tolist(), sections
):
start_idx = int(section["start_frame"])
end_idx = int(section["end_frame"])
audio_len = end_idx - start_idx
prefix_len = bos_pos + 1 - read_pos
next_write = write_pos + prefix_len
expanded_token_ids[write_pos:next_write] = token_ids[read_pos : bos_pos + 1]
expanded_attention_mask[write_pos:next_write] = attention_mask[
read_pos : bos_pos + 1
]
frame_idx_map[next_write - 1] = start_idx if audio_len > 0 else -1
write_pos = next_write
if audio_len > 0:
next_write = write_pos + audio_len
expanded_token_ids[write_pos:next_write] = audio_code_tensor[
start_idx:end_idx
]
expanded_token_ids[write_pos:next_write].add_(self.audio_prefix_length)
expanded_attention_mask[write_pos:next_write] = 1
frame_idx_map[write_pos:next_write] = torch.arange(
start_idx, end_idx, dtype=torch.long
)
write_pos = next_write
expanded_token_ids[write_pos] = token_ids[eos_pos]
expanded_attention_mask[write_pos] = attention_mask[eos_pos]
frame_idx_map[write_pos] = end_idx - 1 if audio_len > 0 else -1
write_pos += 1
read_pos = eos_pos + 1
tail_len = token_ids.numel() - read_pos
if tail_len > 0:
expanded_token_ids[write_pos : write_pos + tail_len] = token_ids[read_pos:]
expanded_attention_mask[write_pos : write_pos + tail_len] = attention_mask[
read_pos:
]
return expanded_token_ids, expanded_attention_mask, frame_idx_map
def _build_token_masks(
self, token_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
audio_codebook_mask = (token_ids >= self.audio_prefix_length) & (
token_ids < self.MASK_AUDIO
)
bos_audio_mask = token_ids == self.BOS_AUDIO
eos_mask = token_ids == self.EOS_AUDIO
label_mask = (audio_codebook_mask | eos_mask).long()
condition_mask = audio_codebook_mask | bos_audio_mask | eos_mask
return audio_codebook_mask, bos_audio_mask, eos_mask, label_mask, condition_mask
def _align_condition_ids(
self,
token_ids: torch.Tensor,
frame_idx_map: torch.Tensor,
total_frames: int,
chord_ids: torch.Tensor,
structure_ids: torch.Tensor,
audio_codebook_mask: torch.Tensor,
bos_audio_mask: torch.Tensor,
eos_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = token_ids.numel()
chord_ids_aligned = torch.zeros(seq_len, dtype=torch.long)
structure_ids_aligned = torch.zeros(seq_len, dtype=torch.long)
bos_positions = torch.nonzero(bos_audio_mask, as_tuple=False).squeeze(-1)
chord_ids_aligned[bos_positions] = CHORD_BOS_ID
structure_ids_aligned[bos_positions] = STRUCTURE_BOS_ID
audio_positions = torch.nonzero(audio_codebook_mask, as_tuple=False).squeeze(-1)
cur_frame_idx = frame_idx_map[audio_positions]
cur_frame_idx = cur_frame_idx.clamp(0, max(total_frames - 1, 0))
chord_ids_aligned[audio_positions] = chord_ids[cur_frame_idx]
structure_ids_aligned[audio_positions] = structure_ids[cur_frame_idx]
eos_positions = torch.nonzero(eos_mask, as_tuple=False).squeeze(-1)
chord_ids_aligned[eos_positions] = CHORD_EOS_ID
structure_ids_aligned[eos_positions] = STRUCTURE_EOS_ID
return chord_ids_aligned, structure_ids_aligned