| import torch |
| import torch.nn.functional as F |
| from typing import Any, Optional |
|
|
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.models.qwen3.modeling_qwen3 import Qwen3Config, Qwen3ForCausalLM |
| from transformers.cache_utils import Cache |
|
|
| from decoders import Qwen3DecoderLayerAdaLN |
| from condition_encoders import ConditionEncoder |
| from vocab import ( |
| CHORD_BOS_ID, |
| CHORD_EOS_ID, |
| CHORD_N_ID, |
| SEGMENT_FALLBACK_ID, |
| STRUCTURE_BOS_ID, |
| STRUCTURE_EOS_ID, |
| ) |
|
|
|
|
| class MAGEL(Qwen3ForCausalLM): |
| """ |
| - masks-based CE loss |
| - decoder layers replaced with Qwen3DecoderLayerAdaLN |
| """ |
|
|
| def __init__( |
| self, |
| config: Qwen3Config, |
| **kwargs: Any, |
| ): |
| super().__init__(config) |
|
|
| adaln_dim = int(config.hidden_size) |
| chord_dropout_trigger_prob = float(config.magel_chord_dropout_trigger_prob) |
| structure_dropout_trigger_prob = float(config.magel_structure_dropout_trigger_prob) |
|
|
| self.vocab_size = config.vocab_size |
| self.adaln_dim = adaln_dim |
|
|
| self.condition_encoder = ConditionEncoder(hidden_size=adaln_dim) |
| self.chord_dropout_trigger_prob = chord_dropout_trigger_prob |
| self.structure_dropout_trigger_prob = structure_dropout_trigger_prob |
|
|
| for layer_idx in range(len(self.model.layers)): |
| self.model.layers[layer_idx] = Qwen3DecoderLayerAdaLN( |
| config, |
| layer_idx=layer_idx, |
| cond_dim=adaln_dim, |
| ) |
|
|
| |
| |
| self.config.magel_chord_dropout_trigger_prob = chord_dropout_trigger_prob |
| self.config.magel_structure_dropout_trigger_prob = structure_dropout_trigger_prob |
|
|
| self.post_init() |
|
|
| @staticmethod |
| def _drop_audio_condition_spans( |
| ids: torch.LongTensor, |
| condition_mask: torch.BoolTensor, |
| trigger_prob: float, |
| replacement_id: int, |
| bos_id: int, |
| eos_id: int, |
| ) -> torch.LongTensor: |
| if trigger_prob <= 0.0: |
| return ids |
|
|
| |
| eligible_mask = condition_mask & (ids != bos_id) & (ids != eos_id) |
|
|
| if not eligible_mask.any(): |
| return ids |
|
|
| dropped = ids.clone() |
| trigger_mask = torch.rand(ids.size(0), device=ids.device) < trigger_prob |
| span_len = 25 |
|
|
| for batch_idx in torch.nonzero(trigger_mask, as_tuple=False).flatten(): |
| candidate_positions = torch.nonzero( |
| eligible_mask[batch_idx], as_tuple=False |
| ).flatten() |
| num_candidates = int(candidate_positions.numel()) |
| if num_candidates == 0: |
| continue |
| drop_ratio = torch.rand((), device=ids.device).item() |
| num_to_drop = int(round(drop_ratio * num_candidates)) |
| if num_to_drop <= 0: |
| continue |
|
|
| remaining = num_to_drop |
| available_positions = candidate_positions.clone() |
| while remaining > 0: |
| num_available = int(available_positions.numel()) |
| if num_available == 0: |
| break |
|
|
| cur_span_len = min(span_len, remaining) |
| if num_available <= cur_span_len: |
| start_idx = 0 |
| selected_positions = available_positions[:cur_span_len] |
| else: |
| max_start = num_available - cur_span_len + 1 |
| start_idx = int( |
| torch.randint(0, max_start, (1,), device=ids.device).item() |
| ) |
| selected_positions = available_positions[ |
| start_idx : start_idx + cur_span_len |
| ] |
| dropped[batch_idx, selected_positions] = replacement_id |
|
|
| keep_mask = torch.ones( |
| num_available, |
| dtype=torch.bool, |
| device=ids.device, |
| ) |
| keep_mask[start_idx : start_idx + int(selected_positions.numel())] = False |
| available_positions = available_positions[keep_mask] |
| remaining -= int(selected_positions.numel()) |
|
|
| return dropped |
|
|
| def _build_condition( |
| self, |
| chord_ids: Optional[torch.LongTensor], |
| structure_ids: Optional[torch.LongTensor], |
| condition_mask: Optional[torch.BoolTensor], |
| cond_precomputed: Optional[torch.FloatTensor], |
| ) -> Optional[torch.FloatTensor]: |
| if cond_precomputed is not None: |
| return cond_precomputed |
| if chord_ids is None or structure_ids is None: |
| return None |
| if self.training: |
| if condition_mask is None: |
| raise ValueError("condition_mask is required during training.") |
| chord_ids = self._drop_audio_condition_spans( |
| ids=chord_ids, |
| condition_mask=condition_mask, |
| trigger_prob=self.chord_dropout_trigger_prob, |
| replacement_id=CHORD_N_ID, |
| bos_id=CHORD_BOS_ID, |
| eos_id=CHORD_EOS_ID, |
| ) |
| structure_ids = self._drop_audio_condition_spans( |
| ids=structure_ids, |
| condition_mask=condition_mask, |
| trigger_prob=self.structure_dropout_trigger_prob, |
| replacement_id=SEGMENT_FALLBACK_ID, |
| bos_id=STRUCTURE_BOS_ID, |
| eos_id=STRUCTURE_EOS_ID, |
| ) |
| return self.condition_encoder(chord_ids, structure_ids) |
|
|
| def ce_loss( |
| self, |
| logits: torch.FloatTensor, |
| labels: Optional[torch.LongTensor], |
| masks: Optional[torch.LongTensor], |
| ) -> Optional[torch.Tensor]: |
| if labels is None or masks is None: |
| return None |
|
|
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].clone() |
| valid_token_mask = masks[:, 1:].bool().contiguous() |
|
|
| if not valid_token_mask.any(): |
| return shift_logits.new_zeros(()) |
|
|
| shift_labels.masked_fill_(~valid_token_mask, -100) |
| loss_sum = F.cross_entropy( |
| shift_logits.view(-1, self.config.vocab_size), |
| shift_labels.view(-1).to(shift_logits.device), |
| ignore_index=-100, |
| reduction="sum", |
| ) |
| valid_count = valid_token_mask.sum().to( |
| device=loss_sum.device, |
| dtype=loss_sum.dtype, |
| ) |
| return loss_sum / valid_count.clamp_min(1) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| masks: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| chord_ids: Optional[torch.LongTensor] = None, |
| structure_ids: Optional[torch.LongTensor] = None, |
| condition_mask: Optional[torch.BoolTensor] = None, |
| cond_precomputed: Optional[torch.FloatTensor] = None, |
| ) -> CausalLMOutputWithPast: |
|
|
| if use_cache is None: |
| use_cache = self.config.use_cache |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
| cond = self._build_condition( |
| chord_ids=chord_ids, |
| structure_ids=structure_ids, |
| condition_mask=condition_mask, |
| cond_precomputed=cond_precomputed, |
| ) |
|
|
| base_out = self.model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cond_expanded=cond, |
| condition_mask=condition_mask, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = base_out.last_hidden_state |
| logits = self.lm_head(hidden_states) |
| loss = self.ce_loss(logits=logits, labels=labels, masks=masks) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=base_out.past_key_values, |
| hidden_states=base_out.hidden_states, |
| attentions=base_out.attentions, |
| ) |
|
|