import torch import torch.nn.functional as F from typing import Any, Optional from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.qwen3.modeling_qwen3 import Qwen3Config, Qwen3ForCausalLM from transformers.cache_utils import Cache from decoders import Qwen3DecoderLayerAdaLN from condition_encoders import ConditionEncoder from vocab import ( CHORD_BOS_ID, CHORD_EOS_ID, CHORD_N_ID, SEGMENT_FALLBACK_ID, STRUCTURE_BOS_ID, STRUCTURE_EOS_ID, ) class MAGEL(Qwen3ForCausalLM): """ - masks-based CE loss - decoder layers replaced with Qwen3DecoderLayerAdaLN """ def __init__( self, config: Qwen3Config, **kwargs: Any, ): super().__init__(config) adaln_dim = int(config.hidden_size) chord_dropout_trigger_prob = float(config.magel_chord_dropout_trigger_prob) structure_dropout_trigger_prob = float(config.magel_structure_dropout_trigger_prob) self.vocab_size = config.vocab_size self.adaln_dim = adaln_dim self.condition_encoder = ConditionEncoder(hidden_size=adaln_dim) self.chord_dropout_trigger_prob = chord_dropout_trigger_prob self.structure_dropout_trigger_prob = structure_dropout_trigger_prob for layer_idx in range(len(self.model.layers)): self.model.layers[layer_idx] = Qwen3DecoderLayerAdaLN( config, layer_idx=layer_idx, cond_dim=adaln_dim, ) # Persist MAGEL-specific ctor args so checkpoints can be reloaded without # out-of-band flags. self.config.magel_chord_dropout_trigger_prob = chord_dropout_trigger_prob self.config.magel_structure_dropout_trigger_prob = structure_dropout_trigger_prob self.post_init() @staticmethod def _drop_audio_condition_spans( ids: torch.LongTensor, condition_mask: torch.BoolTensor, trigger_prob: float, replacement_id: int, bos_id: int, eos_id: int, ) -> torch.LongTensor: if trigger_prob <= 0.0: return ids # Only drop aligned audio-condition positions; keep BOS/EOS untouched. eligible_mask = condition_mask & (ids != bos_id) & (ids != eos_id) if not eligible_mask.any(): return ids dropped = ids.clone() trigger_mask = torch.rand(ids.size(0), device=ids.device) < trigger_prob span_len = 25 for batch_idx in torch.nonzero(trigger_mask, as_tuple=False).flatten(): candidate_positions = torch.nonzero( eligible_mask[batch_idx], as_tuple=False ).flatten() num_candidates = int(candidate_positions.numel()) if num_candidates == 0: continue drop_ratio = torch.rand((), device=ids.device).item() num_to_drop = int(round(drop_ratio * num_candidates)) if num_to_drop <= 0: continue remaining = num_to_drop available_positions = candidate_positions.clone() while remaining > 0: num_available = int(available_positions.numel()) if num_available == 0: break cur_span_len = min(span_len, remaining) if num_available <= cur_span_len: start_idx = 0 selected_positions = available_positions[:cur_span_len] else: max_start = num_available - cur_span_len + 1 start_idx = int( torch.randint(0, max_start, (1,), device=ids.device).item() ) selected_positions = available_positions[ start_idx : start_idx + cur_span_len ] dropped[batch_idx, selected_positions] = replacement_id keep_mask = torch.ones( num_available, dtype=torch.bool, device=ids.device, ) keep_mask[start_idx : start_idx + int(selected_positions.numel())] = False available_positions = available_positions[keep_mask] remaining -= int(selected_positions.numel()) return dropped def _build_condition( self, chord_ids: Optional[torch.LongTensor], structure_ids: Optional[torch.LongTensor], condition_mask: Optional[torch.BoolTensor], cond_precomputed: Optional[torch.FloatTensor], ) -> Optional[torch.FloatTensor]: if cond_precomputed is not None: return cond_precomputed if chord_ids is None or structure_ids is None: return None if self.training: if condition_mask is None: raise ValueError("condition_mask is required during training.") chord_ids = self._drop_audio_condition_spans( ids=chord_ids, condition_mask=condition_mask, trigger_prob=self.chord_dropout_trigger_prob, replacement_id=CHORD_N_ID, bos_id=CHORD_BOS_ID, eos_id=CHORD_EOS_ID, ) structure_ids = self._drop_audio_condition_spans( ids=structure_ids, condition_mask=condition_mask, trigger_prob=self.structure_dropout_trigger_prob, replacement_id=SEGMENT_FALLBACK_ID, bos_id=STRUCTURE_BOS_ID, eos_id=STRUCTURE_EOS_ID, ) return self.condition_encoder(chord_ids, structure_ids) def ce_loss( self, logits: torch.FloatTensor, labels: Optional[torch.LongTensor], masks: Optional[torch.LongTensor], ) -> Optional[torch.Tensor]: if labels is None or masks is None: return None shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].clone() valid_token_mask = masks[:, 1:].bool().contiguous() if not valid_token_mask.any(): return shift_logits.new_zeros(()) shift_labels.masked_fill_(~valid_token_mask, -100) loss_sum = F.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1).to(shift_logits.device), ignore_index=-100, reduction="sum", ) valid_count = valid_token_mask.sum().to( device=loss_sum.device, dtype=loss_sum.dtype, ) return loss_sum / valid_count.clamp_min(1) def forward( self, input_ids: torch.LongTensor, masks: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, chord_ids: Optional[torch.LongTensor] = None, structure_ids: Optional[torch.LongTensor] = None, condition_mask: Optional[torch.BoolTensor] = None, cond_precomputed: Optional[torch.FloatTensor] = None, ) -> CausalLMOutputWithPast: if use_cache is None: use_cache = self.config.use_cache if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) cond = self._build_condition( chord_ids=chord_ids, structure_ids=structure_ids, condition_mask=condition_mask, cond_precomputed=cond_precomputed, ) base_out = self.model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cond_expanded=cond, condition_mask=condition_mask, cache_position=cache_position, ) hidden_states = base_out.last_hidden_state logits = self.lm_head(hidden_states) loss = self.ce_loss(logits=logits, labels=labels, masks=masks) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=base_out.past_key_values, hidden_states=base_out.hidden_states, attentions=base_out.attentions, )